testgenie-py 0.3.7__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. testgen/analyzer/ast_analyzer.py +2 -11
  2. testgen/analyzer/fuzz_analyzer.py +1 -6
  3. testgen/analyzer/random_feedback_analyzer.py +20 -293
  4. testgen/analyzer/reinforcement_analyzer.py +59 -57
  5. testgen/analyzer/test_case_analyzer_context.py +0 -6
  6. testgen/controller/cli_controller.py +35 -29
  7. testgen/controller/docker_controller.py +1 -0
  8. testgen/db/dao.py +68 -0
  9. testgen/db/dao_impl.py +226 -0
  10. testgen/{sqlite → db}/db.py +15 -6
  11. testgen/generator/pytest_generator.py +2 -10
  12. testgen/generator/unit_test_generator.py +2 -11
  13. testgen/main.py +1 -3
  14. testgen/models/coverage_data.py +56 -0
  15. testgen/models/db_test_case.py +65 -0
  16. testgen/models/function.py +56 -0
  17. testgen/models/function_metadata.py +11 -1
  18. testgen/models/generator_context.py +32 -2
  19. testgen/models/source_file.py +29 -0
  20. testgen/models/test_result.py +38 -0
  21. testgen/models/test_suite.py +20 -0
  22. testgen/reinforcement/agent.py +1 -27
  23. testgen/reinforcement/environment.py +11 -93
  24. testgen/reinforcement/statement_coverage_state.py +5 -4
  25. testgen/service/analysis_service.py +31 -22
  26. testgen/service/cfg_service.py +3 -1
  27. testgen/service/coverage_service.py +115 -0
  28. testgen/service/db_service.py +140 -0
  29. testgen/service/generator_service.py +77 -20
  30. testgen/service/logging_service.py +2 -2
  31. testgen/service/service.py +62 -231
  32. testgen/service/test_executor_service.py +145 -0
  33. testgen/util/coverage_utils.py +38 -116
  34. testgen/util/coverage_visualizer.py +10 -9
  35. testgen/util/file_utils.py +10 -111
  36. testgen/util/randomizer.py +0 -26
  37. testgen/util/utils.py +197 -38
  38. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/METADATA +1 -1
  39. testgenie_py-0.3.8.dist-info/RECORD +72 -0
  40. testgen/inspector/inspector.py +0 -59
  41. testgen/presentation/__init__.py +0 -0
  42. testgen/presentation/cli_view.py +0 -12
  43. testgen/sqlite/__init__.py +0 -0
  44. testgen/sqlite/db_service.py +0 -239
  45. testgen/testgen.db +0 -0
  46. testgenie_py-0.3.7.dist-info/RECORD +0 -67
  47. /testgen/{inspector → db}/__init__.py +0 -0
  48. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/WHEEL +0 -0
  49. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,115 @@
1
+ import os
2
+ import subprocess
3
+ import time
4
+
5
+ from testgen.models.coverage_data import CoverageData
6
+ from testgen.service.logging_service import get_logger
7
+ from testgen.service.db_service import DBService
8
+
9
+ class CoverageService:
10
+ UNITTEST_FORMAT = 1
11
+ PYTEST_FORMAT = 2
12
+ DOCTEST_FORMAT = 3
13
+
14
+ def __init__(self):
15
+ self.logger = get_logger()
16
+
17
+ def run_coverage(self, test_file: str, source_file: str) -> CoverageData:
18
+ #self.wait_for_file(test_file)
19
+ self.logger.info(f"Running coverage on test file: {test_file}")
20
+ self.logger.info(f"Source file to measure: {source_file}")
21
+
22
+ try:
23
+ cov_data = self.collect_coverage(test_file, source_file)
24
+ return cov_data
25
+ except Exception as e:
26
+ self.logger.error(f"Error running coverage: {str(e)}")
27
+ raise RuntimeError(f"Error running coverage subprocess: {e}")
28
+
29
+ def collect_coverage(self, test_file: str, source_file: str) -> CoverageData:
30
+ try:
31
+ subprocess.run(["python", "-m", "coverage", "run", "--source=.", test_file], check=True)
32
+ result = subprocess.run(
33
+ ["python", "-m", "coverage", "report", source_file],
34
+ check=True,
35
+ capture_output=True,
36
+ text=True
37
+ )
38
+ coverage_output = result.stdout
39
+ subprocess.run(["python", "-m", "coverage", "json"], check=True)
40
+ return self.parse_coverage_data(coverage_output, source_file)
41
+ except subprocess.CalledProcessError as e:
42
+ self.logger.error(f"Error collecting coverage: {str(e)}")
43
+ return CoverageData(
44
+ coverage_type="file",
45
+ executed_lines=0,
46
+ missed_lines=0,
47
+ branch_coverage=0.0,
48
+ source_file_id=-1,
49
+ function_id=None
50
+ )
51
+
52
+ def parse_coverage_data(self, coverage_output: str, file_path: str) -> CoverageData:
53
+ lines = coverage_output.strip().split('\n')
54
+ executed_lines = missed_lines = total_lines = 0
55
+ branch_coverage = 0.0
56
+
57
+ if lines:
58
+ for line in lines:
59
+ if file_path in line:
60
+ parts = line.split()
61
+ if len(parts) >= 4:
62
+ try:
63
+ total_lines = int(parts[-3])
64
+ missed_lines = int(parts[-2])
65
+ executed_lines = total_lines - missed_lines
66
+ coverage_str = parts[-1].strip('%')
67
+ branch_coverage = float(coverage_str) / 100
68
+ break
69
+ except (ValueError, IndexError) as e:
70
+ self.logger.error(f"Error parsing coverage data: {e}")
71
+
72
+ # You may want to look up the source_file_id here if needed
73
+ source_file_id = -1 # Set this appropriately in your context
74
+
75
+ return CoverageData(
76
+ coverage_type="file",
77
+ executed_lines=executed_lines,
78
+ missed_lines=missed_lines,
79
+ branch_coverage=branch_coverage,
80
+ source_file_id=source_file_id,
81
+ function_id=None
82
+ )
83
+
84
+ # Save coverage data to database currently not working
85
+ """"
86
+ def save_coverage_data(self, db_service: DBService, coverage_data: CoverageData, file_path: str) -> None:
87
+ if db_service is None:
88
+ self.logger.debug("Skipping database operations - no DB service provided")
89
+ return
90
+
91
+ try:
92
+ source_file_id = db_service.get_source_file_id_by_path(file_path)
93
+ if source_file_id == -1:
94
+ self.logger.error(f"Source file not found in database: {file_path}")
95
+ return
96
+
97
+ db_service.insert_coverage_data(
98
+ file_name=file_path,
99
+ executed_lines=coverage_data.executed_lines,
100
+ missed_lines=coverage_data.missed_lines,
101
+ branch_coverage=coverage_data.branch_coverage,
102
+ source_file_id=source_file_id
103
+ )
104
+ except Exception as e:
105
+ self.logger.error(f"Error saving coverage data to database: {e}")
106
+ """
107
+
108
+ @staticmethod
109
+ def wait_for_file(file_path, retries=5, delay=1):
110
+ """Wait for the generated file to appear."""
111
+ while retries > 0 and not os.path.exists(file_path):
112
+ time.sleep(delay)
113
+ retries -= 1
114
+ if not os.path.exists(file_path):
115
+ raise FileNotFoundError(f"File '{file_path}' not found after waiting.")
@@ -0,0 +1,140 @@
1
+ import os
2
+ from datetime import datetime
3
+ from types import ModuleType
4
+ from typing import List, Dict, Any
5
+ import testgen.util.utils as utils
6
+ from testgen.models.db_test_case import DBTestCase
7
+ from testgen.models.function import Function
8
+ from testgen.models.source_file import SourceFile
9
+ from testgen.models.test_case import TestCase
10
+
11
+ from testgen.db.dao_impl import DaoImpl
12
+
13
+ class DBService:
14
+ def __init__(self, db_name="testgen.db"):
15
+ self.dao = DaoImpl(db_name)
16
+
17
+ def close(self):
18
+ """Close the database connection."""
19
+ self.dao.close()
20
+
21
+ def insert_test_suite(self, name: str) -> int:
22
+ return self.dao.insert_test_suite(name)
23
+
24
+ def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
25
+ return self.dao.insert_source_file(path, lines_of_code, last_modified)
26
+
27
+ def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
28
+ return self.dao.insert_function(name, params, start_line, end_line, source_file_id)
29
+
30
+ def insert_test_case(self, test_case: TestCase, test_suite_id: int, function_id: int, test_method_type: int) -> int:
31
+ return self.dao.insert_test_case(test_case, test_suite_id, function_id, test_method_type)
32
+
33
+ def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
34
+ return self.dao.insert_test_result(test_case_id, status, error)
35
+
36
+ # TODO: Add support for function_id
37
+ def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
38
+ branch_coverage: float, source_file_id: int, function_id=None) -> int:
39
+ return self.dao.insert_coverage_data(file_name, executed_lines, missed_lines, branch_coverage, source_file_id, function_id=None)
40
+
41
+ def get_test_suites(self):
42
+ return self.dao.get_test_suites()
43
+
44
+ def get_source_file_id_by_path(self, filepath: str) -> int:
45
+ return self.dao.get_source_file_id_by_path(filepath)
46
+
47
+ def get_test_suite_id_by_name(self, name: str) -> int:
48
+ return self.dao.get_test_suite_id_by_name(name)
49
+
50
+ def get_functions_by_file(self, filepath: str) -> List[Function]:
51
+ return self.dao.get_functions_by_file(filepath)
52
+
53
+ def get_function_id_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int) -> int:
54
+ return self.dao.get_function_by_name_file_id_start(name, source_file_id, start_line)
55
+
56
+ def get_test_cases_by_function(self, function_name):
57
+ return self.dao.get_test_cases_by_function(function_name)
58
+
59
+ def get_coverage_by_file(self, file_path):
60
+ return self.dao.get_coverage_by_file(file_path)
61
+
62
+ def get_test_file_data(self, file_path: str):
63
+ return self.dao.get_test_file_data(file_path)
64
+
65
+ def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
66
+ return self.dao.get_test_case_id_by_func_id_input_expected(function_id, inputs, expected)
67
+
68
+ def save_test_generation_data(self, file_path: str, test_cases: list, test_strategy: int, module: ModuleType, class_name: str | None):
69
+ """Save test generation data to the database."""
70
+ source_file_data = self._get_source_file_data(file_path)
71
+ source_file_id = self.insert_source_file(source_file_data.path, source_file_data.lines_of_code, source_file_data.last_modified)
72
+
73
+ test_suite_name = class_name if class_name else module.__name__
74
+ test_suite_id = self.insert_test_suite(test_suite_name)
75
+
76
+ function_data = self._get_function_data(file_path)
77
+ for function in function_data:
78
+ self.insert_function(function.name, function.params, function.start_line, function.end_line, source_file_id)
79
+
80
+ test_cases_data = self._get_test_cases_data(source_file_id, test_suite_id, function_data, test_cases, test_strategy)
81
+ for test_case in test_cases_data:
82
+ self.insert_test_case(TestCase(test_case.test_function, test_case.inputs, test_case.expected_output), test_case.test_suite_id, test_case.function_id, test_strategy)
83
+
84
+ def _get_source_file_data(self, file_path: str) -> SourceFile:
85
+ lines_of_code = sum(1 for _ in open(file_path, 'r')) # Count lines in file
86
+ last_modified_time = os.path.getmtime(file_path)
87
+ return SourceFile(file_path, lines_of_code, last_modified_time)
88
+
89
+ def _get_function_data(self, file_path: str) -> List[Function]:
90
+ return utils.get_list_of_functions(file_path)
91
+
92
+ def _get_test_cases_data(self, source_file_id: int, test_suite_id: int, function_data: List[Function], test_cases: List[TestCase], test_strategy: int) -> List[DBTestCase]:
93
+ db_test_cases = []
94
+ for test_case in test_cases:
95
+ function_id = self.match_test_case_to_function_for_id(source_file_id, test_case, function_data)
96
+ db_test_case = DBTestCase(test_case.expected, test_case.inputs, test_case.func_name, datetime.now(), test_strategy, test_suite_id, function_id)
97
+ db_test_cases.append(db_test_case)
98
+ return db_test_cases
99
+
100
+ def match_test_case_to_function_for_id(self, source_file_id: int, test_case: TestCase, functions: List[Function]) -> int:
101
+ func_name = test_case.func_name
102
+ if "." in func_name:
103
+ func_name = func_name.split(".")[-1]
104
+
105
+ candidate_functions = [f for f in functions if f.name.endswith(func_name)]
106
+
107
+ if not candidate_functions:
108
+ return -1
109
+
110
+ if len(candidate_functions) == 1:
111
+ return self.get_function_id_by_name_file_id_start(candidate_functions[0].name, source_file_id, candidate_functions[0].start_line)
112
+
113
+ # Match by parameter count
114
+ input_param_count = len(test_case.inputs) if isinstance(test_case.inputs, dict) else 1
115
+
116
+ for function in candidate_functions:
117
+ # Parse params from string representation to dict if needed
118
+ params = function.params
119
+ if isinstance(params, str):
120
+ try:
121
+ params = eval(params) # Convert string representation to dict
122
+ except (SyntaxError, NameError):
123
+ continue
124
+
125
+ param_count = len(params) if isinstance(params, dict) else 0
126
+
127
+ if param_count == input_param_count:
128
+ return self.get_function_id_by_name_file_id_start(function.name, source_file_id, function.start_line)
129
+
130
+ return -1
131
+
132
+ def _get_test_results(self, filepath: str, execution_results, test_format: int):
133
+ source_file_id = self.get_source_file_id_by_path(filepath)
134
+ functions = self.get_functions_by_file(filepath)
135
+ for result in execution_results:
136
+ name = result.name
137
+ test_case = utils.parse_test_case_from_result_name(name, test_format)
138
+ function_id = self.match_test_case_to_function_for_id(source_file_id, test_case, functions)
139
+ test_case_id = self.get_test_case_id_by_func_id_input_expected(function_id, str(test_case.inputs), test_case.expected)
140
+ self.insert_test_result(test_case_id, result.status, result.error)
@@ -8,11 +8,11 @@ from testgen.generator.code_generator import CodeGenerator
8
8
  from testgen.generator.doctest_generator import DocTestGenerator
9
9
  from testgen.generator.pytest_generator import PyTestGenerator
10
10
  from testgen.generator.unit_test_generator import UnitTestGenerator
11
- from testgen.inspector.inspector import Inspector
12
11
  from testgen.service.logging_service import get_logger
13
12
  from testgen.tree.node import Node
14
13
  from testgen.tree.tree_utils import build_binary_tree
15
14
  from testgen.models.generator_context import GeneratorContext
15
+ from testgen.util.file_utils import find_project_root
16
16
 
17
17
  # Constants for test formats
18
18
  UNITTEST_FORMAT = 1
@@ -49,23 +49,9 @@ class GeneratorService:
49
49
  filename = self.get_filename(self.filepath)
50
50
  output_path = self.get_test_file_path(module.__name__, output_path)
51
51
 
52
- # Determine the actual class name used in the module
53
- actual_class_name = class_name
54
- if 'generated_' in self.filepath and class_name:
55
- # For generated classes, find the actual class name in the module
56
- for name, obj in inspect.getmembers(module):
57
- if inspect.isclass(obj):
58
- actual_class_name = name
59
- break
52
+ actual_class_name = self.resolve_class_name(module, class_name)
60
53
 
61
- context = GeneratorContext(
62
- filepath=self.filepath,
63
- filename=filename,
64
- class_name=actual_class_name, # Use the actual class name
65
- module=module,
66
- output_path=output_path,
67
- test_cases=test_cases
68
- )
54
+ context = self.get_generator_context(self.filepath, module, actual_class_name, test_cases, output_path)
69
55
 
70
56
  self.test_generator.generator_context = context
71
57
 
@@ -73,7 +59,6 @@ class GeneratorService:
73
59
 
74
60
  self.generate_function_tests(test_cases)
75
61
 
76
-
77
62
  if self.test_format == DOCTEST_FORMAT:
78
63
  self.logger.debug("SAVING DOCT TEST FILE")
79
64
  self.test_generator.save_file()
@@ -172,17 +157,89 @@ class GeneratorService:
172
157
 
173
158
  return self.generated_file_path
174
159
 
160
+ def get_generator_context(self, filepath: str, module: ModuleType, class_name: str | None, test_cases: List[TestCase], output_path: str) -> GeneratorContext:
161
+ file_dir = os.path.dirname(os.path.abspath(filepath))
162
+ is_package = os.path.exists(os.path.join(file_dir, '__init__.py'))
163
+
164
+ package_name, import_path = self._resolve_package_name_and_import_path(filepath)
165
+
166
+ generator_context = GeneratorContext(
167
+ filepath=filepath,
168
+ filename=self.get_filename(filepath),
169
+ class_name=class_name,
170
+ module=module,
171
+ output_path=output_path,
172
+ test_cases=test_cases,
173
+ is_package=is_package,
174
+ package_name=package_name,
175
+ import_path=import_path
176
+ )
177
+
178
+ print(f"Generator Context: {generator_context.filepath} {generator_context.filename} {generator_context.class_name} {generator_context.is_package} {generator_context.package_name} {generator_context.import_path}")
179
+
180
+ return generator_context
181
+
182
+ @staticmethod
183
+ def _resolve_package_name_and_import_path(filepath: str) -> tuple:
184
+ if not os.path.exists(filepath) or not filepath.endswith('.py'):
185
+ raise ValueError(f"Invalid Python file: {filepath}")
186
+
187
+ # Get the directory and filename
188
+ file_dir = os.path.dirname(os.path.abspath(filepath))
189
+ module_name = os.path.splitext(os.path.basename(filepath))[0]
190
+
191
+ # Find the project root by looking for setup.py or a .git directory
192
+ project_root = find_project_root(file_dir)
193
+
194
+ # Build the import path based on the file's location relative to the project root
195
+ if project_root:
196
+ rel_path = os.path.relpath(file_dir, project_root)
197
+ if rel_path == '.':
198
+ # File is directly in the project root
199
+ import_path = module_name
200
+ package_name = ''
201
+ else:
202
+ # File is in a subdirectory
203
+ path_parts = rel_path.replace('\\', '/').split('/')
204
+ # Filter out any empty parts
205
+ path_parts = [part for part in path_parts if part]
206
+
207
+ if path_parts:
208
+ package_name = path_parts[0]
209
+ # Construct the full import path
210
+ import_path = '.'.join(path_parts) + '.' + module_name
211
+ else:
212
+ package_name = ''
213
+ import_path = module_name
214
+ else:
215
+ # Fallback if we can't find a project root
216
+ package_name = ''
217
+ import_path = module_name
218
+
219
+ return package_name, import_path
220
+
175
221
  def build_func_trees(self, functions: list):
176
222
  """Build binary trees for function signatures."""
177
223
  tree_list = []
178
224
  for name, func in functions:
179
- signature = Inspector.get_signature(func)
180
- params = Inspector.get_params_not_self(signature)
225
+ signature = inspect.signature(func)
226
+ params = [param for param, value in signature.parameters.items() if param != 'self']
181
227
  root = Node(None)
182
228
  build_binary_tree(root, 0, len(params))
183
229
  tree_list.append((func, root, params))
184
230
  return tree_list
185
231
 
232
+ def resolve_class_name(self, module: ModuleType, class_name: str):
233
+ # Determine the actual class name used in the module
234
+ actual_class_name = class_name
235
+ if 'generated_' in self.filepath and class_name:
236
+ # For generated classes, find the actual class name in the module
237
+ for name, obj in inspect.getmembers(module):
238
+ if inspect.isclass(obj):
239
+ actual_class_name = name
240
+ break
241
+ return actual_class_name
242
+
186
243
  @staticmethod
187
244
  def get_filename(filepath: str) -> str:
188
245
  """Get filename from filepath."""
@@ -96,7 +96,7 @@ def get_logger():
96
96
  logger = LoggingService.get_instance()
97
97
 
98
98
  # If logger hasn't been initialized yet, set up a basic configuration
99
- if not LoggingService._initialized:
100
- logger.initialize(debug_mode=False, console_output=True)
99
+ """if not LoggingService._initialized:
100
+ logger.initialize(debug_mode=False, console_output=True)"""
101
101
 
102
102
  return logger