testgenie-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/__init__.py +0 -0
- testgen/analyzer/__init__.py +0 -0
- testgen/analyzer/ast_analyzer.py +149 -0
- testgen/analyzer/contracts/__init__.py +0 -0
- testgen/analyzer/contracts/contract.py +13 -0
- testgen/analyzer/contracts/no_exception_contract.py +16 -0
- testgen/analyzer/contracts/nonnull_contract.py +15 -0
- testgen/analyzer/fuzz_analyzer.py +106 -0
- testgen/analyzer/random_feedback_analyzer.py +291 -0
- testgen/analyzer/reinforcement_analyzer.py +75 -0
- testgen/analyzer/test_case_analyzer.py +46 -0
- testgen/analyzer/test_case_analyzer_context.py +58 -0
- testgen/controller/__init__.py +0 -0
- testgen/controller/cli_controller.py +194 -0
- testgen/controller/docker_controller.py +169 -0
- testgen/docker/Dockerfile +22 -0
- testgen/docker/poetry.lock +361 -0
- testgen/docker/pyproject.toml +22 -0
- testgen/generator/__init__.py +0 -0
- testgen/generator/code_generator.py +66 -0
- testgen/generator/doctest_generator.py +208 -0
- testgen/generator/generator.py +55 -0
- testgen/generator/pytest_generator.py +77 -0
- testgen/generator/test_generator.py +26 -0
- testgen/generator/unit_test_generator.py +84 -0
- testgen/inspector/__init__.py +0 -0
- testgen/inspector/inspector.py +61 -0
- testgen/main.py +13 -0
- testgen/models/__init__.py +0 -0
- testgen/models/analysis_context.py +56 -0
- testgen/models/function_metadata.py +61 -0
- testgen/models/generator_context.py +63 -0
- testgen/models/test_case.py +8 -0
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +12 -0
- testgen/q_table/global_q_table.json +1 -0
- testgen/reinforcement/__init__.py +0 -0
- testgen/reinforcement/abstract_state.py +7 -0
- testgen/reinforcement/agent.py +153 -0
- testgen/reinforcement/environment.py +215 -0
- testgen/reinforcement/statement_coverage_state.py +33 -0
- testgen/service/__init__.py +0 -0
- testgen/service/analysis_service.py +260 -0
- testgen/service/cfg_service.py +55 -0
- testgen/service/generator_service.py +169 -0
- testgen/service/service.py +389 -0
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db.py +84 -0
- testgen/sqlite/db_service.py +219 -0
- testgen/tree/__init__.py +0 -0
- testgen/tree/node.py +7 -0
- testgen/tree/tree_utils.py +79 -0
- testgen/util/__init__.py +0 -0
- testgen/util/coverage_utils.py +168 -0
- testgen/util/coverage_visualizer.py +154 -0
- testgen/util/file_utils.py +110 -0
- testgen/util/randomizer.py +122 -0
- testgen/util/utils.py +143 -0
- testgen/util/z3_utils/__init__.py +0 -0
- testgen/util/z3_utils/ast_to_z3.py +99 -0
- testgen/util/z3_utils/branch_condition.py +72 -0
- testgen/util/z3_utils/constraint_extractor.py +36 -0
- testgen/util/z3_utils/variable_finder.py +10 -0
- testgen/util/z3_utils/z3_test_case.py +94 -0
- testgenie_py-0.1.0.dist-info/METADATA +24 -0
- testgenie_py-0.1.0.dist-info/RECORD +68 -0
- testgenie_py-0.1.0.dist-info/WHEEL +4 -0
- testgenie_py-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,260 @@
|
|
1
|
+
import inspect
|
2
|
+
import ast
|
3
|
+
import time
|
4
|
+
from types import ModuleType
|
5
|
+
from typing import Dict, List
|
6
|
+
|
7
|
+
import testgen
|
8
|
+
import testgen.util.file_utils
|
9
|
+
import testgen.util.file_utils as file_utils
|
10
|
+
import testgen.util.utils
|
11
|
+
from testgen.analyzer.ast_analyzer import ASTAnalyzer
|
12
|
+
from testgen.analyzer.fuzz_analyzer import FuzzAnalyzer
|
13
|
+
from testgen.analyzer.random_feedback_analyzer import RandomFeedbackAnalyzer
|
14
|
+
from testgen.models.test_case import TestCase
|
15
|
+
from testgen.analyzer.test_case_analyzer_context import TestCaseAnalyzerContext
|
16
|
+
from testgen.reinforcement.agent import ReinforcementAgent
|
17
|
+
from testgen.reinforcement.environment import ReinforcementEnvironment
|
18
|
+
from testgen.reinforcement.statement_coverage_state import StatementCoverageState
|
19
|
+
from testgen.models.analysis_context import AnalysisContext
|
20
|
+
from testgen.models.function_metadata import FunctionMetadata
|
21
|
+
|
22
|
+
# Constants for test strategies
|
23
|
+
AST_STRAT = 1
|
24
|
+
FUZZ_STRAT = 2
|
25
|
+
RANDOM_STRAT = 3
|
26
|
+
REINFORCE_STRAT = 4
|
27
|
+
|
28
|
+
class AnalysisService:
|
29
|
+
def __init__(self):
|
30
|
+
self.file_path = None
|
31
|
+
self.class_name = None
|
32
|
+
self.test_case_analyzer_context = TestCaseAnalyzerContext(None, None)
|
33
|
+
self.test_strategy = 0
|
34
|
+
self.reinforcement_mode = "train"
|
35
|
+
|
36
|
+
def generate_test_cases(self) -> List[TestCase]:
|
37
|
+
"""Generate test cases using the current strategy."""
|
38
|
+
if self.test_strategy == REINFORCE_STRAT:
|
39
|
+
return self.do_reinforcement_learning(self.file_path)
|
40
|
+
else:
|
41
|
+
self.test_case_analyzer_context.do_logic()
|
42
|
+
return self.test_case_analyzer_context.test_cases
|
43
|
+
|
44
|
+
def create_analysis_context(self, filepath: str) -> AnalysisContext:
|
45
|
+
"""Create an analysis context for the given file."""
|
46
|
+
filename = file_utils.get_filename(filepath)
|
47
|
+
module = file_utils.load_module(filepath)
|
48
|
+
class_name = self.get_class_name(module)
|
49
|
+
function_data = self.get_function_data(filename, module, class_name)
|
50
|
+
return AnalysisContext(filepath, filename, class_name, module, function_data)
|
51
|
+
|
52
|
+
def get_function_data(self, filename: str, module: ModuleType, class_name: str | None) -> List[FunctionMetadata]:
|
53
|
+
function_metadata_list: List[FunctionMetadata] = []
|
54
|
+
|
55
|
+
# Parse the module's source code
|
56
|
+
source_code = inspect.getsource(module)
|
57
|
+
tree = ast.parse(source_code)
|
58
|
+
|
59
|
+
# Find all function definitions in the module
|
60
|
+
for node in tree.body:
|
61
|
+
if isinstance(node, ast.FunctionDef):
|
62
|
+
function_metadata_list.append(self._create_function_metadata(filename, module, class_name, node))
|
63
|
+
elif isinstance(node, ast.ClassDef):
|
64
|
+
# Get the class name to check if it matches the expected class
|
65
|
+
ast_class_name = node.name
|
66
|
+
|
67
|
+
# If we have a specified class and it doesn't match, skip
|
68
|
+
if self.class_name and ast_class_name != self.class_name:
|
69
|
+
continue
|
70
|
+
|
71
|
+
# Process class methods
|
72
|
+
for class_node in node.body:
|
73
|
+
if isinstance(class_node, ast.FunctionDef):
|
74
|
+
# Skip private methods (those starting with _)
|
75
|
+
if not class_node.name.startswith('_'):
|
76
|
+
function_metadata_list.append(self._create_function_metadata(filename, module, class_name, class_node))
|
77
|
+
|
78
|
+
return function_metadata_list
|
79
|
+
|
80
|
+
def do_reinforcement_learning(self, filepath: str, mode: str = None) -> List[TestCase]:
|
81
|
+
mode = mode or self.reinforcement_mode
|
82
|
+
module: ModuleType = testgen.util.file_utils.load_module(filepath)
|
83
|
+
tree: ast.Module = testgen.util.file_utils.load_and_parse_file_for_tree(filepath)
|
84
|
+
functions: List[ast.FunctionDef] = testgen.util.utils.get_functions(tree)
|
85
|
+
time_limit: int = 30
|
86
|
+
all_test_cases: List[TestCase] = []
|
87
|
+
|
88
|
+
q_table = self._load_q_table()
|
89
|
+
|
90
|
+
for function in functions:
|
91
|
+
print(f"\nStarting reinforcement learning for function {function.name}")
|
92
|
+
start_time = time.time()
|
93
|
+
function_test_cases: List[TestCase] = []
|
94
|
+
best_coverage: float = 0.0
|
95
|
+
|
96
|
+
# Create environment and agent once per function
|
97
|
+
environment = ReinforcementEnvironment(filepath, function, module, function_test_cases, state=StatementCoverageState(None))
|
98
|
+
environment.state = StatementCoverageState(environment)
|
99
|
+
|
100
|
+
# Create agent with existing Q-table
|
101
|
+
agent = ReinforcementAgent(filepath, environment, function_test_cases, q_table)
|
102
|
+
|
103
|
+
if mode == "train":
|
104
|
+
new_test_cases = agent.do_q_learning()
|
105
|
+
function_test_cases.extend(new_test_cases)
|
106
|
+
else:
|
107
|
+
new_test_cases = agent.collect_test_cases()
|
108
|
+
function_test_cases.extend(new_test_cases)
|
109
|
+
|
110
|
+
print(f"\nTest cases for {function.name}: {len(function_test_cases)}")
|
111
|
+
|
112
|
+
current_coverage: float = environment.run_tests()
|
113
|
+
print(f"Current coverage: {function.name}: {current_coverage}")
|
114
|
+
|
115
|
+
q_table.update(agent.q_table)
|
116
|
+
|
117
|
+
# Process and filter unique test cases
|
118
|
+
seen = set()
|
119
|
+
unique_test_cases: List[TestCase] = []
|
120
|
+
for case in function_test_cases:
|
121
|
+
# Make case tuple hashable
|
122
|
+
case_inputs = tuple(case.inputs) if isinstance(case.inputs, list) else case.inputs
|
123
|
+
case_key = (case.func_name, case_inputs)
|
124
|
+
if case_key not in seen:
|
125
|
+
seen.add(case_key)
|
126
|
+
unique_test_cases.append(case)
|
127
|
+
|
128
|
+
all_test_cases.extend(unique_test_cases)
|
129
|
+
print(f"Final coverage for {function.name}: {best_coverage}%")
|
130
|
+
print(f"Final test cases for {function.name}: {len(unique_test_cases)}")
|
131
|
+
|
132
|
+
self._save_q_table(q_table)
|
133
|
+
|
134
|
+
print("\nReinforcement Learning Complete")
|
135
|
+
print(f"Total test cases found: {len(all_test_cases)}")
|
136
|
+
return all_test_cases
|
137
|
+
|
138
|
+
def _create_function_metadata(self, filename: str, module: ModuleType, class_name: str | None,
|
139
|
+
func_node: ast.FunctionDef) -> FunctionMetadata:
|
140
|
+
function_name = func_node.name
|
141
|
+
|
142
|
+
param_types = self._get_params(func_node)
|
143
|
+
|
144
|
+
return FunctionMetadata(filename, module, class_name, function_name, func_node, param_types)
|
145
|
+
|
146
|
+
def _get_params(self, func_node: ast.FunctionDef) -> Dict[str, str]:
|
147
|
+
# Extract parameter types
|
148
|
+
param_types = {}
|
149
|
+
for arg in func_node.args.args:
|
150
|
+
param_name = arg.arg
|
151
|
+
if param_name == 'self':
|
152
|
+
continue
|
153
|
+
|
154
|
+
if arg.annotation:
|
155
|
+
param_type = ast.unparse(arg.annotation)
|
156
|
+
param_types[param_name] = param_type
|
157
|
+
else:
|
158
|
+
param_types[param_name] = None
|
159
|
+
return param_types
|
160
|
+
|
161
|
+
def set_test_strategy(self, strategy: int, module_name: str, class_name: str):
|
162
|
+
"""Set the test analysis strategy."""
|
163
|
+
self.test_strategy = strategy
|
164
|
+
analysis_context = self.create_analysis_context(self.file_path)
|
165
|
+
|
166
|
+
if strategy == AST_STRAT:
|
167
|
+
analyzer = ASTAnalyzer(analysis_context)
|
168
|
+
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
169
|
+
elif strategy == FUZZ_STRAT:
|
170
|
+
analyzer = FuzzAnalyzer(analysis_context)
|
171
|
+
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
172
|
+
elif strategy == RANDOM_STRAT:
|
173
|
+
analyzer = RandomFeedbackAnalyzer(analysis_context)
|
174
|
+
self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
|
175
|
+
elif strategy == REINFORCE_STRAT:
|
176
|
+
pass
|
177
|
+
else:
|
178
|
+
raise NotImplementedError(f"Test strategy {strategy} not implemented")
|
179
|
+
|
180
|
+
def set_file_path(self, path: str):
|
181
|
+
"""Set the file path for analysis."""
|
182
|
+
self.file_path = path
|
183
|
+
|
184
|
+
def set_reinforcement_mode(self, mode: str):
|
185
|
+
self.reinforcement_mode = mode
|
186
|
+
|
187
|
+
@staticmethod
|
188
|
+
def get_class_name(module: ModuleType) -> str | None:
|
189
|
+
"""Get class from module otherwise return None."""
|
190
|
+
for name, cls in inspect.getmembers(module, inspect.isclass):
|
191
|
+
if cls.__module__ == module.__name__:
|
192
|
+
return name
|
193
|
+
return None
|
194
|
+
|
195
|
+
@staticmethod
|
196
|
+
def _save_q_table(q_table):
|
197
|
+
"""Save Q-table to a global JSON file"""
|
198
|
+
import json
|
199
|
+
import os
|
200
|
+
|
201
|
+
q_table_dir = "q_table"
|
202
|
+
os.makedirs(q_table_dir, exist_ok=True)
|
203
|
+
q_table_path = os.path.join(q_table_dir, "global_q_table.json")
|
204
|
+
|
205
|
+
global_q_table = {}
|
206
|
+
if os.path.exists(q_table_path):
|
207
|
+
try:
|
208
|
+
with open(q_table_path, 'r') as f:
|
209
|
+
global_q_table = json.load(f)
|
210
|
+
except Exception as e:
|
211
|
+
print(f"Error loading existing Q-table: {e}")
|
212
|
+
|
213
|
+
serializable_q_table = {}
|
214
|
+
for key, value in q_table.items():
|
215
|
+
state, action = key
|
216
|
+
state_str = str(state)
|
217
|
+
serializable_q_table[f"{state_str}|{action}"] = value
|
218
|
+
|
219
|
+
global_q_table = serializable_q_table # Replace with latest Q-table
|
220
|
+
|
221
|
+
try:
|
222
|
+
with open(q_table_path, 'w') as f:
|
223
|
+
json.dump(global_q_table, f)
|
224
|
+
print(f"Q-table saved to {q_table_path}")
|
225
|
+
except Exception as e:
|
226
|
+
print(f"Error saving Q-table: {e}")
|
227
|
+
|
228
|
+
@staticmethod
|
229
|
+
def _load_q_table():
|
230
|
+
"""Load Q-table from the global JSON file"""
|
231
|
+
import json
|
232
|
+
import os
|
233
|
+
import ast
|
234
|
+
|
235
|
+
q_table_dir = "q_table"
|
236
|
+
q_table_path = os.path.join(q_table_dir, "global_q_table.json")
|
237
|
+
|
238
|
+
if not os.path.exists(q_table_path):
|
239
|
+
print(f"No existing Q-table found at {q_table_path}")
|
240
|
+
return {}
|
241
|
+
|
242
|
+
try:
|
243
|
+
with open(q_table_path, 'r') as f:
|
244
|
+
serialized_q_table = json.load(f)
|
245
|
+
|
246
|
+
# Convert serialized keys back to (state, action) tuples
|
247
|
+
q_table = {}
|
248
|
+
for key, value in serialized_q_table.items():
|
249
|
+
state_str, action = key.split('|')
|
250
|
+
try:
|
251
|
+
state = ast.literal_eval(state_str)
|
252
|
+
q_table[(state, action)] = value
|
253
|
+
except (ValueError, SyntaxError):
|
254
|
+
print(f"Skipping invalid state: {state_str}")
|
255
|
+
|
256
|
+
print(f"Loaded global Q-table with {len(q_table)} entries")
|
257
|
+
return q_table
|
258
|
+
except Exception as e:
|
259
|
+
print(f"Error loading Q-table: {e}")
|
260
|
+
return {}
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import os
|
2
|
+
from typing import List
|
3
|
+
from testgen.models.test_case import TestCase
|
4
|
+
from testgen.util.coverage_visualizer import CoverageVisualizer
|
5
|
+
from testgen.service.analysis_service import AnalysisService
|
6
|
+
|
7
|
+
|
8
|
+
class CFGService:
|
9
|
+
"""Service for generating and managing Control Flow Graph visualizations."""
|
10
|
+
def __init__(self):
|
11
|
+
self.analysis_service = AnalysisService()
|
12
|
+
self.visualizer = None
|
13
|
+
|
14
|
+
def initialize_visualizer(self, service):
|
15
|
+
self.visualizer = CoverageVisualizer()
|
16
|
+
self.visualizer.set_service(service)
|
17
|
+
|
18
|
+
@staticmethod
|
19
|
+
def create_visualization_directory() -> str:
|
20
|
+
"""Create visualization directory if it doesn't exist."""
|
21
|
+
visualization_dir = os.path.join(os.getcwd(), "visualize")
|
22
|
+
if not os.path.exists(visualization_dir):
|
23
|
+
os.makedirs(visualization_dir)
|
24
|
+
print(f"Created visualization directory: {visualization_dir}")
|
25
|
+
return visualization_dir
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def get_versioned_filename(directory: str, base_filename: str) -> str:
|
29
|
+
"""Generate a versioned filename to avoid overwriting existing files."""
|
30
|
+
version = 1
|
31
|
+
output_path = os.path.join(directory, f"{base_filename}.png")
|
32
|
+
|
33
|
+
while os.path.exists(output_path):
|
34
|
+
output_path = os.path.join(directory, f"{base_filename}_v{version}.png")
|
35
|
+
version += 1
|
36
|
+
|
37
|
+
return output_path
|
38
|
+
|
39
|
+
def visualize_test_coverage(self, file_path: str, test_cases: List[TestCase]) -> str | None:
|
40
|
+
visualization_dir = self.create_visualization_directory()
|
41
|
+
|
42
|
+
analysis_context = self.analysis_service.create_analysis_context(file_path)
|
43
|
+
|
44
|
+
filename = os.path.basename(file_path).replace('.py', '')
|
45
|
+
|
46
|
+
for func in analysis_context.function_data:
|
47
|
+
self.visualizer.get_covered_lines(file_path, func.func_def, test_cases)
|
48
|
+
|
49
|
+
base_filename = f"{filename}_{func.function_name}_coverage"
|
50
|
+
output_filepath = self.get_versioned_filename(visualization_dir, base_filename)
|
51
|
+
|
52
|
+
self.visualizer.generate_colored_cfg(func.function_name, output_filepath)
|
53
|
+
|
54
|
+
print(f"Generated CFG visualizations in {visualization_dir}")
|
55
|
+
return visualization_dir
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import os
|
2
|
+
from types import ModuleType
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
from testgen.models.test_case import TestCase
|
6
|
+
from testgen.generator.code_generator import CodeGenerator
|
7
|
+
from testgen.generator.doctest_generator import DocTestGenerator
|
8
|
+
from testgen.generator.pytest_generator import PyTestGenerator
|
9
|
+
from testgen.generator.unit_test_generator import UnitTestGenerator
|
10
|
+
from testgen.inspector.inspector import Inspector
|
11
|
+
from testgen.tree.node import Node
|
12
|
+
from testgen.tree.tree_utils import build_binary_tree
|
13
|
+
from testgen.models.generator_context import GeneratorContext
|
14
|
+
|
15
|
+
# Constants for test formats
|
16
|
+
UNITTEST_FORMAT = 1
|
17
|
+
PYTEST_FORMAT = 2
|
18
|
+
DOCTEST_FORMAT = 3
|
19
|
+
|
20
|
+
class GeneratorService:
|
21
|
+
def __init__(self, filepath: str, output_path: str, test_format: int = UNITTEST_FORMAT):
|
22
|
+
self.filepath = filepath
|
23
|
+
self.output_path = output_path
|
24
|
+
self.test_format = test_format
|
25
|
+
self.code_generator = CodeGenerator()
|
26
|
+
self.test_generator = UnitTestGenerator(generator_context=None)
|
27
|
+
self.generated_file_path = None
|
28
|
+
|
29
|
+
def set_test_format(self, test_format: int):
|
30
|
+
"""Set the test generator format."""
|
31
|
+
self.test_format = test_format
|
32
|
+
if test_format == UNITTEST_FORMAT:
|
33
|
+
print("SETTING TEST FORMAT TO UNITTEST")
|
34
|
+
self.test_generator = UnitTestGenerator(generator_context=None)
|
35
|
+
elif test_format == PYTEST_FORMAT:
|
36
|
+
print("SETTING TEST FORMAT TO PYTEST")
|
37
|
+
self.test_generator = PyTestGenerator(generator_context=None)
|
38
|
+
elif test_format == DOCTEST_FORMAT:
|
39
|
+
print("SETTING TEST FORMAT TO DOCTEST")
|
40
|
+
self.test_generator = DocTestGenerator(generator_context=None)
|
41
|
+
else:
|
42
|
+
raise NotImplementedError(f"Test format {test_format} not implemented")
|
43
|
+
|
44
|
+
def generate_test_file(self, module: ModuleType, class_name: str | None, test_cases: List[TestCase], output_path=None) -> str:
|
45
|
+
"""Generate a test file for the given test cases."""
|
46
|
+
|
47
|
+
filename = self.get_filename(self.filepath)
|
48
|
+
|
49
|
+
output_path = self.get_test_file_path(module.__name__, output_path)
|
50
|
+
|
51
|
+
context = GeneratorContext(
|
52
|
+
filepath=self.filepath,
|
53
|
+
filename=filename,
|
54
|
+
class_name=class_name,
|
55
|
+
module=module,
|
56
|
+
output_path=output_path,
|
57
|
+
test_cases=test_cases
|
58
|
+
)
|
59
|
+
|
60
|
+
self.test_generator.generator_context = context
|
61
|
+
|
62
|
+
self.test_generator.generate_test_header()
|
63
|
+
|
64
|
+
print("GENERATE TEST FILE: Generated test header")
|
65
|
+
|
66
|
+
self.generate_function_tests(test_cases)
|
67
|
+
|
68
|
+
print("GENERATE TEST FILE: Generate function tests")
|
69
|
+
|
70
|
+
print()
|
71
|
+
|
72
|
+
if self.test_format == DOCTEST_FORMAT:
|
73
|
+
print("SAVING DOCT TEST FILE")
|
74
|
+
self.test_generator.save_file()
|
75
|
+
return self.filepath
|
76
|
+
else:
|
77
|
+
self.test_generator.save_file()
|
78
|
+
return output_path
|
79
|
+
|
80
|
+
def generate_function_tests(self, test_cases: List[TestCase]) -> None:
|
81
|
+
"""Generate test functions for the given test cases."""
|
82
|
+
for i, test_case in enumerate(test_cases):
|
83
|
+
unique_func_name = f"{test_case.func_name}_{i}"
|
84
|
+
cases = [(test_case.inputs, test_case.expected)]
|
85
|
+
self.test_generator.generate_test_function(unique_func_name, test_case.func_name, cases)
|
86
|
+
|
87
|
+
def get_test_file_path(self, module_name: str, specified_path=None) -> str:
|
88
|
+
"""Determine the path for the generated test file."""
|
89
|
+
if specified_path is not None:
|
90
|
+
if os.path.exists(specified_path):
|
91
|
+
if os.path.isdir(specified_path):
|
92
|
+
self.ensure_init_py(specified_path)
|
93
|
+
return os.path.join(specified_path, f"test_{module_name.lower()}.py")
|
94
|
+
else:
|
95
|
+
print(f"Specified directory path: {specified_path} is not a directory.")
|
96
|
+
else:
|
97
|
+
print(f"Specified directory path: {specified_path} does not exist.")
|
98
|
+
|
99
|
+
current_dir = os.getcwd()
|
100
|
+
test_dir = os.path.join(current_dir, "tests")
|
101
|
+
|
102
|
+
if os.path.exists(test_dir):
|
103
|
+
if os.path.isdir(test_dir):
|
104
|
+
self.ensure_init_py(test_dir)
|
105
|
+
return os.path.join(test_dir, f"test_{module_name.lower()}.py")
|
106
|
+
else:
|
107
|
+
print(f"Test directory path: {test_dir} is not a directory.")
|
108
|
+
else:
|
109
|
+
print(f"Test directory path: {test_dir} does not exist, creating it.")
|
110
|
+
os.mkdir(test_dir)
|
111
|
+
self.ensure_init_py(test_dir)
|
112
|
+
|
113
|
+
return os.path.join(test_dir, f"test_{module_name.lower()}.py")
|
114
|
+
|
115
|
+
def ensure_init_py(self, directory: str):
|
116
|
+
"""Ensures that an __init__.py file exists in the given directory."""
|
117
|
+
init_file = os.path.join(directory, "__init__.py")
|
118
|
+
if not os.path.exists(init_file):
|
119
|
+
with open(init_file, "w") as f:
|
120
|
+
pass # Create an empty __init__.py file
|
121
|
+
print(f"Created __init__.py in {directory}")
|
122
|
+
|
123
|
+
def generate_function_code(self, file_path: str, class_name: str | None, functions: list) -> str:
|
124
|
+
"""Generate function code for a given class and its functions."""
|
125
|
+
trees = self.build_func_trees(functions)
|
126
|
+
if class_name:
|
127
|
+
print(class_name)
|
128
|
+
self.generated_file_path = f"generated_{class_name.lower()}.py"
|
129
|
+
else:
|
130
|
+
print(self.get_filename(file_path))
|
131
|
+
self.generated_file_path = f"generated_{self.get_filename(file_path)}"
|
132
|
+
|
133
|
+
# Create the class file
|
134
|
+
if class_name:
|
135
|
+
file = self.code_generator.generate_class(class_name)
|
136
|
+
file.close()
|
137
|
+
else:
|
138
|
+
file = open(self.generated_file_path, "w")
|
139
|
+
file.close()
|
140
|
+
|
141
|
+
# Append function implementations
|
142
|
+
with open(self.generated_file_path, "a") as file:
|
143
|
+
for func, root, params in trees:
|
144
|
+
is_class_method = True if class_name else False
|
145
|
+
code = self.code_generator.generate_code_from_tree(func.__name__, root, params, func, is_class_method)
|
146
|
+
if class_name:
|
147
|
+
for line in code.split("\n"):
|
148
|
+
file.write(f" {line}\n")
|
149
|
+
else:
|
150
|
+
file.write(code)
|
151
|
+
file.write("\n")
|
152
|
+
|
153
|
+
return self.generated_file_path
|
154
|
+
|
155
|
+
def build_func_trees(self, functions: list):
|
156
|
+
"""Build binary trees for function signatures."""
|
157
|
+
tree_list = []
|
158
|
+
for name, func in functions:
|
159
|
+
signature = Inspector.get_signature(func)
|
160
|
+
params = Inspector.get_params_not_self(signature)
|
161
|
+
root = Node(None)
|
162
|
+
build_binary_tree(root, 0, len(params))
|
163
|
+
tree_list.append((func, root, params))
|
164
|
+
return tree_list
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def get_filename(filepath: str) -> str:
|
168
|
+
"""Get filename from filepath."""
|
169
|
+
return os.path.basename(filepath)
|