testgenie-py 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/ast_analyzer.py +2 -11
- testgen/analyzer/fuzz_analyzer.py +1 -6
- testgen/analyzer/random_feedback_analyzer.py +20 -293
- testgen/analyzer/reinforcement_analyzer.py +59 -57
- testgen/analyzer/test_case_analyzer_context.py +0 -6
- testgen/controller/cli_controller.py +35 -29
- testgen/controller/docker_controller.py +1 -0
- testgen/db/dao.py +68 -0
- testgen/db/dao_impl.py +226 -0
- testgen/{sqlite → db}/db.py +15 -6
- testgen/generator/pytest_generator.py +2 -10
- testgen/generator/unit_test_generator.py +2 -11
- testgen/main.py +1 -3
- testgen/models/coverage_data.py +56 -0
- testgen/models/db_test_case.py +65 -0
- testgen/models/function.py +56 -0
- testgen/models/function_metadata.py +11 -1
- testgen/models/generator_context.py +30 -3
- testgen/models/source_file.py +29 -0
- testgen/models/test_result.py +38 -0
- testgen/models/test_suite.py +20 -0
- testgen/reinforcement/agent.py +1 -27
- testgen/reinforcement/environment.py +11 -93
- testgen/reinforcement/statement_coverage_state.py +5 -4
- testgen/service/analysis_service.py +31 -22
- testgen/service/cfg_service.py +3 -1
- testgen/service/coverage_service.py +115 -0
- testgen/service/db_service.py +140 -0
- testgen/service/generator_service.py +77 -20
- testgen/service/logging_service.py +2 -2
- testgen/service/service.py +62 -231
- testgen/service/test_executor_service.py +145 -0
- testgen/util/coverage_utils.py +38 -116
- testgen/util/coverage_visualizer.py +10 -9
- testgen/util/file_utils.py +10 -111
- testgen/util/randomizer.py +0 -26
- testgen/util/utils.py +197 -38
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/METADATA +1 -1
- testgenie_py-0.3.9.dist-info/RECORD +72 -0
- testgen/inspector/inspector.py +0 -59
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +0 -12
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db_service.py +0 -239
- testgen/testgen.db +0 -0
- testgenie_py-0.3.7.dist-info/RECORD +0 -67
- /testgen/{inspector → db}/__init__.py +0 -0
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/WHEEL +0 -0
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,65 @@
|
|
1
|
+
class DBTestCase:
|
2
|
+
def __init__(self, expected_output, inputs, test_function: str, last_run_time, test_method_type: int, test_suite_id: int, function_id: int):
|
3
|
+
self._expected_output = expected_output
|
4
|
+
self._inputs = inputs
|
5
|
+
self._test_function = test_function
|
6
|
+
self._last_run_time = last_run_time
|
7
|
+
self._test_method_type = test_method_type
|
8
|
+
self._test_suite_id = test_suite_id
|
9
|
+
self._function_id = function_id
|
10
|
+
|
11
|
+
@property
|
12
|
+
def expected_output(self):
|
13
|
+
return self._expected_output
|
14
|
+
|
15
|
+
@expected_output.setter
|
16
|
+
def expected_output(self, value):
|
17
|
+
self._expected_output = value
|
18
|
+
|
19
|
+
@property
|
20
|
+
def inputs(self):
|
21
|
+
return self._inputs
|
22
|
+
|
23
|
+
@inputs.setter
|
24
|
+
def inputs(self, value):
|
25
|
+
self._inputs = value
|
26
|
+
|
27
|
+
@property
|
28
|
+
def test_function(self) -> str:
|
29
|
+
return self._test_function
|
30
|
+
|
31
|
+
@test_function.setter
|
32
|
+
def test_function(self, value: str) -> None:
|
33
|
+
self._test_function = value
|
34
|
+
|
35
|
+
@property
|
36
|
+
def last_run_time(self):
|
37
|
+
return self._last_run_time
|
38
|
+
|
39
|
+
@last_run_time.setter
|
40
|
+
def last_run_time(self, value) -> None:
|
41
|
+
self._last_run_time = value
|
42
|
+
|
43
|
+
@property
|
44
|
+
def test_method_type(self) -> int:
|
45
|
+
return self._test_method_type
|
46
|
+
|
47
|
+
@test_method_type.setter
|
48
|
+
def test_method_type(self, value: int) -> None:
|
49
|
+
self._test_method_type = value
|
50
|
+
|
51
|
+
@property
|
52
|
+
def test_suite_id(self) -> int:
|
53
|
+
return self._test_suite_id
|
54
|
+
|
55
|
+
@test_suite_id.setter
|
56
|
+
def test_suite_id(self, value: int) -> None:
|
57
|
+
self._test_suite_id = value
|
58
|
+
|
59
|
+
@property
|
60
|
+
def function_id(self) -> int:
|
61
|
+
return self._function_id
|
62
|
+
|
63
|
+
@function_id.setter
|
64
|
+
def function_id(self, value: int) -> None:
|
65
|
+
self._function_id = value
|
@@ -0,0 +1,56 @@
|
|
1
|
+
class Function:
|
2
|
+
def __init__(self, name: str, params, start_line: int, end_line: int, num_lines: int, source_file_id: int):
|
3
|
+
self._name = name
|
4
|
+
self._params = params
|
5
|
+
self._start_line = start_line
|
6
|
+
self._end_line = end_line
|
7
|
+
self._num_lines = num_lines
|
8
|
+
self._source_file_id = source_file_id
|
9
|
+
|
10
|
+
@property
|
11
|
+
def name(self) -> str:
|
12
|
+
return self._name
|
13
|
+
|
14
|
+
@name.setter
|
15
|
+
def name(self, value: str) -> None:
|
16
|
+
self._name = value
|
17
|
+
|
18
|
+
@property
|
19
|
+
def params(self) -> str:
|
20
|
+
return self._params
|
21
|
+
|
22
|
+
@params.setter
|
23
|
+
def params(self, value: str) -> None:
|
24
|
+
self._params = value
|
25
|
+
|
26
|
+
@property
|
27
|
+
def start_line(self) -> int:
|
28
|
+
return self._start_line
|
29
|
+
|
30
|
+
@start_line.setter
|
31
|
+
def start_line(self, value: int) -> None:
|
32
|
+
self._start_line = value
|
33
|
+
|
34
|
+
@property
|
35
|
+
def end_line(self) -> int:
|
36
|
+
return self._end_line
|
37
|
+
|
38
|
+
@end_line.setter
|
39
|
+
def end_line(self, value: int) -> None:
|
40
|
+
self._end_line = value
|
41
|
+
|
42
|
+
@property
|
43
|
+
def num_lines(self) -> int:
|
44
|
+
return self._num_lines
|
45
|
+
|
46
|
+
@num_lines.setter
|
47
|
+
def num_lines(self, value: int) -> None:
|
48
|
+
self._num_lines = value
|
49
|
+
|
50
|
+
@property
|
51
|
+
def source_file_id(self) -> int:
|
52
|
+
return self._source_file_id
|
53
|
+
|
54
|
+
@source_file_id.setter
|
55
|
+
def source_file_id(self, value: int) -> None:
|
56
|
+
self._source_file_id = value
|
@@ -1,12 +1,14 @@
|
|
1
1
|
import ast
|
2
2
|
from types import ModuleType
|
3
|
+
from typing import Any
|
3
4
|
|
4
5
|
|
5
6
|
class FunctionMetadata:
|
6
|
-
def __init__(self, filename: str, module: ModuleType, class_name: str, function_name: str, func_def: ast.FunctionDef, params: dict):
|
7
|
+
def __init__(self, filename: str, module: ModuleType, class_name: str, func: Any, function_name: str, func_def: ast.FunctionDef, params: dict):
|
7
8
|
self._filename: str = filename
|
8
9
|
self._module: ModuleType = module
|
9
10
|
self._class_name: str = class_name
|
11
|
+
self._func: Any = func
|
10
12
|
self._function_name: str = function_name
|
11
13
|
self._func_def: ast.FunctionDef = func_def
|
12
14
|
self._params: dict = params
|
@@ -34,6 +36,14 @@ class FunctionMetadata:
|
|
34
36
|
@class_name.setter
|
35
37
|
def class_name(self, class_name: str):
|
36
38
|
self._class_name = class_name
|
39
|
+
|
40
|
+
@property
|
41
|
+
def func(self) -> Any:
|
42
|
+
return self._func
|
43
|
+
|
44
|
+
@func.setter
|
45
|
+
def func(self, func: Any):
|
46
|
+
self._func = func
|
37
47
|
|
38
48
|
@property
|
39
49
|
def function_name(self) -> str:
|
@@ -1,18 +1,21 @@
|
|
1
1
|
from types import ModuleType
|
2
2
|
from typing import List
|
3
|
-
|
4
3
|
from testgen.models.test_case import TestCase
|
5
4
|
|
6
5
|
|
7
6
|
class GeneratorContext:
|
8
|
-
def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str,
|
7
|
+
def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str,
|
8
|
+
test_cases: List[TestCase], is_package: bool, package_name: str, import_path: str):
|
9
9
|
self._filepath: str = filepath
|
10
10
|
self._filename: str = filename
|
11
11
|
self._class_name: str = class_name
|
12
12
|
self._module: ModuleType = module
|
13
13
|
self._output_path: str = output_path
|
14
14
|
self._test_cases: List[TestCase] = test_cases
|
15
|
-
|
15
|
+
self._is_package: bool = is_package
|
16
|
+
self._package_name: str = package_name
|
17
|
+
self._import_path: str = import_path
|
18
|
+
|
16
19
|
@property
|
17
20
|
def filepath(self) -> str:
|
18
21
|
return self._filepath
|
@@ -61,3 +64,27 @@ class GeneratorContext:
|
|
61
64
|
def test_cases(self, value: List[TestCase]) -> None:
|
62
65
|
self._test_cases = value
|
63
66
|
|
67
|
+
@property
|
68
|
+
def is_package(self) -> bool:
|
69
|
+
return self._is_package
|
70
|
+
|
71
|
+
@is_package.setter
|
72
|
+
def is_package(self, value: bool) -> None:
|
73
|
+
self._is_package = value
|
74
|
+
|
75
|
+
@property
|
76
|
+
def package_name(self) -> str:
|
77
|
+
return self._package_name
|
78
|
+
|
79
|
+
@package_name.setter
|
80
|
+
def package_name(self, value: str) -> None:
|
81
|
+
self._package_name = value
|
82
|
+
|
83
|
+
@property
|
84
|
+
def import_path(self) -> str:
|
85
|
+
return self._import_path
|
86
|
+
|
87
|
+
@import_path.setter
|
88
|
+
def import_path(self, value: str) -> None:
|
89
|
+
self._import_path = value
|
90
|
+
|
@@ -0,0 +1,29 @@
|
|
1
|
+
class SourceFile:
|
2
|
+
def __init__(self, path: str, lines_of_code: int, last_modified):
|
3
|
+
self._path = path
|
4
|
+
self._lines_of_code = lines_of_code
|
5
|
+
self._last_modified = last_modified
|
6
|
+
|
7
|
+
@property
|
8
|
+
def path(self) -> str:
|
9
|
+
return self._path
|
10
|
+
|
11
|
+
@path.setter
|
12
|
+
def path(self, value: str) -> None:
|
13
|
+
self._path = value
|
14
|
+
|
15
|
+
@property
|
16
|
+
def lines_of_code(self) -> int:
|
17
|
+
return self._lines_of_code
|
18
|
+
|
19
|
+
@lines_of_code.setter
|
20
|
+
def lines_of_code(self, value: int) -> None:
|
21
|
+
self._lines_of_code = value
|
22
|
+
|
23
|
+
@property
|
24
|
+
def last_modified(self):
|
25
|
+
return self._last_modified
|
26
|
+
|
27
|
+
@last_modified.setter
|
28
|
+
def last_modified(self, value) -> None:
|
29
|
+
self._last_modified = value
|
@@ -0,0 +1,38 @@
|
|
1
|
+
class TestResult:
|
2
|
+
def __init__(self, test_case_id: int, status: bool, error: str, execution_time):
|
3
|
+
self._test_case_id = test_case_id
|
4
|
+
self._status = status
|
5
|
+
self._error = error
|
6
|
+
self._execution_time = execution_time
|
7
|
+
|
8
|
+
@property
|
9
|
+
def test_case_id(self) -> int:
|
10
|
+
return self._test_case_id
|
11
|
+
|
12
|
+
@test_case_id.setter
|
13
|
+
def test_case_id(self, value: int) -> None:
|
14
|
+
self._test_case_id = value
|
15
|
+
|
16
|
+
@property
|
17
|
+
def status(self) -> bool:
|
18
|
+
return self._status
|
19
|
+
|
20
|
+
@status.setter
|
21
|
+
def status(self, value: bool) -> None:
|
22
|
+
self._status = value
|
23
|
+
|
24
|
+
@property
|
25
|
+
def error(self) -> str:
|
26
|
+
return self._error
|
27
|
+
|
28
|
+
@error.setter
|
29
|
+
def error(self, value: str) -> None:
|
30
|
+
self._error = value
|
31
|
+
|
32
|
+
@property
|
33
|
+
def execution_time(self):
|
34
|
+
return self._execution_time
|
35
|
+
|
36
|
+
@execution_time.setter
|
37
|
+
def execution_time(self, value) -> None:
|
38
|
+
self._execution_time = value
|
@@ -0,0 +1,20 @@
|
|
1
|
+
class TestSuite:
|
2
|
+
def __init__(self, name: str, creation_date):
|
3
|
+
self._name = name
|
4
|
+
self._creation_date = creation_date
|
5
|
+
|
6
|
+
@property
|
7
|
+
def name(self) -> str:
|
8
|
+
return self._name
|
9
|
+
|
10
|
+
@name.setter
|
11
|
+
def name(self, value: str) -> None:
|
12
|
+
self._name = value
|
13
|
+
|
14
|
+
@property
|
15
|
+
def creation_date(self):
|
16
|
+
return self._creation_date
|
17
|
+
|
18
|
+
@creation_date.setter
|
19
|
+
def creation_date(self, value) -> None:
|
20
|
+
self._creation_date = value
|
testgen/reinforcement/agent.py
CHANGED
@@ -128,30 +128,4 @@ class ReinforcementAgent:
|
|
128
128
|
print(f"UPDATING Q TABLE FOR STATE: {state}, ACTION: {action} WITH REWARD: {reward}")
|
129
129
|
new_q = (1 - self.learning_rate) * current_q + self.learning_rate * (reward + max_next_q)
|
130
130
|
|
131
|
-
self.q_table[(state, action)] = new_q
|
132
|
-
|
133
|
-
"""def optimize_test_suit(self, current_state, executable_statements):
|
134
|
-
# Try to optimize test cases by repeatedly performing remove actions if reached full coverage
|
135
|
-
test_case_count = current_state[1]
|
136
|
-
optimization_attempts = min(10, test_case_count - 1)
|
137
|
-
|
138
|
-
for _ in range(optimization_attempts):
|
139
|
-
if test_case_count <= 1:
|
140
|
-
break
|
141
|
-
|
142
|
-
action = "remove"
|
143
|
-
next_state, reward = self.env.step(action)
|
144
|
-
|
145
|
-
new_covered = next_state[0]
|
146
|
-
new_uncovered = [stmt for stmt in executable_statements if stmt not in new_covered]
|
147
|
-
|
148
|
-
if len(new_uncovered) == 0:
|
149
|
-
current_state = next_state
|
150
|
-
test_case_count = current_state[2]
|
151
|
-
print(f"Optimized to {test_case_count} test cases.")
|
152
|
-
else:
|
153
|
-
# Add a test case back if removing broke coverage
|
154
|
-
self.env.step("add")
|
155
|
-
break
|
156
|
-
|
157
|
-
return current_state"""
|
131
|
+
self.q_table[(state, action)] = new_q
|
@@ -4,6 +4,7 @@ from typing import List, Tuple
|
|
4
4
|
|
5
5
|
import coverage
|
6
6
|
|
7
|
+
from testgen.models.function_metadata import FunctionMetadata
|
7
8
|
from testgen.service.logging_service import get_logger
|
8
9
|
import testgen.util.coverage_utils
|
9
10
|
import testgen.util.file_utils
|
@@ -13,11 +14,9 @@ from testgen.models.test_case import TestCase
|
|
13
14
|
|
14
15
|
|
15
16
|
class ReinforcementEnvironment:
|
16
|
-
def __init__(self,
|
17
|
-
self.
|
18
|
-
self.
|
19
|
-
self.module = module
|
20
|
-
self.class_name = class_name
|
17
|
+
def __init__(self, filepath: str, function_data: FunctionMetadata, initial_test_cases: List[TestCase], state: AbstractState):
|
18
|
+
self.filepath = filepath
|
19
|
+
self.function_data = function_data
|
21
20
|
self.initial_test_cases = initial_test_cases
|
22
21
|
self.test_cases = initial_test_cases.copy()
|
23
22
|
self.state = state
|
@@ -35,13 +34,13 @@ class ReinforcementEnvironment:
|
|
35
34
|
|
36
35
|
# Execute action
|
37
36
|
if action == "add":
|
38
|
-
self.test_cases.append(randomizer.new_random_test_case(self.
|
37
|
+
self.test_cases.append(randomizer.new_random_test_case(self.filepath, self.function_data.class_name, self.function_data.func_def))
|
39
38
|
elif action == "merge" and len(self.test_cases) > 1:
|
40
|
-
self.test_cases.append(randomizer.combine_cases(self.module, self.test_cases))
|
39
|
+
self.test_cases.append(randomizer.combine_cases(self.function_data.module, self.test_cases))
|
41
40
|
elif action == "remove" and len(self.test_cases) > 1:
|
42
41
|
self.test_cases = randomizer.remove_case(self.test_cases)
|
43
42
|
elif action == "z3":
|
44
|
-
self.test_cases = randomizer.get_z3_test_cases(self.
|
43
|
+
self.test_cases = randomizer.get_z3_test_cases(self.filepath, self.function_data.class_name, self.function_data.func_def, self.test_cases)
|
45
44
|
else:
|
46
45
|
raise ValueError("Invalid action")
|
47
46
|
|
@@ -91,81 +90,6 @@ class ReinforcementEnvironment:
|
|
91
90
|
|
92
91
|
print(f"Final reward {reward}")
|
93
92
|
return reward
|
94
|
-
|
95
|
-
|
96
|
-
def get_all_executable_statements(self):
|
97
|
-
"""Get all executable statements including else branches"""
|
98
|
-
import ast
|
99
|
-
|
100
|
-
test_cases = [tc for tc in self.test_cases if tc.func_name == self.fut.name]
|
101
|
-
|
102
|
-
executable_lines = set()
|
103
|
-
if not test_cases:
|
104
|
-
self.logger.debug("Warning: No test cases available to determine executable statements")
|
105
|
-
from testgen.util.randomizer import new_random_test_case
|
106
|
-
temp_case = new_random_test_case(self.file_name, self.class_name, self.fut)
|
107
|
-
analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, temp_case.inputs)
|
108
|
-
executable_lines.update(analysis[1]) # Add executable lines from coverage analysis
|
109
|
-
else:
|
110
|
-
analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, test_cases[0].inputs)
|
111
|
-
|
112
|
-
executable_lines.update(analysis[1]) # Add executable lines from coverage analysis
|
113
|
-
# Get standard executable lines from coverage.py
|
114
|
-
executable_lines = list(executable_lines)
|
115
|
-
|
116
|
-
# Parse the source file to find else branches
|
117
|
-
with open(self.file_name, 'r') as f:
|
118
|
-
source = f.read()
|
119
|
-
|
120
|
-
# Parse the code
|
121
|
-
tree = ast.parse(source)
|
122
|
-
# Find our specific function
|
123
|
-
for node in ast.walk(tree):
|
124
|
-
if isinstance(node, ast.ClassDef) and node.name == self.class_name:
|
125
|
-
# If we have a class, find the method
|
126
|
-
for method in node.body:
|
127
|
-
if isinstance(method, ast.FunctionDef) and method.name == self.fut.name:
|
128
|
-
# Find all if statements in this method
|
129
|
-
for if_node in ast.walk(method):
|
130
|
-
if isinstance(if_node, ast.If) and if_node.orelse:
|
131
|
-
# There's an else branch
|
132
|
-
if isinstance(if_node.orelse[0], ast.If):
|
133
|
-
# This is an elif - already counted
|
134
|
-
continue
|
135
|
-
|
136
|
-
# Get the line number of the first statement in the else block
|
137
|
-
# and subtract 1 to get the 'else:' line
|
138
|
-
else_line = if_node.orelse[0].lineno - 1
|
139
|
-
|
140
|
-
# Check if this is actually an else line (not a nested if)
|
141
|
-
with open(self.file_name, 'r') as f:
|
142
|
-
lines = f.readlines()
|
143
|
-
if else_line <= len(lines):
|
144
|
-
line_content = lines[else_line - 1].strip()
|
145
|
-
if line_content == "else:":
|
146
|
-
if else_line not in executable_lines:
|
147
|
-
executable_lines.append(else_line)
|
148
|
-
if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
|
149
|
-
# Find all if statements in this function
|
150
|
-
for if_node in ast.walk(node):
|
151
|
-
if isinstance(if_node, ast.If) and if_node.orelse:
|
152
|
-
# There's an else branch
|
153
|
-
if isinstance(if_node.orelse[0], ast.If):
|
154
|
-
# This is an elif - already counted
|
155
|
-
continue
|
156
|
-
|
157
|
-
# Get the line number of the first statement in the else block
|
158
|
-
# and subtract 1 to get the 'else:' line
|
159
|
-
else_line = if_node.orelse[0].lineno - 1
|
160
|
-
|
161
|
-
# Check if this is actually an else line (not a nested if)
|
162
|
-
with open(self.file_name, 'r') as f:
|
163
|
-
lines = f.readlines()
|
164
|
-
if else_line <= len(lines):
|
165
|
-
line_content = lines[else_line - 1].strip()
|
166
|
-
if line_content == "else:":
|
167
|
-
if else_line not in executable_lines:
|
168
|
-
executable_lines.append(else_line)
|
169
93
|
|
170
94
|
return sorted(executable_lines)
|
171
95
|
|
@@ -180,13 +104,7 @@ class ReinforcementEnvironment:
|
|
180
104
|
# Execute all test cases
|
181
105
|
for test_case in self.test_cases:
|
182
106
|
try:
|
183
|
-
|
184
|
-
if self.class_name:
|
185
|
-
class_obj = getattr(module, self.class_name)
|
186
|
-
instance = class_obj()
|
187
|
-
func = getattr(instance, self.fut.name)
|
188
|
-
else:
|
189
|
-
func = getattr(module, self.fut.name)
|
107
|
+
func = self.function_data.func
|
190
108
|
_ = func(*test_case.inputs)
|
191
109
|
except Exception as e:
|
192
110
|
import traceback
|
@@ -195,7 +113,7 @@ class ReinforcementEnvironment:
|
|
195
113
|
self.cov.stop()
|
196
114
|
|
197
115
|
# Get detailed coverage data including branches
|
198
|
-
file_path = os.path.abspath(self.
|
116
|
+
file_path = os.path.abspath(self.filepath)
|
199
117
|
data = self.cov.get_data()
|
200
118
|
|
201
119
|
# Extract function-specific coverage
|
@@ -225,13 +143,13 @@ class ReinforcementEnvironment:
|
|
225
143
|
import ast
|
226
144
|
|
227
145
|
try:
|
228
|
-
with open(self.
|
146
|
+
with open(self.filepath, 'r') as f:
|
229
147
|
source = f.read()
|
230
148
|
|
231
149
|
tree = ast.parse(source)
|
232
150
|
|
233
151
|
for node in ast.walk(tree):
|
234
|
-
if isinstance(node, ast.FunctionDef) and node.name == self.
|
152
|
+
if isinstance(node, ast.FunctionDef) and node.name == self.function_data.function_name:
|
235
153
|
# Find the first line of the function
|
236
154
|
start_line = node.lineno
|
237
155
|
|
@@ -14,11 +14,11 @@ class StatementCoverageState(AbstractState):
|
|
14
14
|
"""Returns calculated coverage and length of test cases in a tuple"""
|
15
15
|
all_covered_statements = set()
|
16
16
|
for test_case in self.environment.test_cases:
|
17
|
-
analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.
|
17
|
+
analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.filepath, self.environment.function_data, test_case.inputs)
|
18
18
|
covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
|
19
19
|
all_covered_statements.update(covered)
|
20
20
|
|
21
|
-
executable_statements = self.environment.
|
21
|
+
executable_statements = testgen.util.coverage_utils.get_all_executable_statements(self.environment.filepath, self.environment.function_data, self.environment.test_cases)
|
22
22
|
|
23
23
|
if not executable_statements or executable_statements == 0:
|
24
24
|
calc_coverage = 0.0
|
@@ -26,10 +26,11 @@ class StatementCoverageState(AbstractState):
|
|
26
26
|
calc_coverage: float = (len(all_covered_statements) / len(executable_statements)) * 100
|
27
27
|
|
28
28
|
self.logger.debug(f"GET STATE ALL COVERED STATEMENTS: {all_covered_statements}")
|
29
|
-
self.logger.debug(f"GET STATE ALL EXECUTABLE STATEMENTS: {
|
29
|
+
self.logger.debug(f"GET STATE ALL EXECUTABLE STATEMENTS: {executable_statements}")
|
30
30
|
self.logger.debug(f"GET STATE FLOAT COVERAGE: {calc_coverage}")
|
31
31
|
|
32
32
|
if calc_coverage >= 100:
|
33
|
-
print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment.
|
33
|
+
print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment.function_data.function_name}!!!!!!!!")
|
34
|
+
|
34
35
|
return calc_coverage, len(self.environment.test_cases)
|
35
36
|
|
@@ -2,9 +2,10 @@ import inspect
|
|
2
2
|
import ast
|
3
3
|
import time
|
4
4
|
from types import ModuleType
|
5
|
-
from typing import Dict, List
|
5
|
+
from typing import Dict, List, Any
|
6
6
|
|
7
7
|
import testgen
|
8
|
+
from testgen.analyzer.reinforcement_analyzer import ReinforcementAnalyzer
|
8
9
|
from testgen.service.logging_service import get_logger
|
9
10
|
import testgen.util.file_utils
|
10
11
|
import testgen.util.file_utils as file_utils
|
@@ -37,11 +38,8 @@ class AnalysisService:
|
|
37
38
|
|
38
39
|
def generate_test_cases(self) -> List[TestCase]:
|
39
40
|
"""Generate test cases using the current strategy."""
|
40
|
-
|
41
|
-
|
42
|
-
else:
|
43
|
-
self.test_case_analyzer_context.do_logic()
|
44
|
-
return self.test_case_analyzer_context.test_cases
|
41
|
+
self.test_case_analyzer_context.do_strategy(30)
|
42
|
+
return self.test_case_analyzer_context.test_cases
|
45
43
|
|
46
44
|
def create_analysis_context(self, filepath: str) -> AnalysisContext:
|
47
45
|
"""Create an analysis context for the given file."""
|
@@ -88,6 +86,7 @@ class AnalysisService:
|
|
88
86
|
mode = mode or self.reinforcement_mode
|
89
87
|
module: ModuleType = testgen.util.file_utils.load_module(filepath)
|
90
88
|
tree: ast.Module = testgen.util.file_utils.load_and_parse_file_for_tree(filepath)
|
89
|
+
list_of_function_data: List[FunctionMetadata] = self.get_function_data(filepath, module, class_name)
|
91
90
|
functions: List[ast.FunctionDef] = testgen.util.utils.get_functions(tree)
|
92
91
|
self.class_name = class_name
|
93
92
|
time_limit: int = 30
|
@@ -95,14 +94,14 @@ class AnalysisService:
|
|
95
94
|
|
96
95
|
q_table = self._load_q_table()
|
97
96
|
|
98
|
-
for function in
|
99
|
-
print(f"\nStarting reinforcement learning for function {function.
|
97
|
+
for function in list_of_function_data:
|
98
|
+
print(f"\nStarting reinforcement learning for function {function.function_name}")
|
100
99
|
start_time = time.time()
|
101
100
|
function_test_cases: List[TestCase] = []
|
102
101
|
best_coverage: float = 0.0
|
103
102
|
|
104
103
|
# Create environment and agent once per function
|
105
|
-
environment = ReinforcementEnvironment(filepath, function,
|
104
|
+
environment = ReinforcementEnvironment(filepath, function, function_test_cases, state=StatementCoverageState(None))
|
106
105
|
environment.state = StatementCoverageState(environment)
|
107
106
|
|
108
107
|
# Create agent with existing Q-table
|
@@ -115,10 +114,10 @@ class AnalysisService:
|
|
115
114
|
new_test_cases = agent.collect_test_cases()
|
116
115
|
function_test_cases.extend(new_test_cases)
|
117
116
|
|
118
|
-
print(f"\nNumber of test cases for {function.
|
117
|
+
print(f"\nNumber of test cases for {function.function_name}: {len(function_test_cases)}")
|
119
118
|
|
120
119
|
current_coverage: float = environment.run_tests()
|
121
|
-
print(f"Current coverage: {function.
|
120
|
+
print(f"Current coverage: {function.function_name}: {current_coverage}")
|
122
121
|
|
123
122
|
q_table.update(agent.q_table)
|
124
123
|
|
@@ -134,8 +133,8 @@ class AnalysisService:
|
|
134
133
|
unique_test_cases.append(case)
|
135
134
|
|
136
135
|
all_test_cases.extend(unique_test_cases)
|
137
|
-
print(f"Final coverage for {function.
|
138
|
-
print(f"Final test cases for {function.
|
136
|
+
print(f"Final coverage for {function.function_name}: {best_coverage}%")
|
137
|
+
print(f"Final test cases for {function.function_name}: {len(unique_test_cases)}")
|
139
138
|
|
140
139
|
self._save_q_table(q_table)
|
141
140
|
|
@@ -146,12 +145,23 @@ class AnalysisService:
|
|
146
145
|
def _create_function_metadata(self, filename: str, module: ModuleType, class_name: str | None,
|
147
146
|
func_node: ast.FunctionDef) -> FunctionMetadata:
|
148
147
|
function_name = func_node.name
|
149
|
-
|
148
|
+
func = self._get_func_attr(function_name, module, class_name)
|
150
149
|
param_types = self._get_params(func_node)
|
151
150
|
|
152
|
-
return FunctionMetadata(filename, module, class_name, function_name, func_node, param_types)
|
153
|
-
|
154
|
-
|
151
|
+
return FunctionMetadata(filename, module, class_name, func, function_name, func_node, param_types)
|
152
|
+
|
153
|
+
@staticmethod
|
154
|
+
def _get_func_attr(function_name: str, module: ModuleType, class_name: str | None) -> Any:
|
155
|
+
if class_name:
|
156
|
+
cls = getattr(module, class_name)
|
157
|
+
instance = cls()
|
158
|
+
func = getattr(instance, function_name)
|
159
|
+
else:
|
160
|
+
func = getattr(module, function_name)
|
161
|
+
return func
|
162
|
+
|
163
|
+
@staticmethod
|
164
|
+
def _get_params(func_node: ast.FunctionDef) -> Dict[str, str]:
|
155
165
|
# Extract parameter types
|
156
166
|
param_types = {}
|
157
167
|
for arg in func_node.args.args:
|
@@ -173,18 +183,17 @@ class AnalysisService:
|
|
173
183
|
|
174
184
|
if strategy == AST_STRAT:
|
175
185
|
analyzer = ASTAnalyzer(analysis_context)
|
176
|
-
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
177
186
|
elif strategy == FUZZ_STRAT:
|
178
187
|
analyzer = FuzzAnalyzer(analysis_context)
|
179
|
-
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
180
188
|
elif strategy == RANDOM_STRAT:
|
181
189
|
analyzer = RandomFeedbackAnalyzer(analysis_context)
|
182
|
-
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
183
190
|
elif strategy == REINFORCE_STRAT:
|
184
|
-
|
191
|
+
analyzer = ReinforcementAnalyzer(analysis_context, mode=self.reinforcement_mode)
|
185
192
|
else:
|
186
193
|
raise NotImplementedError(f"Test strategy {strategy} not implemented")
|
187
|
-
|
194
|
+
|
195
|
+
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
196
|
+
|
188
197
|
def set_file_path(self, path: str):
|
189
198
|
"""Set the file path for analysis."""
|
190
199
|
self.file_path = path
|
testgen/service/cfg_service.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
from typing import List
|
3
|
+
|
4
|
+
from testgen.models.function_metadata import FunctionMetadata
|
3
5
|
from testgen.models.test_case import TestCase
|
4
6
|
from testgen.service.logging_service import get_logger
|
5
7
|
from testgen.util.coverage_visualizer import CoverageVisualizer
|
@@ -46,7 +48,7 @@ class CFGService:
|
|
46
48
|
filename = os.path.basename(file_path).replace('.py', '')
|
47
49
|
|
48
50
|
for func in analysis_context.function_data:
|
49
|
-
self.visualizer.get_covered_lines(file_path,
|
51
|
+
self.visualizer.get_covered_lines(file_path, func, test_cases)
|
50
52
|
|
51
53
|
base_filename = f"{filename}_{func.function_name}_coverage"
|
52
54
|
output_filepath = self.get_versioned_filename(visualization_dir, base_filename)
|