testgenie-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. testgen/__init__.py +0 -0
  2. testgen/analyzer/__init__.py +0 -0
  3. testgen/analyzer/ast_analyzer.py +149 -0
  4. testgen/analyzer/contracts/__init__.py +0 -0
  5. testgen/analyzer/contracts/contract.py +13 -0
  6. testgen/analyzer/contracts/no_exception_contract.py +16 -0
  7. testgen/analyzer/contracts/nonnull_contract.py +15 -0
  8. testgen/analyzer/fuzz_analyzer.py +106 -0
  9. testgen/analyzer/random_feedback_analyzer.py +291 -0
  10. testgen/analyzer/reinforcement_analyzer.py +75 -0
  11. testgen/analyzer/test_case_analyzer.py +46 -0
  12. testgen/analyzer/test_case_analyzer_context.py +58 -0
  13. testgen/controller/__init__.py +0 -0
  14. testgen/controller/cli_controller.py +194 -0
  15. testgen/controller/docker_controller.py +169 -0
  16. testgen/docker/Dockerfile +22 -0
  17. testgen/docker/poetry.lock +361 -0
  18. testgen/docker/pyproject.toml +22 -0
  19. testgen/generator/__init__.py +0 -0
  20. testgen/generator/code_generator.py +66 -0
  21. testgen/generator/doctest_generator.py +208 -0
  22. testgen/generator/generator.py +55 -0
  23. testgen/generator/pytest_generator.py +77 -0
  24. testgen/generator/test_generator.py +26 -0
  25. testgen/generator/unit_test_generator.py +84 -0
  26. testgen/inspector/__init__.py +0 -0
  27. testgen/inspector/inspector.py +61 -0
  28. testgen/main.py +13 -0
  29. testgen/models/__init__.py +0 -0
  30. testgen/models/analysis_context.py +56 -0
  31. testgen/models/function_metadata.py +61 -0
  32. testgen/models/generator_context.py +63 -0
  33. testgen/models/test_case.py +8 -0
  34. testgen/presentation/__init__.py +0 -0
  35. testgen/presentation/cli_view.py +12 -0
  36. testgen/q_table/global_q_table.json +1 -0
  37. testgen/reinforcement/__init__.py +0 -0
  38. testgen/reinforcement/abstract_state.py +7 -0
  39. testgen/reinforcement/agent.py +153 -0
  40. testgen/reinforcement/environment.py +215 -0
  41. testgen/reinforcement/statement_coverage_state.py +33 -0
  42. testgen/service/__init__.py +0 -0
  43. testgen/service/analysis_service.py +260 -0
  44. testgen/service/cfg_service.py +55 -0
  45. testgen/service/generator_service.py +169 -0
  46. testgen/service/service.py +389 -0
  47. testgen/sqlite/__init__.py +0 -0
  48. testgen/sqlite/db.py +84 -0
  49. testgen/sqlite/db_service.py +219 -0
  50. testgen/tree/__init__.py +0 -0
  51. testgen/tree/node.py +7 -0
  52. testgen/tree/tree_utils.py +79 -0
  53. testgen/util/__init__.py +0 -0
  54. testgen/util/coverage_utils.py +168 -0
  55. testgen/util/coverage_visualizer.py +154 -0
  56. testgen/util/file_utils.py +110 -0
  57. testgen/util/randomizer.py +122 -0
  58. testgen/util/utils.py +143 -0
  59. testgen/util/z3_utils/__init__.py +0 -0
  60. testgen/util/z3_utils/ast_to_z3.py +99 -0
  61. testgen/util/z3_utils/branch_condition.py +72 -0
  62. testgen/util/z3_utils/constraint_extractor.py +36 -0
  63. testgen/util/z3_utils/variable_finder.py +10 -0
  64. testgen/util/z3_utils/z3_test_case.py +94 -0
  65. testgenie_py-0.1.0.dist-info/METADATA +24 -0
  66. testgenie_py-0.1.0.dist-info/RECORD +68 -0
  67. testgenie_py-0.1.0.dist-info/WHEEL +4 -0
  68. testgenie_py-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,260 @@
1
+ import inspect
2
+ import ast
3
+ import time
4
+ from types import ModuleType
5
+ from typing import Dict, List
6
+
7
+ import testgen
8
+ import testgen.util.file_utils
9
+ import testgen.util.file_utils as file_utils
10
+ import testgen.util.utils
11
+ from testgen.analyzer.ast_analyzer import ASTAnalyzer
12
+ from testgen.analyzer.fuzz_analyzer import FuzzAnalyzer
13
+ from testgen.analyzer.random_feedback_analyzer import RandomFeedbackAnalyzer
14
+ from testgen.models.test_case import TestCase
15
+ from testgen.analyzer.test_case_analyzer_context import TestCaseAnalyzerContext
16
+ from testgen.reinforcement.agent import ReinforcementAgent
17
+ from testgen.reinforcement.environment import ReinforcementEnvironment
18
+ from testgen.reinforcement.statement_coverage_state import StatementCoverageState
19
+ from testgen.models.analysis_context import AnalysisContext
20
+ from testgen.models.function_metadata import FunctionMetadata
21
+
22
+ # Constants for test strategies
23
+ AST_STRAT = 1
24
+ FUZZ_STRAT = 2
25
+ RANDOM_STRAT = 3
26
+ REINFORCE_STRAT = 4
27
+
28
+ class AnalysisService:
29
+ def __init__(self):
30
+ self.file_path = None
31
+ self.class_name = None
32
+ self.test_case_analyzer_context = TestCaseAnalyzerContext(None, None)
33
+ self.test_strategy = 0
34
+ self.reinforcement_mode = "train"
35
+
36
+ def generate_test_cases(self) -> List[TestCase]:
37
+ """Generate test cases using the current strategy."""
38
+ if self.test_strategy == REINFORCE_STRAT:
39
+ return self.do_reinforcement_learning(self.file_path)
40
+ else:
41
+ self.test_case_analyzer_context.do_logic()
42
+ return self.test_case_analyzer_context.test_cases
43
+
44
+ def create_analysis_context(self, filepath: str) -> AnalysisContext:
45
+ """Create an analysis context for the given file."""
46
+ filename = file_utils.get_filename(filepath)
47
+ module = file_utils.load_module(filepath)
48
+ class_name = self.get_class_name(module)
49
+ function_data = self.get_function_data(filename, module, class_name)
50
+ return AnalysisContext(filepath, filename, class_name, module, function_data)
51
+
52
+ def get_function_data(self, filename: str, module: ModuleType, class_name: str | None) -> List[FunctionMetadata]:
53
+ function_metadata_list: List[FunctionMetadata] = []
54
+
55
+ # Parse the module's source code
56
+ source_code = inspect.getsource(module)
57
+ tree = ast.parse(source_code)
58
+
59
+ # Find all function definitions in the module
60
+ for node in tree.body:
61
+ if isinstance(node, ast.FunctionDef):
62
+ function_metadata_list.append(self._create_function_metadata(filename, module, class_name, node))
63
+ elif isinstance(node, ast.ClassDef):
64
+ # Get the class name to check if it matches the expected class
65
+ ast_class_name = node.name
66
+
67
+ # If we have a specified class and it doesn't match, skip
68
+ if self.class_name and ast_class_name != self.class_name:
69
+ continue
70
+
71
+ # Process class methods
72
+ for class_node in node.body:
73
+ if isinstance(class_node, ast.FunctionDef):
74
+ # Skip private methods (those starting with _)
75
+ if not class_node.name.startswith('_'):
76
+ function_metadata_list.append(self._create_function_metadata(filename, module, class_name, class_node))
77
+
78
+ return function_metadata_list
79
+
80
+ def do_reinforcement_learning(self, filepath: str, mode: str = None) -> List[TestCase]:
81
+ mode = mode or self.reinforcement_mode
82
+ module: ModuleType = testgen.util.file_utils.load_module(filepath)
83
+ tree: ast.Module = testgen.util.file_utils.load_and_parse_file_for_tree(filepath)
84
+ functions: List[ast.FunctionDef] = testgen.util.utils.get_functions(tree)
85
+ time_limit: int = 30
86
+ all_test_cases: List[TestCase] = []
87
+
88
+ q_table = self._load_q_table()
89
+
90
+ for function in functions:
91
+ print(f"\nStarting reinforcement learning for function {function.name}")
92
+ start_time = time.time()
93
+ function_test_cases: List[TestCase] = []
94
+ best_coverage: float = 0.0
95
+
96
+ # Create environment and agent once per function
97
+ environment = ReinforcementEnvironment(filepath, function, module, function_test_cases, state=StatementCoverageState(None))
98
+ environment.state = StatementCoverageState(environment)
99
+
100
+ # Create agent with existing Q-table
101
+ agent = ReinforcementAgent(filepath, environment, function_test_cases, q_table)
102
+
103
+ if mode == "train":
104
+ new_test_cases = agent.do_q_learning()
105
+ function_test_cases.extend(new_test_cases)
106
+ else:
107
+ new_test_cases = agent.collect_test_cases()
108
+ function_test_cases.extend(new_test_cases)
109
+
110
+ print(f"\nTest cases for {function.name}: {len(function_test_cases)}")
111
+
112
+ current_coverage: float = environment.run_tests()
113
+ print(f"Current coverage: {function.name}: {current_coverage}")
114
+
115
+ q_table.update(agent.q_table)
116
+
117
+ # Process and filter unique test cases
118
+ seen = set()
119
+ unique_test_cases: List[TestCase] = []
120
+ for case in function_test_cases:
121
+ # Make case tuple hashable
122
+ case_inputs = tuple(case.inputs) if isinstance(case.inputs, list) else case.inputs
123
+ case_key = (case.func_name, case_inputs)
124
+ if case_key not in seen:
125
+ seen.add(case_key)
126
+ unique_test_cases.append(case)
127
+
128
+ all_test_cases.extend(unique_test_cases)
129
+ print(f"Final coverage for {function.name}: {best_coverage}%")
130
+ print(f"Final test cases for {function.name}: {len(unique_test_cases)}")
131
+
132
+ self._save_q_table(q_table)
133
+
134
+ print("\nReinforcement Learning Complete")
135
+ print(f"Total test cases found: {len(all_test_cases)}")
136
+ return all_test_cases
137
+
138
+ def _create_function_metadata(self, filename: str, module: ModuleType, class_name: str | None,
139
+ func_node: ast.FunctionDef) -> FunctionMetadata:
140
+ function_name = func_node.name
141
+
142
+ param_types = self._get_params(func_node)
143
+
144
+ return FunctionMetadata(filename, module, class_name, function_name, func_node, param_types)
145
+
146
+ def _get_params(self, func_node: ast.FunctionDef) -> Dict[str, str]:
147
+ # Extract parameter types
148
+ param_types = {}
149
+ for arg in func_node.args.args:
150
+ param_name = arg.arg
151
+ if param_name == 'self':
152
+ continue
153
+
154
+ if arg.annotation:
155
+ param_type = ast.unparse(arg.annotation)
156
+ param_types[param_name] = param_type
157
+ else:
158
+ param_types[param_name] = None
159
+ return param_types
160
+
161
+ def set_test_strategy(self, strategy: int, module_name: str, class_name: str):
162
+ """Set the test analysis strategy."""
163
+ self.test_strategy = strategy
164
+ analysis_context = self.create_analysis_context(self.file_path)
165
+
166
+ if strategy == AST_STRAT:
167
+ analyzer = ASTAnalyzer(analysis_context)
168
+ self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
169
+ elif strategy == FUZZ_STRAT:
170
+ analyzer = FuzzAnalyzer(analysis_context)
171
+ self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
172
+ elif strategy == RANDOM_STRAT:
173
+ analyzer = RandomFeedbackAnalyzer(analysis_context)
174
+ self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
175
+ elif strategy == REINFORCE_STRAT:
176
+ pass
177
+ else:
178
+ raise NotImplementedError(f"Test strategy {strategy} not implemented")
179
+
180
+ def set_file_path(self, path: str):
181
+ """Set the file path for analysis."""
182
+ self.file_path = path
183
+
184
+ def set_reinforcement_mode(self, mode: str):
185
+ self.reinforcement_mode = mode
186
+
187
+ @staticmethod
188
+ def get_class_name(module: ModuleType) -> str | None:
189
+ """Get class from module otherwise return None."""
190
+ for name, cls in inspect.getmembers(module, inspect.isclass):
191
+ if cls.__module__ == module.__name__:
192
+ return name
193
+ return None
194
+
195
+ @staticmethod
196
+ def _save_q_table(q_table):
197
+ """Save Q-table to a global JSON file"""
198
+ import json
199
+ import os
200
+
201
+ q_table_dir = "q_table"
202
+ os.makedirs(q_table_dir, exist_ok=True)
203
+ q_table_path = os.path.join(q_table_dir, "global_q_table.json")
204
+
205
+ global_q_table = {}
206
+ if os.path.exists(q_table_path):
207
+ try:
208
+ with open(q_table_path, 'r') as f:
209
+ global_q_table = json.load(f)
210
+ except Exception as e:
211
+ print(f"Error loading existing Q-table: {e}")
212
+
213
+ serializable_q_table = {}
214
+ for key, value in q_table.items():
215
+ state, action = key
216
+ state_str = str(state)
217
+ serializable_q_table[f"{state_str}|{action}"] = value
218
+
219
+ global_q_table = serializable_q_table # Replace with latest Q-table
220
+
221
+ try:
222
+ with open(q_table_path, 'w') as f:
223
+ json.dump(global_q_table, f)
224
+ print(f"Q-table saved to {q_table_path}")
225
+ except Exception as e:
226
+ print(f"Error saving Q-table: {e}")
227
+
228
+ @staticmethod
229
+ def _load_q_table():
230
+ """Load Q-table from the global JSON file"""
231
+ import json
232
+ import os
233
+ import ast
234
+
235
+ q_table_dir = "q_table"
236
+ q_table_path = os.path.join(q_table_dir, "global_q_table.json")
237
+
238
+ if not os.path.exists(q_table_path):
239
+ print(f"No existing Q-table found at {q_table_path}")
240
+ return {}
241
+
242
+ try:
243
+ with open(q_table_path, 'r') as f:
244
+ serialized_q_table = json.load(f)
245
+
246
+ # Convert serialized keys back to (state, action) tuples
247
+ q_table = {}
248
+ for key, value in serialized_q_table.items():
249
+ state_str, action = key.split('|')
250
+ try:
251
+ state = ast.literal_eval(state_str)
252
+ q_table[(state, action)] = value
253
+ except (ValueError, SyntaxError):
254
+ print(f"Skipping invalid state: {state_str}")
255
+
256
+ print(f"Loaded global Q-table with {len(q_table)} entries")
257
+ return q_table
258
+ except Exception as e:
259
+ print(f"Error loading Q-table: {e}")
260
+ return {}
@@ -0,0 +1,55 @@
1
+ import os
2
+ from typing import List
3
+ from testgen.models.test_case import TestCase
4
+ from testgen.util.coverage_visualizer import CoverageVisualizer
5
+ from testgen.service.analysis_service import AnalysisService
6
+
7
+
8
+ class CFGService:
9
+ """Service for generating and managing Control Flow Graph visualizations."""
10
+ def __init__(self):
11
+ self.analysis_service = AnalysisService()
12
+ self.visualizer = None
13
+
14
+ def initialize_visualizer(self, service):
15
+ self.visualizer = CoverageVisualizer()
16
+ self.visualizer.set_service(service)
17
+
18
+ @staticmethod
19
+ def create_visualization_directory() -> str:
20
+ """Create visualization directory if it doesn't exist."""
21
+ visualization_dir = os.path.join(os.getcwd(), "visualize")
22
+ if not os.path.exists(visualization_dir):
23
+ os.makedirs(visualization_dir)
24
+ print(f"Created visualization directory: {visualization_dir}")
25
+ return visualization_dir
26
+
27
+ @staticmethod
28
+ def get_versioned_filename(directory: str, base_filename: str) -> str:
29
+ """Generate a versioned filename to avoid overwriting existing files."""
30
+ version = 1
31
+ output_path = os.path.join(directory, f"{base_filename}.png")
32
+
33
+ while os.path.exists(output_path):
34
+ output_path = os.path.join(directory, f"{base_filename}_v{version}.png")
35
+ version += 1
36
+
37
+ return output_path
38
+
39
+ def visualize_test_coverage(self, file_path: str, test_cases: List[TestCase]) -> str | None:
40
+ visualization_dir = self.create_visualization_directory()
41
+
42
+ analysis_context = self.analysis_service.create_analysis_context(file_path)
43
+
44
+ filename = os.path.basename(file_path).replace('.py', '')
45
+
46
+ for func in analysis_context.function_data:
47
+ self.visualizer.get_covered_lines(file_path, func.func_def, test_cases)
48
+
49
+ base_filename = f"{filename}_{func.function_name}_coverage"
50
+ output_filepath = self.get_versioned_filename(visualization_dir, base_filename)
51
+
52
+ self.visualizer.generate_colored_cfg(func.function_name, output_filepath)
53
+
54
+ print(f"Generated CFG visualizations in {visualization_dir}")
55
+ return visualization_dir
@@ -0,0 +1,169 @@
1
+ import os
2
+ from types import ModuleType
3
+ from typing import List
4
+
5
+ from testgen.models.test_case import TestCase
6
+ from testgen.generator.code_generator import CodeGenerator
7
+ from testgen.generator.doctest_generator import DocTestGenerator
8
+ from testgen.generator.pytest_generator import PyTestGenerator
9
+ from testgen.generator.unit_test_generator import UnitTestGenerator
10
+ from testgen.inspector.inspector import Inspector
11
+ from testgen.tree.node import Node
12
+ from testgen.tree.tree_utils import build_binary_tree
13
+ from testgen.models.generator_context import GeneratorContext
14
+
15
+ # Constants for test formats
16
+ UNITTEST_FORMAT = 1
17
+ PYTEST_FORMAT = 2
18
+ DOCTEST_FORMAT = 3
19
+
20
+ class GeneratorService:
21
+ def __init__(self, filepath: str, output_path: str, test_format: int = UNITTEST_FORMAT):
22
+ self.filepath = filepath
23
+ self.output_path = output_path
24
+ self.test_format = test_format
25
+ self.code_generator = CodeGenerator()
26
+ self.test_generator = UnitTestGenerator(generator_context=None)
27
+ self.generated_file_path = None
28
+
29
+ def set_test_format(self, test_format: int):
30
+ """Set the test generator format."""
31
+ self.test_format = test_format
32
+ if test_format == UNITTEST_FORMAT:
33
+ print("SETTING TEST FORMAT TO UNITTEST")
34
+ self.test_generator = UnitTestGenerator(generator_context=None)
35
+ elif test_format == PYTEST_FORMAT:
36
+ print("SETTING TEST FORMAT TO PYTEST")
37
+ self.test_generator = PyTestGenerator(generator_context=None)
38
+ elif test_format == DOCTEST_FORMAT:
39
+ print("SETTING TEST FORMAT TO DOCTEST")
40
+ self.test_generator = DocTestGenerator(generator_context=None)
41
+ else:
42
+ raise NotImplementedError(f"Test format {test_format} not implemented")
43
+
44
+ def generate_test_file(self, module: ModuleType, class_name: str | None, test_cases: List[TestCase], output_path=None) -> str:
45
+ """Generate a test file for the given test cases."""
46
+
47
+ filename = self.get_filename(self.filepath)
48
+
49
+ output_path = self.get_test_file_path(module.__name__, output_path)
50
+
51
+ context = GeneratorContext(
52
+ filepath=self.filepath,
53
+ filename=filename,
54
+ class_name=class_name,
55
+ module=module,
56
+ output_path=output_path,
57
+ test_cases=test_cases
58
+ )
59
+
60
+ self.test_generator.generator_context = context
61
+
62
+ self.test_generator.generate_test_header()
63
+
64
+ print("GENERATE TEST FILE: Generated test header")
65
+
66
+ self.generate_function_tests(test_cases)
67
+
68
+ print("GENERATE TEST FILE: Generate function tests")
69
+
70
+ print()
71
+
72
+ if self.test_format == DOCTEST_FORMAT:
73
+ print("SAVING DOCT TEST FILE")
74
+ self.test_generator.save_file()
75
+ return self.filepath
76
+ else:
77
+ self.test_generator.save_file()
78
+ return output_path
79
+
80
+ def generate_function_tests(self, test_cases: List[TestCase]) -> None:
81
+ """Generate test functions for the given test cases."""
82
+ for i, test_case in enumerate(test_cases):
83
+ unique_func_name = f"{test_case.func_name}_{i}"
84
+ cases = [(test_case.inputs, test_case.expected)]
85
+ self.test_generator.generate_test_function(unique_func_name, test_case.func_name, cases)
86
+
87
+ def get_test_file_path(self, module_name: str, specified_path=None) -> str:
88
+ """Determine the path for the generated test file."""
89
+ if specified_path is not None:
90
+ if os.path.exists(specified_path):
91
+ if os.path.isdir(specified_path):
92
+ self.ensure_init_py(specified_path)
93
+ return os.path.join(specified_path, f"test_{module_name.lower()}.py")
94
+ else:
95
+ print(f"Specified directory path: {specified_path} is not a directory.")
96
+ else:
97
+ print(f"Specified directory path: {specified_path} does not exist.")
98
+
99
+ current_dir = os.getcwd()
100
+ test_dir = os.path.join(current_dir, "tests")
101
+
102
+ if os.path.exists(test_dir):
103
+ if os.path.isdir(test_dir):
104
+ self.ensure_init_py(test_dir)
105
+ return os.path.join(test_dir, f"test_{module_name.lower()}.py")
106
+ else:
107
+ print(f"Test directory path: {test_dir} is not a directory.")
108
+ else:
109
+ print(f"Test directory path: {test_dir} does not exist, creating it.")
110
+ os.mkdir(test_dir)
111
+ self.ensure_init_py(test_dir)
112
+
113
+ return os.path.join(test_dir, f"test_{module_name.lower()}.py")
114
+
115
+ def ensure_init_py(self, directory: str):
116
+ """Ensures that an __init__.py file exists in the given directory."""
117
+ init_file = os.path.join(directory, "__init__.py")
118
+ if not os.path.exists(init_file):
119
+ with open(init_file, "w") as f:
120
+ pass # Create an empty __init__.py file
121
+ print(f"Created __init__.py in {directory}")
122
+
123
+ def generate_function_code(self, file_path: str, class_name: str | None, functions: list) -> str:
124
+ """Generate function code for a given class and its functions."""
125
+ trees = self.build_func_trees(functions)
126
+ if class_name:
127
+ print(class_name)
128
+ self.generated_file_path = f"generated_{class_name.lower()}.py"
129
+ else:
130
+ print(self.get_filename(file_path))
131
+ self.generated_file_path = f"generated_{self.get_filename(file_path)}"
132
+
133
+ # Create the class file
134
+ if class_name:
135
+ file = self.code_generator.generate_class(class_name)
136
+ file.close()
137
+ else:
138
+ file = open(self.generated_file_path, "w")
139
+ file.close()
140
+
141
+ # Append function implementations
142
+ with open(self.generated_file_path, "a") as file:
143
+ for func, root, params in trees:
144
+ is_class_method = True if class_name else False
145
+ code = self.code_generator.generate_code_from_tree(func.__name__, root, params, func, is_class_method)
146
+ if class_name:
147
+ for line in code.split("\n"):
148
+ file.write(f" {line}\n")
149
+ else:
150
+ file.write(code)
151
+ file.write("\n")
152
+
153
+ return self.generated_file_path
154
+
155
+ def build_func_trees(self, functions: list):
156
+ """Build binary trees for function signatures."""
157
+ tree_list = []
158
+ for name, func in functions:
159
+ signature = Inspector.get_signature(func)
160
+ params = Inspector.get_params_not_self(signature)
161
+ root = Node(None)
162
+ build_binary_tree(root, 0, len(params))
163
+ tree_list.append((func, root, params))
164
+ return tree_list
165
+
166
+ @staticmethod
167
+ def get_filename(filepath: str) -> str:
168
+ """Get filename from filepath."""
169
+ return os.path.basename(filepath)