testgenie-py 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,100 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from enum import Enum
5
+ from typing import Optional
6
+
7
+
8
+ class LogLevel(Enum):
9
+ DEBUG = logging.DEBUG
10
+ INFO = logging.INFO
11
+ WARNING = logging.WARNING
12
+ ERROR = logging.ERROR
13
+ CRITICAL = logging.CRITICAL
14
+
15
+
16
+ class LoggingService:
17
+ """Centralized logging service for testgen framework"""
18
+
19
+ _instance = None
20
+ _initialized = False
21
+
22
+ @classmethod
23
+ def get_instance(cls):
24
+ """Get singleton instance of LoggingService"""
25
+ if cls._instance is None:
26
+ cls._instance = LoggingService()
27
+ return cls._instance
28
+
29
+ def __init__(self):
30
+ self.logger = logging.getLogger('testgen')
31
+ self.formatter = logging.Formatter(
32
+ '%(asctime)s [%(levelname)s] %(message)s',
33
+ datefmt='%Y-%m-%d %H:%M:%S'
34
+ )
35
+ self.file_handler = None
36
+ self.console_handler = None
37
+
38
+ def initialize(self,
39
+ debug_mode: bool = False,
40
+ log_file: Optional[str] = None,
41
+ console_output: bool = True):
42
+ """Initialize the logging service"""
43
+ if LoggingService._initialized:
44
+ return
45
+
46
+ # Set the base logging level
47
+ level = LogLevel.DEBUG.value if debug_mode else LogLevel.INFO.value
48
+ self.logger.setLevel(level)
49
+
50
+ # Add console handler if requested
51
+ if console_output:
52
+ self.console_handler = logging.StreamHandler(sys.stdout)
53
+ self.console_handler.setFormatter(self.formatter)
54
+ self.console_handler.setLevel(level)
55
+ self.logger.addHandler(self.console_handler)
56
+
57
+ # Add file handler if path provided
58
+ if log_file:
59
+ # Ensure directory exists
60
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
61
+ self.file_handler = logging.FileHandler(log_file)
62
+ self.file_handler.setFormatter(self.formatter)
63
+ self.file_handler.setLevel(level)
64
+ self.logger.addHandler(self.file_handler)
65
+
66
+ # Mark as initialized
67
+ LoggingService._initialized = True
68
+ self.info(f"Logging initialized - Debug mode: {debug_mode}")
69
+
70
+ def debug(self, message: str):
71
+ """Log debug message"""
72
+ self.logger.debug(message)
73
+
74
+ def info(self, message: str):
75
+ """Log info message"""
76
+ self.logger.info(message)
77
+
78
+ def warning(self, message: str):
79
+ """Log warning message"""
80
+ self.logger.warning(message)
81
+
82
+ def error(self, message: str):
83
+ """Log error message"""
84
+ self.logger.error(message)
85
+
86
+ def critical(self, message: str):
87
+ """Log critical message"""
88
+ self.logger.critical(message)
89
+
90
+
91
+ # Global accessor function for easy import and use
92
+ def get_logger():
93
+ """Get the global logger instance"""
94
+ logger = LoggingService.get_instance()
95
+
96
+ # If logger hasn't been initialized yet, set up a basic configuration
97
+ if not LoggingService._initialized:
98
+ logger.initialize(debug_mode=False, console_output=True)
99
+
100
+ return logger
@@ -14,7 +14,8 @@ from testgen.service.analysis_service import AnalysisService
14
14
  from testgen.service.generator_service import GeneratorService
15
15
  from testgen.sqlite.db_service import DBService
16
16
  from testgen.models.analysis_context import AnalysisContext
17
- from testgen.util.coverage_visualizer import CoverageVisualizer
17
+ from testgen.service.logging_service import get_logger
18
+
18
19
 
19
20
  # Constants for test strategies
20
21
  AST_STRAT = 1
@@ -29,26 +30,32 @@ DOCTEST_FORMAT = 3
29
30
 
30
31
  class Service:
31
32
  def __init__(self):
33
+ self.debug_mode: bool = False
32
34
  self.test_strategy: int = 0
33
35
  self.test_format: int = 0
34
36
  self.file_path = None
35
37
  self.generated_file_path = None
36
38
  self.class_name = None
37
39
  self.test_cases = []
40
+ self.logger = get_logger()
38
41
  self.reinforcement_mode = "train"
39
42
 
40
43
  # Initialize specialized services
41
44
  self.analysis_service = AnalysisService()
42
45
  self.generator_service = GeneratorService(None, None, None)
43
- self.db_service = DBService()
46
+ # Only initialize DB service if not running in Docker
47
+ if os.environ.get("RUNNING_IN_DOCKER") is None:
48
+ self.db_service = DBService()
49
+ else:
50
+ # Create a dummy DB service that doesn't do anything
51
+ self.db_service = None
44
52
 
45
53
  def select_all_from_db(self) -> None:
46
54
  rows = self.db_service.get_test_suites()
47
55
  for row in rows:
48
56
  print(repr(dict(row)))
49
57
 
50
- def generate_tests(self, output_path=None):
51
- """Generate tests for a class or module."""
58
+ def generate_test_cases(self) -> List[TestCase] | None:
52
59
  module = file_utils.load_module(self.file_path)
53
60
  class_name = self.analysis_service.get_class_name(module)
54
61
 
@@ -65,24 +72,38 @@ class Service:
65
72
 
66
73
  test_cases: List[TestCase] = []
67
74
  if self.test_strategy == REINFORCE_STRAT:
68
- test_cases = self.analysis_service.do_reinforcement_learning(self.file_path, self.reinforcement_mode)
75
+ test_cases = self.analysis_service.do_reinforcement_learning(self.file_path, class_name, self.reinforcement_mode)
69
76
  else:
70
77
  test_cases = self.analysis_service.generate_test_cases()
71
78
 
72
- self.test_cases = test_cases
73
-
74
- file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
75
- self.db_service.save_test_generation_data(
76
- file_path_to_use,
77
- test_cases,
78
- self.test_strategy,
79
- class_name
80
- )
81
-
82
79
  if os.environ.get("RUNNING_IN_DOCKER") is not None:
83
- print(f"Serializing test cases {test_cases}")
80
+ self.debug(f"Serializing test cases {test_cases}")
84
81
  self.serialize_test_cases(test_cases)
85
82
  return None # Exit early in analysis-only mode
83
+
84
+ return test_cases
85
+
86
+ def generate_tests(self, output_path=None):
87
+ module = file_utils.load_module(self.file_path)
88
+ class_name = self.analysis_service.get_class_name(module)
89
+
90
+ test_cases = self.generate_test_cases()
91
+
92
+ # Only process if we have test cases
93
+ if test_cases is None:
94
+ return None
95
+
96
+ self.test_cases = test_cases
97
+
98
+ # Only save to DB if not running in Docker
99
+ if os.environ.get("RUNNING_IN_DOCKER") is None:
100
+ file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
101
+ self.db_service.save_test_generation_data(
102
+ file_path_to_use,
103
+ test_cases,
104
+ self.test_strategy,
105
+ class_name
106
+ )
86
107
 
87
108
  test_file = self.generate_test_file(test_cases, output_path, module, class_name)
88
109
 
@@ -112,6 +133,8 @@ class Service:
112
133
  output_path
113
134
  )
114
135
 
136
+ print(f"Generated test file: {test_file}")
137
+
115
138
  # Ensure the test file is ready
116
139
  Service.wait_for_file(test_file)
117
140
  return test_file
@@ -127,6 +150,7 @@ class Service:
127
150
  """Run coverage analysis on the generated tests."""
128
151
  Service.wait_for_file(test_file)
129
152
  file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
153
+ print(f"File path to use for coverage: {file_path_to_use}")
130
154
  coverage_output = ""
131
155
 
132
156
  try:
@@ -166,32 +190,40 @@ class Service:
166
190
 
167
191
  def _save_coverage_data(self, coverage_output, file_path):
168
192
  """Parse coverage output and save to database."""
193
+ # Skip if running in Docker or DB service is None
194
+ if os.environ.get("RUNNING_IN_DOCKER") is not None or self.db_service is None:
195
+ self.debug("Skipping database operations in Docker container")
196
+ return
197
+
169
198
  try:
170
199
  lines = coverage_output.strip().split('\n')
171
- for line in lines:
172
- if file_path in line:
173
- parts = line.split()
174
- if len(parts) >= 4:
175
- file_name = os.path.basename(file_path)
176
- try:
177
- total_lines = int(parts[-3])
178
- missed_lines = int(parts[-2])
179
- executed_lines = total_lines - missed_lines
180
- coverage_str = parts[-1].strip('%')
181
- branch_coverage = float(coverage_str) / 100
182
-
183
- source_file_id = self._get_source_file_id(file_path)
184
-
185
- self.db_service.insert_coverage_data(
186
- file_name,
187
- executed_lines,
188
- missed_lines,
189
- branch_coverage,
190
- source_file_id
191
- )
192
- break
193
- except (ValueError, IndexError) as e:
194
- print(f"Error parsing coverage data: {e}")
200
+ if not lines:
201
+ raise ValueError("No coverage data found in the output.")
202
+ else:
203
+ for line in lines:
204
+ if file_path in line:
205
+ parts = line.split()
206
+ if len(parts) >= 4:
207
+ file_name = os.path.basename(file_path)
208
+ try:
209
+ total_lines = int(parts[-3])
210
+ missed_lines = int(parts[-2])
211
+ executed_lines = total_lines - missed_lines
212
+ coverage_str = parts[-1].strip('%')
213
+ branch_coverage = float(coverage_str) / 100
214
+
215
+ source_file_id = self._get_source_file_id(file_path)
216
+
217
+ self.db_service.insert_coverage_data(
218
+ file_name,
219
+ executed_lines,
220
+ missed_lines,
221
+ branch_coverage,
222
+ source_file_id
223
+ )
224
+ break
225
+ except (ValueError, IndexError) as e:
226
+ print(f"Error parsing coverage data: {e}")
195
227
  except Exception as e:
196
228
  print(f"Error saving coverage data: {e}")
197
229
 
@@ -261,6 +293,7 @@ class Service:
261
293
 
262
294
  def set_file_path(self, path: str):
263
295
  """Set the file path for analysis and validate it."""
296
+ print(f"Setting file path: {path}")
264
297
  if os.path.isfile(path) and path.endswith(".py"):
265
298
  self.file_path = path
266
299
  self.analysis_service.set_file_path(path)
@@ -386,4 +419,11 @@ class Service:
386
419
  def set_reinforcement_mode(self, mode: str):
387
420
  self.reinforcement_mode = mode
388
421
  if hasattr(self ,'analysis_service'):
389
- self.analysis_service.set_reinforcement_mode(mode)
422
+ self.analysis_service.set_reinforcement_mode(mode)
423
+
424
+ def set_debug_mode(self, debug: bool):
425
+ self.debug_mode = debug
426
+
427
+ def debug(self, message: str):
428
+ if self.debug_mode:
429
+ self.logger.debug(message)
testgen/testgen.db ADDED
Binary file
@@ -24,12 +24,39 @@ def get_branch_coverage(file_name, func, *args) -> list:
24
24
  return branches
25
25
 
26
26
 
27
- def get_coverage_analysis(file_name, func_name, args) -> tuple:
27
+ def get_coverage_analysis(file_name, class_name: str | None, func_name, args) -> tuple:
28
28
  tree = load_and_parse_file_for_tree(file_name)
29
29
  func_node = None
30
30
  func_start = None
31
31
  func_end = None
32
+
33
+ # Process tree body
32
34
  for i, node in enumerate(tree.body):
35
+ # Handle class methods
36
+ if isinstance(node, ast.ClassDef) and class_name is not None:
37
+ # Search within class body with its own index
38
+ for j, class_node in enumerate(node.body):
39
+ if isinstance(class_node, ast.FunctionDef) and class_node.name == func_name:
40
+ func_node = class_node
41
+ func_start = class_node.lineno
42
+
43
+ # Now correctly check if this is the last method in the class
44
+ if j == len(node.body) - 1:
45
+ # Last method in class - find maximum line in method
46
+ max_lines = [line.lineno for line in ast.walk(class_node)
47
+ if hasattr(line, 'lineno') and line.lineno]
48
+ func_end = max(max_lines) if max_lines else func_start
49
+ else:
50
+ # Not last method - use next method's line number minus 1
51
+ next_node = node.body[j + 1] # Correct index now
52
+ if hasattr(next_node, 'lineno'):
53
+ func_end = next_node.lineno - 1
54
+ else:
55
+ # Fallback using max line in method
56
+ max_lines = [line.lineno for line in ast.walk(class_node)
57
+ if hasattr(line, 'lineno') and line.lineno]
58
+ func_end = max(max_lines) if max_lines else func_start
59
+ break
33
60
  if isinstance(node, ast.FunctionDef) and node.name == func_name:
34
61
  func_node = node
35
62
  func_start = node.lineno
@@ -55,7 +82,12 @@ def get_coverage_analysis(file_name, func_name, args) -> tuple:
55
82
  cov.start()
56
83
  module = load_module(file_name)
57
84
 
58
- func = getattr(module, func_name)
85
+ if class_name is not None:
86
+ class_obj = getattr(module, class_name)
87
+ instance = class_obj()
88
+ func = getattr(instance, func_name)
89
+ else:
90
+ func = getattr(module, func_name)
59
91
 
60
92
  func(*args)
61
93
 
@@ -64,7 +96,6 @@ def get_coverage_analysis(file_name, func_name, args) -> tuple:
64
96
 
65
97
  analysis = cov.analysis2(file_name)
66
98
  analysis_list = list(analysis)
67
-
68
99
  # Filter executable and missed lines to function range
69
100
  analysis_list[1] = [line for line in analysis_list[1] if func_start <= line <= func_end]
70
101
  analysis_list[3] = [line for line in analysis_list[3] if func_start <= line <= func_end]
@@ -113,21 +144,12 @@ def get_list_of_missed_lines(analysis: tuple) -> list:
113
144
 
114
145
 
115
146
  def get_list_of_covered_statements(analysis: tuple) -> list:
116
- # analysis[1] == list of total executable statements
117
- # analysis[3] == list of missed statements
118
147
  return [x for x in analysis[1] if x not in analysis[3]]
119
- # total_statements = len(analysis[1]) # list of executable statements
120
- # missed_statements = len(analysis[3]) # list of missed line numbers
121
-
122
- # covered_statements = total_statements - missed_statements
123
148
 
124
149
 
125
- def get_uncovered_lines_for_func(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[int]:
150
+ def get_uncovered_lines_for_func(file_name: str, class_name: str | None, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[int]:
126
151
  # Get normal uncovered lines
127
152
  func_name = func_node.name
128
- if not test_cases:
129
- print(f"Warning: No test cases provided {func_name}.")
130
- return []
131
153
 
132
154
  function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
133
155
  if not function_test_cases:
@@ -135,7 +157,12 @@ def get_uncovered_lines_for_func(file_name: str, func_node: ast.FunctionDef, tes
135
157
  return []
136
158
 
137
159
  module = load_module(file_name)
138
- func = getattr(module, func_name)
160
+ if class_name is not None:
161
+ class_obj = getattr(module, class_name)
162
+ instance = class_obj()
163
+ func = getattr(instance, func_name)
164
+ else:
165
+ func = getattr(module, func_name)
139
166
 
140
167
  # Run coverage
141
168
  cov = coverage.Coverage(branch=True) # Enable branch coverage
@@ -16,12 +16,12 @@ class CoverageVisualizer:
16
16
  def set_service(self, service):
17
17
  self.service = service
18
18
 
19
- def get_covered_lines(self, file_path: str, func_def: ast.FunctionDef, test_cases: List[TestCase]):
19
+ def get_covered_lines(self, file_path: str, class_name: str | None, func_def: ast.FunctionDef, test_cases: List[TestCase]):
20
20
  if func_def.name not in self.covered_lines:
21
21
  self.covered_lines[func_def.name] = set()
22
22
 
23
23
  for test_case in [tc for tc in test_cases if tc.func_name == func_def.name]:
24
- analysis = testgen.util.coverage_utils.get_coverage_analysis(file_path, func_def.name, test_case.inputs)
24
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(file_path, class_name, func_def.name, test_case.inputs)
25
25
  covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
26
26
  if covered:
27
27
  self.covered_lines[func_def.name].update(covered)
@@ -108,3 +108,49 @@ def load_and_parse_file_for_tree(file) -> Module:
108
108
  code = f.read()
109
109
  tree = ast.parse(code)
110
110
  return tree
111
+
112
+ def adjust_file_path_for_docker(file_path) -> str:
113
+ """Adjust file path for Docker environment to handle subdirectories and relative paths."""
114
+ print(f"Docker - adjusting path: {file_path}")
115
+
116
+ # Try direct path first (maybe it's correct already)
117
+ if os.path.isfile(file_path):
118
+ print(f"Docker - found file at direct path: {file_path}")
119
+ return file_path
120
+
121
+ # Try relative to /controller (root of mount)
122
+ controller_path = os.path.join('/controller', file_path)
123
+ if os.path.isfile(controller_path):
124
+ print(f"Docker - found file at: {controller_path}")
125
+ return controller_path
126
+
127
+ # Try /controller/testgen/code_to_test path - this is where the file actually is
128
+ testgen_path = os.path.join('/controller/testgen', file_path)
129
+ if os.path.isfile(testgen_path):
130
+ print(f"Docker - found file at: {testgen_path}")
131
+ return testgen_path
132
+
133
+ # If it's just a filename, search in common locations
134
+ if os.path.basename(file_path) == file_path:
135
+ for search_dir in ['/controller', '/controller/code_to_test', '/controller/testgen/code_to_test', '/controller/testgen']:
136
+ test_path = os.path.join(search_dir, file_path)
137
+ if os.path.isfile(test_path):
138
+ print(f"Docker - found file at: {test_path}")
139
+ return test_path
140
+
141
+ # Debug output to help diagnose issues
142
+ print("Docker - available files in /controller:")
143
+ os.system("find /controller -name '*.py' | grep boolean")
144
+
145
+ # Return original path if we couldn't find a better match
146
+ print(f"Docker - couldn't find file, returning original: {file_path}")
147
+ return file_path
148
+
149
+ def get_project_root_in_docker(script_path) -> str:
150
+ script_path = os.path.abspath(sys.argv[0])
151
+ print(f"Script path: {script_path}")
152
+ script_dir = os.path.dirname(script_path)
153
+ print(f"Script directory: {script_dir}")
154
+ project_root = os.path.dirname(script_dir)
155
+ print(f"Project root directory: {project_root}")
156
+ return project_root
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
  import random
3
+ import string
3
4
  from typing import List
4
5
 
5
6
  import testgen.util.coverage_utils
@@ -13,12 +14,12 @@ except ImportError:
13
14
  solve_branch_condition = None
14
15
  from testgen.models.test_case import TestCase
15
16
 
16
- def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
17
+ def make_random_move(file_name: str, class_name: str | None, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
17
18
  random_choice = random.randint(1, 4)
18
19
  func_name = func_node.name
19
20
  # new random test case
20
21
  if random_choice == 1:
21
- test_cases.append(new_random_test_case(file_name, func_node))
22
+ test_cases.append(new_random_test_case(file_name, class_name, func_node))
22
23
  # combine test cases
23
24
  if random_choice == 2:
24
25
  test_cases.append(combine_cases(test_cases))
@@ -31,7 +32,7 @@ def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: Lis
31
32
  function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
32
33
 
33
34
  if function_test_cases:
34
- uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, func_name)
35
+ uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, class_name, func_name)
35
36
 
36
37
  if len(uncovered_lines) > 0:
37
38
  z3_test_cases = solve_branch_condition(file_name, func_node, uncovered_lines)
@@ -39,18 +40,32 @@ def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: Lis
39
40
 
40
41
  return test_cases
41
42
 
42
- def new_random_test_case(file_name: str, func_node: ast.FunctionDef) -> TestCase:
43
+ def new_random_test_case(file_name: str, class_name: str | None, func_node: ast.FunctionDef) -> TestCase:
43
44
  func_name = func_node.name
44
45
  param_types: dict = utils.extract_parameter_types(func_node)
46
+
47
+ if class_name is not None and 'self' in param_types:
48
+ del param_types['self']
49
+
45
50
  inputs: dict = utils.generate_random_inputs(param_types)
46
- args = inputs.values()
51
+ args = list(inputs.values())
47
52
 
48
53
  module = testgen.util.file_utils.load_module(file_name)
49
- func = getattr(module, func_name)
50
54
 
51
- output = func(*args)
55
+ try:
56
+ if class_name is not None:
57
+ class_obj = getattr(module, class_name)
58
+ instance = class_obj()
59
+ func = getattr(instance, func_name)
60
+ else:
61
+ func = getattr(module, func_name)
62
+
63
+ output = func(*args)
52
64
 
53
- return TestCase(func_name, tuple(args), output)
65
+ return TestCase(func_name, tuple(args), output)
66
+ except Exception as e:
67
+ print(f"Error generating test case for {func_name}: {e}")
68
+ raise
54
69
 
55
70
  # Should combining test cases preserve the parent cases or entirely replace them?
56
71
  def combine_cases(test_cases: List[TestCase]) -> TestCase:
@@ -87,25 +102,25 @@ def mix_inputs(test_case1: TestCase, test_case2: TestCase) -> tuple:
87
102
 
88
103
  return new_inputs
89
104
 
90
- def get_z3_test_cases(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
105
+ def get_z3_test_cases(file_name: str, class_name: str | None, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
91
106
  func_name = func_node.name
92
107
 
93
108
  # Filter test cases for this specific function
94
109
  function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
95
110
 
96
111
  if not function_test_cases:
97
- initial_case = new_random_test_case(file_name, func_node)
112
+ initial_case = new_random_test_case(file_name, class_name, func_node)
98
113
  test_cases.append(initial_case)
99
114
  function_test_cases = [initial_case]
100
115
 
101
116
  try:
102
117
  # Get uncovered lines
103
- uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, func_node, function_test_cases)
118
+ uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, class_name, func_node, function_test_cases)
104
119
 
105
120
  if uncovered_lines:
106
121
  if solve_branch_condition:
107
122
  # Call the Z3 solver with uncovered lines
108
- z3_cases = solve_branch_condition(file_name, func_node, uncovered_lines)
123
+ z3_cases = solve_branch_condition(file_name, class_name, func_node, uncovered_lines)
109
124
  if z3_cases:
110
125
  test_cases.extend(z3_cases)
111
126
  else:
@@ -9,7 +9,7 @@ from testgen.util.z3_utils import ast_to_z3
9
9
  from testgen.util.z3_utils.constraint_extractor import extract_branch_conditions
10
10
 
11
11
 
12
- def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered_lines: List[int]) -> List[TestCase]:
12
+ def solve_branch_condition(file_name: str, class_name: str | None, func_node: ast.FunctionDef, uncovered_lines: List[int]) -> List[TestCase]:
13
13
  branch_conditions, param_types = extract_branch_conditions(func_node)
14
14
  uncovered_conditions = [bc for bc in branch_conditions if bc.line_number in uncovered_lines]
15
15
  test_cases = []
@@ -25,6 +25,10 @@ def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered
25
25
  # Create default values for all parameters
26
26
  param_values = {}
27
27
  for param_name in param_types:
28
+ # Skip 'self' parameter for class methods
29
+ if param_name == 'self':
30
+ continue
31
+
28
32
  # Set default values based on type
29
33
  if param_types[param_name] == "int":
30
34
  param_values[param_name] = 0
@@ -39,32 +43,38 @@ def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered
39
43
 
40
44
  # Update with model values where available
41
45
  for var_name, z3_var in z3_vars.items():
42
- if var_name in param_types and z3_var in model:
46
+ if var_name in param_types:
43
47
  try:
48
+
49
+ model_value = model.evaluate(z3_var)
50
+
44
51
  if param_types[var_name] == "int":
45
- param_values[var_name] = model[z3_var].as_long()
52
+ param_values[var_name] = model_value.as_long()
46
53
  elif param_types[var_name] == "float":
47
- param_values[var_name] = float(model[z3_var].as_decimal())
54
+ param_values[var_name] = float(model_value.as_decimal(10))
48
55
  elif param_types[var_name] == "bool":
49
- param_values[var_name] = z3.is_true(model[z3_var])
56
+ param_values[var_name] = z3.is_true(model_value)
50
57
  elif param_types[var_name] == "str":
51
- param_values[var_name] = str(model[z3_var])
58
+ str_val = str(model_value)
59
+ if str_val.startswith('"') and str_val.endswith('"'):
60
+ str_val = str_val[1:-1]
61
+ param_values[var_name] = str_val
52
62
  else:
53
63
  param_values[var_name] = model[z3_var].as_long()
54
64
  except Exception as e:
55
- print(f"Error converting Z3 model value for {var_name}: {e}")
65
+ print(f"Couldn't get {var_name} from model: {e}")
66
+ # Keep the default value we already set
56
67
 
57
68
  # Ensure all parameters are included in correct order
58
69
  ordered_params = []
59
70
  for arg in func_node.args.args:
60
71
  arg_name = arg.arg
61
- if arg_name == 'self': # Skip self parameter for class methods
72
+ if arg_name == 'self': # Skip self parameter
62
73
  continue
63
74
  if arg_name in param_values:
64
75
  ordered_params.append(param_values[arg_name])
65
76
  else:
66
- print(f"Warning: Missing value for parameter {arg_name}")
67
- # Provide default values based on annotation if available
77
+ # Default value handling if parameter not in solution
68
78
  if hasattr(arg, 'annotation') and arg.annotation:
69
79
  if isinstance(arg.annotation, ast.Name):
70
80
  if arg.annotation.id == 'int':
@@ -85,7 +95,12 @@ def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered
85
95
  func_name = func_node.name
86
96
  try:
87
97
  module = testgen.util.file_utils.load_module(file_name)
88
- func = getattr(module, func_name)
98
+ if class_name is not None:
99
+ class_obj = getattr(module, class_name)
100
+ instance = class_obj()
101
+ func = getattr(instance, func_name)
102
+ else:
103
+ func = getattr(module, func_name)
89
104
  result = func(*ordered_params)
90
105
  test_cases.append(TestCase(func_name, tuple(ordered_params), result))
91
106
  except Exception as e:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: testgenie-py
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary:
5
5
  Author: cjseitz
6
6
  Author-email: charlesjseitz@gmail.com