testgenie-py 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -113,7 +113,7 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
113
113
  self.covered_lines[func.function_name] = set()
114
114
 
115
115
  for test_case in [tc for tc in self.test_cases if tc.func_name == func.function_name]:
116
- analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath,
116
+ analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name,
117
117
  func.function_name, test_case.inputs)
118
118
  covered = coverage_utils.get_list_of_covered_statements(analysis)
119
119
  self.covered_lines[func.function_name].update(covered)
@@ -249,10 +249,10 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
249
249
  print("Warning: No test cases available to determine executable statements")
250
250
  from testgen.util.randomizer import new_random_test_case
251
251
  temp_case = new_random_test_case(self._analysis_context.filepath, func.func_def)
252
- analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, func.function_name,
252
+ analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name,
253
253
  temp_case.inputs)
254
254
  else:
255
- analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, func.function_name, test_cases[0].inputs)
255
+ analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name, test_cases[0].inputs)
256
256
 
257
257
  # Get standard executable lines from coverage.py
258
258
  executable_lines = list(analysis[1])
@@ -6,7 +6,8 @@ import sys
6
6
  import docker
7
7
  from docker import DockerClient
8
8
  from docker import errors
9
-
9
+ from testgen.service.logging_service import LoggingService, get_logger
10
+ from testgen.util.file_utils import adjust_file_path_for_docker, get_project_root_in_docker
10
11
  from testgen.controller.docker_controller import DockerController
11
12
  from testgen.service.service import Service
12
13
  from testgen.presentation.cli_view import CLIView
@@ -28,6 +29,85 @@ class CLIController:
28
29
  self.view = view
29
30
 
30
31
  def run(self):
32
+
33
+ parser = self.add_arguments()
34
+
35
+ args = parser.parse_args()
36
+
37
+ LoggingService.get_instance().initialize(
38
+ debug_mode=args.debug if hasattr(args, 'debug') else False,
39
+ log_file=args.log_file if hasattr(args, 'log_file') else None,
40
+ console_output=True
41
+ )
42
+
43
+ logger = get_logger()
44
+
45
+ if args.select_all:
46
+ self.view.display_message("Selecting all from SQLite database...")
47
+ # Assuming you have a method in your service to handle this
48
+ self.service.select_all_from_db()
49
+ return
50
+
51
+ running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
52
+ if running_in_docker:
53
+ args.file_path = adjust_file_path_for_docker(args.file_path)
54
+ self.execute_generation(args, True)
55
+ elif args.safe and not running_in_docker:
56
+ client = self.docker_available()
57
+ # Skip Docker-dependent operations if client is None
58
+ if client is None and args.safe:
59
+ self.view.display_message("Running with --safe flag requires Docker. Continuing without safe mode.")
60
+ args.safe = False
61
+ self.execute_generation(args)
62
+ else:
63
+ docker_controller = DockerController()
64
+ project_root = get_project_root_in_docker(args.file_path)
65
+ successful: bool = docker_controller.run_in_docker(project_root, client, args)
66
+ if not successful:
67
+ if hasattr(args, 'db') and args.db:
68
+ self.service.db_service = DBService(args.db)
69
+ self.view.display_message(f"Using database: {args.db}")
70
+ self.execute_generation(args)
71
+ # Else successful, do nothing - we're done
72
+ else:
73
+ # Initialize database service with specified path
74
+ if hasattr(args, 'db') and args.db:
75
+ self.service.db_service = DBService(args.db)
76
+ self.view.display_message(f"Using database: {args.db}")
77
+ self.view.display_message("Running in local mode...")
78
+ self.execute_generation(args)
79
+
80
+ def execute_generation(self, args: argparse.Namespace, running_in_docker: bool = False):
81
+ try:
82
+ self.set_service_args(args)
83
+
84
+ if running_in_docker:
85
+ self.view.display_message("Running in Docker mode...")
86
+ self.service.generate_test_cases()
87
+
88
+ else:
89
+ test_file = self.service.generate_tests(args.output)
90
+ self.view.display_message(f"Unit tests saved to: {test_file}")
91
+ self.view.display_message("Running coverage...")
92
+ self.service.run_coverage(test_file)
93
+ self.view.display_message("Tests and coverage data saved to database.")
94
+
95
+ if args.visualize:
96
+ self.service.visualize_test_coverage()
97
+
98
+ except Exception as e:
99
+ self.view.display_error(f"An error occurred: {e}")
100
+ # Make sure to close the DB connection on error
101
+ if hasattr(self.service, 'db_service'):
102
+ self.service.db_service.close()
103
+
104
+ def set_service_args(self, args: argparse.Namespace):
105
+ self.service.set_file_path(args.file_path)
106
+ self.service.set_debug_mode(args.debug)
107
+ self.set_test_format(args)
108
+ self.set_test_strategy(args)
109
+
110
+ def add_arguments(self) -> argparse.ArgumentParser:
31
111
  parser = argparse.ArgumentParser(description="A CLI tool for generating unit tests.")
32
112
  parser.add_argument("file_path", type=str, help="Path to the Python file.")
33
113
  parser.add_argument("--output", "-o", type=str, help="Path to output directory.")
@@ -75,108 +155,48 @@ class CLIController:
75
155
  action="store_true",
76
156
  help = "Visualize the tests with graphviz"
77
157
  )
158
+ parser.add_argument(
159
+ "--debug",
160
+ action="store_true",
161
+ help="Enable debug logging"
162
+ )
163
+ parser.add_argument(
164
+ "--log-file",
165
+ type=str,
166
+ help="Path to log file (if not specified, logs will only go to console)"
167
+ )
168
+ return parser
78
169
 
79
- args = parser.parse_args()
80
-
81
- if args.select_all:
82
- self.view.display_message("Selecting all from SQLite database...")
83
- # Assuming you have a method in your service to handle this
84
- self.service.select_all_from_db()
85
- return
86
-
87
- # Initialize database service with specified path
88
- if hasattr(args, 'db') and args.db:
89
- self.service.db_service = DBService(args.db)
90
- self.view.display_message(f"Using database: {args.db}")
91
-
92
- running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
93
- if running_in_docker:
94
- args.file_path = self.adjust_file_path_for_docker(args.file_path)
95
- self.execute_generation(args)
96
- elif args.safe and not running_in_docker:
97
- client = self.docker_available()
98
- # Skip Docker-dependent operations if client is None
99
- if client is None and args.safe:
100
- self.view.display_message("Running with --safe flag requires Docker. Continuing without safe mode.")
101
- args.safe = False
102
- docker_controller = DockerController()
103
- project_root = self.get_project_root_in_docker(args.file_path)
104
- successful: bool = docker_controller.run_in_docker(project_root, client, args)
105
- if not successful:
106
- self.execute_generation(args)
170
+ def set_test_format(self, args: argparse.Namespace):
171
+ if args.test_format == "pytest":
172
+ self.service.set_test_generator_format(PYTEST_FORMAT)
173
+ elif args.test_format == "doctest":
174
+ self.service.set_test_generator_format(DOCTEST_FORMAT)
107
175
  else:
108
- self.view.display_message("Running in local mode...")
109
- self.execute_generation(args)
176
+ self.service.set_test_generator_format(UNITTEST_FORMAT)
110
177
 
111
- def execute_generation(self, args: argparse.Namespace):
112
- try:
113
- self.service.set_file_path(args.file_path)
114
- if args.test_format == "pytest":
115
- self.service.set_test_generator_format(PYTEST_FORMAT)
116
- elif args.test_format == "doctest":
117
- self.service.set_test_generator_format(DOCTEST_FORMAT)
118
- else:
119
- self.service.set_test_generator_format(UNITTEST_FORMAT)
120
- if args.test_mode == "random":
121
- self.view.display_message("Using Random Feedback-Directed Test Generation Strategy.")
122
- self.service.set_test_analysis_strategy(RANDOM_STRAT)
123
- elif args.test_mode == "fuzz":
124
- self.view.display_message("Using Fuzz Test Generation Strategy...")
125
- self.service.set_test_analysis_strategy(FUZZ_STRAT)
126
- elif args.test_mode == "reinforce":
127
- self.view.display_message("Using Reinforcement Learning Test Generation Strategy...")
128
- if args.reinforce_mode == "train":
129
- self.view.display_message("Training mode enabled - will update Q-table")
130
- else:
131
- self.view.display_message("Training mode disabled - will use existing Q-table")
132
- self.service.set_test_analysis_strategy(REINFORCE_STRAT)
133
- self.service.set_reinforcement_mode(args.reinforce_mode)
178
+ def set_test_strategy(self, args: argparse.Namespace):
179
+ if args.test_mode == "random":
180
+ self.view.display_message("Using Random Feedback-Directed Test Generation Strategy.")
181
+ self.service.set_test_analysis_strategy(RANDOM_STRAT)
182
+ elif args.test_mode == "fuzz":
183
+ self.view.display_message("Using Fuzz Test Generation Strategy...")
184
+ self.service.set_test_analysis_strategy(FUZZ_STRAT)
185
+ elif args.test_mode == "reinforce":
186
+ self.view.display_message("Using Reinforcement Learning Test Generation Strategy...")
187
+ if args.reinforce_mode == "train":
188
+ self.view.display_message("Training mode enabled - will update Q-table")
134
189
  else:
135
- self.view.display_message("Generating function code using AST analysis...")
136
- generated_file_path = self.service.generate_function_code()
137
- self.view.display_message(f"Generated code saved to: {generated_file_path}")
138
- if not args.generate_only:
139
- self.view.display_message("Using Simple AST Traversal Test Generation Strategy...")
140
- self.service.set_test_analysis_strategy(AST_STRAT)
141
-
142
- test_file = self.service.generate_tests(args.output)
143
- self.view.display_message(f"Unit tests saved to: {test_file}")
144
- self.view.display_message("Running coverage...")
145
- self.service.run_coverage(test_file)
146
- self.view.display_message("Tests and coverage data saved to database.")
147
-
148
- if args.visualize:
149
- self.service.visualize_test_coverage()
150
-
151
- except Exception as e:
152
- self.view.display_error(f"An error occurred: {e}")
153
- # Make sure to close the DB connection on error
154
- if hasattr(self.service, 'db_service'):
155
- self.service.db_service.close()
156
-
157
- def adjust_file_path_for_docker(self, file_path) -> str:
158
- file_dir = os.path.abspath(os.path.dirname(file_path))
159
- sys.path.append(file_dir)
160
- sys.path.append('/controller')
161
- file_abs_path = os.path.abspath(file_path)
162
- if not os.path.exists(file_abs_path):
163
- testgen_path = os.path.join('/controller/testgen', os.path.basename(file_path))
164
- if os.path.exists(testgen_path):
165
- file_path = testgen_path
166
- else:
167
- app_path = os.path.join('/controller', os.path.basename(file_path))
168
- if os.path.exists(app_path):
169
- file_path = app_path
170
- return file_path
171
-
172
- def get_project_root_in_docker(self, script_path) -> str:
173
- script_path = os.path.abspath(sys.argv[0])
174
- print(f"Script path: {script_path}")
175
- script_dir = os.path.dirname(script_path)
176
- print(f"Script directory: {script_dir}")
177
- project_root = os.path.dirname(script_dir)
178
- print(f"Project root directory: {project_root}")
179
- return project_root
190
+ self.view.display_message("Training mode disabled - will use existing Q-table")
191
+ self.service.set_test_analysis_strategy(REINFORCE_STRAT)
192
+ self.service.set_reinforcement_mode(args.reinforce_mode)
193
+ else:
194
+ self.view.display_message("Generating function code using AST analysis...")
195
+ generated_file_path = self.service.generate_function_code()
196
+ self.view.display_message(f"Generated code saved to: {generated_file_path}")
197
+ if not args.generate_only:
198
+ self.view.display_message("Using Simple AST Traversal Test Generation Strategy...")
199
+ self.service.set_test_analysis_strategy(AST_STRAT)
180
200
 
181
201
  def docker_available(self) -> DockerClient | None:
182
202
  try:
@@ -6,6 +6,7 @@ from docker import DockerClient, client
6
6
  from docker import errors
7
7
  from docker.models.containers import Container
8
8
 
9
+ from testgen.service.logging_service import get_logger
9
10
  from testgen.service.service import Service
10
11
 
11
12
  AST_STRAT = 1
@@ -19,10 +20,13 @@ DOCTEST_FORMAT = 3
19
20
  class DockerController:
20
21
  def __init__(self):
21
22
  self.service = Service()
23
+ self.debug_mode = False
22
24
  self.args = None
25
+ self.logger = get_logger()
23
26
 
24
27
  def run_in_docker(self, project_root: str, docker_client: DockerClient, args: Namespace) -> bool:
25
28
  self.args = args
29
+ self.debug_mode = True if args.debug else False
26
30
  os.environ["RUNNING_IN_DOCKER"] = "1"
27
31
 
28
32
  # Check if Docker image exists, build it if not
@@ -30,18 +34,19 @@ class DockerController:
30
34
  # If args.safe is set to false it means the image was not found and the system will try to run_locally
31
35
  self.get_image(docker_client, image_name, project_root)
32
36
  if not self.args.safe:
33
- print("Docker image not found. Running locally...")
37
+ self.logger.info("Docker image not found. Running locally...")
34
38
  return False
35
39
 
36
- docker_args = [os.path.basename(args.file_path)] + [arg for arg in sys.argv[2:] if arg != "--safe"]
40
+ docker_args = [args.file_path] + [arg for arg in sys.argv[2:] if arg != "--safe"]
37
41
 
38
42
  # Run the container with the same arguments
39
43
  try:
44
+ self.debug(f"project_root: {project_root}")
40
45
  container = self.run_container(docker_client, image_name, docker_args, project_root)
41
46
 
42
47
  # Stream the logs to the console
43
48
  logs_output = self.get_logs(container)
44
- print(logs_output)
49
+ self.debug(logs_output)
45
50
 
46
51
  try:
47
52
  # Create the target directory if it doesn't exist
@@ -51,14 +56,14 @@ class DockerController:
51
56
  target_path = args.output
52
57
  os.makedirs(target_path, exist_ok=True)
53
58
 
54
- print(f"SERVICE target path after logs: {target_path}")
59
+ self.debug(f"SERVICE target path after logs: {target_path}")
55
60
 
56
61
  test_cases = self.service.parse_test_cases_from_logs(logs_output)
57
62
 
58
63
  print(f"Extracted {len(test_cases)} test cases from container.")
59
64
 
60
65
  file_path = os.path.abspath(args.file_path)
61
- print(f"Filepath in CLI CONTROLLER: {file_path}")
66
+ self.debug(f"Filepath in CLI CONTROLLER: {file_path}")
62
67
  self.service.set_file_path(file_path)
63
68
 
64
69
  if args.test_format == "pytest":
@@ -73,9 +78,10 @@ class DockerController:
73
78
 
74
79
  if not args.generate_only:
75
80
  print("Running coverage...")
76
- import traceback
77
- print(traceback.format_exc())
78
81
  self.service.run_coverage(test_file)
82
+
83
+ # Add explicit return True here
84
+ return True
79
85
 
80
86
  except Exception as e:
81
87
  print(f"Error running container: {e}")
@@ -99,7 +105,7 @@ class DockerController:
99
105
  print(f"Dockerfile not found at {dockerfile_path}")
100
106
  sys.exit(1)
101
107
 
102
- print(f"Using Dockerfile at: {dockerfile_path}")
108
+ self.debug(f"Using Dockerfile at: {dockerfile_path}")
103
109
 
104
110
  if not self.build_docker_image(docker_client, image_name, dockerfile_path, project_root):
105
111
  print("Failed to build Docker image. Continuing without safe mode.")
@@ -108,7 +114,6 @@ class DockerController:
108
114
  @staticmethod
109
115
  def get_logs(container) -> str:
110
116
  # Stream the logs to the console
111
- print("Running in Docker container...")
112
117
  logs = container.logs(stream=True)
113
118
  logs_output = ""
114
119
  for log in logs:
@@ -119,25 +124,32 @@ class DockerController:
119
124
 
120
125
  @staticmethod
121
126
  def run_container(docker_client: DockerClient, image_name: str, docker_args: list, project_root: str) -> Container:
127
+ # Create Docker-specific environment variables
128
+ docker_env = {
129
+ "RUNNING_IN_DOCKER": "1",
130
+ "PYTHONPATH": "/controller",
131
+ "COVERAGE_FILE": "/tmp/.coverage", # Move coverage file to /tmp
132
+ "DB_PATH": "/tmp/testgen.db" # Move DB to /tmp
133
+ }
134
+
122
135
  return docker_client.containers.run(
123
136
  image=image_name,
124
137
  command=["poetry", "run", "python", "-m", "testgen.main"] + docker_args,
125
- volumes={project_root: {"bind": "/controller", "mode": "rw"}}, # Mount current dir
126
- environment={"RUNNING_IN_DOCKER": "1"},
138
+ volumes={project_root: {"bind": "/controller", "mode": "rw"}},
139
+ environment=docker_env,
127
140
  detach=True,
128
141
  remove=True,
129
142
  stdout=True,
130
143
  stderr=True
131
144
  )
132
145
 
133
- @staticmethod
134
- def build_docker_image(docker_client, image_name, dockerfile_path, project_root):
146
+ def build_docker_image(self, docker_client, image_name, dockerfile_path, project_root):
135
147
  try:
136
148
  print(f"Starting Docker build for image: {image_name}")
137
149
  dockerfile_rel_path = os.path.relpath(dockerfile_path, project_root)
138
- print(f"Project root {project_root}")
139
- print(f"Docker directory: {os.path.dirname(dockerfile_path)}")
140
- print(f"Docker rel path: {dockerfile_rel_path}")
150
+ self.debug(f"Project root {project_root}")
151
+ self.debug(f"Docker directory: {os.path.dirname(dockerfile_path)}")
152
+ self.debug(f"Docker rel path: {dockerfile_rel_path}")
141
153
  build_progress = docker_client.api.build(
142
154
  path=os.path.join(project_root, "testgen", "docker"),
143
155
  dockerfile=os.path.join(project_root, "testgen", "docker", "Dockerfile"),
@@ -147,13 +159,13 @@ class DockerController:
147
159
  )
148
160
 
149
161
  for chunk in build_progress:
150
- print(f"CHUNK: {chunk}")
162
+ self.debug(f"CHUNK: {chunk}")
151
163
  if 'stream' in chunk:
152
164
  for line in chunk['stream'].splitlines():
153
165
  if line.strip():
154
166
  print(f"Docker: {line.strip()}")
155
167
  elif 'error' in chunk:
156
- print(f"Docker build error: {chunk['error']}")
168
+ self.debug(f"Docker build error: {chunk['error']}")
157
169
  return False
158
170
  print(f"Docker image built successfully: {image_name}")
159
171
  return True
@@ -166,4 +178,7 @@ class DockerController:
166
178
  print(f"Unexpected error during Docker build: {str(e)}")
167
179
  return False
168
180
 
169
-
181
+ def debug(self, message: str):
182
+ """Log debug message"""
183
+ if self.debug_mode:
184
+ self.logger.debug(message)
testgen/docker/Dockerfile CHANGED
@@ -9,12 +9,12 @@ ENV POETRY_VIRTUALENVS_CREATE=false \
9
9
  PYTHONUNBUFFERED=1 \
10
10
  RUNNING_IN_DOCKER=true
11
11
 
12
- WORKDIR /app
12
+ WORKDIR /controller
13
13
 
14
14
  # Copy poetry files
15
15
  COPY . .
16
16
 
17
- ENV PYTHONPATH=/app:/app/testgen
17
+ ENV PYTHONPATH=/controller:/controller/testgen
18
18
 
19
19
  RUN poetry install --no-root
20
20
 
@@ -1,10 +1,16 @@
1
1
  [tool.poetry]
2
- name = "testgen"
3
- version = "0.1.0"
2
+ name = "testgenie-py"
3
+ version = "0.1.6"
4
4
  description = ""
5
5
  authors = ["cjseitz <charlesjseitz@gmail.com>"]
6
6
  readme = "README.md"
7
7
 
8
+ [[tool.poetry.packages]]
9
+ include = "testgen"
10
+
11
+ [tool.poetry.scripts]
12
+ testgenie = "testgen.main:main"
13
+
8
14
  [tool.poetry.dependencies]
9
15
  python = "^3.10"
10
16
  astor = "0.8.1"
@@ -16,6 +22,7 @@ typed-ast = "1.5.5"
16
22
  z3-solver = "4.13.3.0"
17
23
  staticfg = "^0.9.5"
18
24
  pytest = "^8.3.5"
25
+ docker = "^7.1.0"
19
26
 
20
27
  [build-system]
21
28
  requires = ["poetry-core"]
@@ -12,10 +12,11 @@ from testgen.models.test_case import TestCase
12
12
 
13
13
 
14
14
  class ReinforcementEnvironment:
15
- def __init__(self, file_name, fut: ast.FunctionDef, module, initial_test_cases: List[TestCase], state: AbstractState):
15
+ def __init__(self, file_name, fut: ast.FunctionDef, module, class_name: str | None, initial_test_cases: List[TestCase], state: AbstractState):
16
16
  self.file_name = file_name
17
17
  self.fut = fut
18
18
  self.module = module
19
+ self.class_name = class_name
19
20
  self.initial_test_cases = initial_test_cases
20
21
  self.test_cases = initial_test_cases.copy()
21
22
  self.state = state
@@ -32,13 +33,13 @@ class ReinforcementEnvironment:
32
33
 
33
34
  # Execute action
34
35
  if action == "add":
35
- self.test_cases.append(randomizer.new_random_test_case(self.file_name, self.fut))
36
+ self.test_cases.append(randomizer.new_random_test_case(self.file_name, self.class_name, self.fut))
36
37
  elif action == "merge" and len(self.test_cases) > 1:
37
38
  self.test_cases.append(randomizer.combine_cases(self.test_cases))
38
39
  elif action == "remove" and len(self.test_cases) > 1:
39
40
  self.test_cases = randomizer.remove_case(self.test_cases)
40
41
  elif action == "z3":
41
- self.test_cases = randomizer.get_z3_test_cases(self.file_name, self.fut, self.test_cases)
42
+ self.test_cases = randomizer.get_z3_test_cases(self.file_name, self.class_name, self.fut, self.test_cases)
42
43
  else:
43
44
  raise ValueError("Invalid action")
44
45
 
@@ -97,26 +98,52 @@ class ReinforcementEnvironment:
97
98
 
98
99
  test_cases = [tc for tc in self.test_cases if tc.func_name == self.fut.name]
99
100
 
101
+ executable_lines = set()
100
102
  if not test_cases:
101
103
  print("Warning: No test cases available to determine executable statements")
102
104
  from testgen.util.randomizer import new_random_test_case
103
- temp_case = new_random_test_case(self.file_name, self.fut)
104
- analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.fut.name, temp_case.inputs)
105
+ temp_case = new_random_test_case(self.file_name, self.class_name, self.fut)
106
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, temp_case.inputs)
107
+ executable_lines.update(analysis[1]) # Add executable lines from coverage analysis
105
108
  else:
106
- analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.fut.name, test_cases[0].inputs)
109
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, test_cases[0].inputs)
107
110
 
111
+ executable_lines.update(analysis[1]) # Add executable lines from coverage analysis
108
112
  # Get standard executable lines from coverage.py
109
- executable_lines = list(analysis[1])
113
+ executable_lines = list(executable_lines)
110
114
 
111
115
  # Parse the source file to find else branches
112
116
  with open(self.file_name, 'r') as f:
113
117
  source = f.read()
114
118
 
115
119
  # Parse the code
116
- tree = ast.parse(source)
117
-
120
+ tree = ast.parse(source)
118
121
  # Find our specific function
119
122
  for node in ast.walk(tree):
123
+ if isinstance(node, ast.ClassDef) and node.name == self.class_name:
124
+ # If we have a class, find the method
125
+ for method in node.body:
126
+ if isinstance(method, ast.FunctionDef) and method.name == self.fut.name:
127
+ # Find all if statements in this method
128
+ for if_node in ast.walk(method):
129
+ if isinstance(if_node, ast.If) and if_node.orelse:
130
+ # There's an else branch
131
+ if isinstance(if_node.orelse[0], ast.If):
132
+ # This is an elif - already counted
133
+ continue
134
+
135
+ # Get the line number of the first statement in the else block
136
+ # and subtract 1 to get the 'else:' line
137
+ else_line = if_node.orelse[0].lineno - 1
138
+
139
+ # Check if this is actually an else line (not a nested if)
140
+ with open(self.file_name, 'r') as f:
141
+ lines = f.readlines()
142
+ if else_line <= len(lines):
143
+ line_content = lines[else_line - 1].strip()
144
+ if line_content == "else:":
145
+ if else_line not in executable_lines:
146
+ executable_lines.append(else_line)
120
147
  if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
121
148
  # Find all if statements in this function
122
149
  for if_node in ast.walk(node):
@@ -153,7 +180,12 @@ class ReinforcementEnvironment:
153
180
  for test_case in self.test_cases:
154
181
  try:
155
182
  module = testgen.util.file_utils.load_module(self.file_name)
156
- func = getattr(module, self.fut.name)
183
+ if self.class_name:
184
+ class_obj = getattr(module, self.class_name)
185
+ instance = class_obj()
186
+ func = getattr(instance, self.fut.name)
187
+ else:
188
+ func = getattr(module, self.fut.name)
157
189
  _ = func(*test_case.inputs)
158
190
  except Exception as e:
159
191
  import traceback
@@ -12,7 +12,7 @@ class StatementCoverageState(AbstractState):
12
12
  """Returns calculated coverage and length of test cases in a tuple"""
13
13
  all_covered_statements = set()
14
14
  for test_case in self.environment.test_cases:
15
- analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.file_name, self.environment.fut.name, test_case.inputs)
15
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.file_name, self.environment.class_name, self.environment.fut.name, test_case.inputs)
16
16
  covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
17
17
  all_covered_statements.update(covered)
18
18
 
@@ -43,10 +43,15 @@ class AnalysisService:
43
43
 
44
44
  def create_analysis_context(self, filepath: str) -> AnalysisContext:
45
45
  """Create an analysis context for the given file."""
46
+ print(f"Creating analysis context for {filepath}")
46
47
  filename = file_utils.get_filename(filepath)
48
+ print(f"Filename: {filename}")
47
49
  module = file_utils.load_module(filepath)
50
+ print(f"Module: {module}")
48
51
  class_name = self.get_class_name(module)
52
+ print(f"Class name: {class_name}")
49
53
  function_data = self.get_function_data(filename, module, class_name)
54
+ print(f"Function data: {function_data}")
50
55
  return AnalysisContext(filepath, filename, class_name, module, function_data)
51
56
 
52
57
  def get_function_data(self, filename: str, module: ModuleType, class_name: str | None) -> List[FunctionMetadata]:
@@ -77,11 +82,12 @@ class AnalysisService:
77
82
 
78
83
  return function_metadata_list
79
84
 
80
- def do_reinforcement_learning(self, filepath: str, mode: str = None) -> List[TestCase]:
85
+ def do_reinforcement_learning(self, filepath: str, class_name: str | None, mode: str = None) -> List[TestCase]:
81
86
  mode = mode or self.reinforcement_mode
82
87
  module: ModuleType = testgen.util.file_utils.load_module(filepath)
83
88
  tree: ast.Module = testgen.util.file_utils.load_and_parse_file_for_tree(filepath)
84
89
  functions: List[ast.FunctionDef] = testgen.util.utils.get_functions(tree)
90
+ self.class_name = class_name
85
91
  time_limit: int = 30
86
92
  all_test_cases: List[TestCase] = []
87
93
 
@@ -94,7 +100,7 @@ class AnalysisService:
94
100
  best_coverage: float = 0.0
95
101
 
96
102
  # Create environment and agent once per function
97
- environment = ReinforcementEnvironment(filepath, function, module, function_test_cases, state=StatementCoverageState(None))
103
+ environment = ReinforcementEnvironment(filepath, function, module, self.class_name, function_test_cases, state=StatementCoverageState(None))
98
104
  environment.state = StatementCoverageState(environment)
99
105
 
100
106
  # Create agent with existing Q-table
@@ -44,7 +44,7 @@ class CFGService:
44
44
  filename = os.path.basename(file_path).replace('.py', '')
45
45
 
46
46
  for func in analysis_context.function_data:
47
- self.visualizer.get_covered_lines(file_path, func.func_def, test_cases)
47
+ self.visualizer.get_covered_lines(file_path, analysis_context.class_name, func.func_def, test_cases)
48
48
 
49
49
  base_filename = f"{filename}_{func.function_name}_coverage"
50
50
  output_filepath = self.get_versioned_filename(visualization_dir, base_filename)
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import inspect
2
3
  from types import ModuleType
3
4
  from typing import List
4
5
 
@@ -43,15 +44,22 @@ class GeneratorService:
43
44
 
44
45
  def generate_test_file(self, module: ModuleType, class_name: str | None, test_cases: List[TestCase], output_path=None) -> str:
45
46
  """Generate a test file for the given test cases."""
46
-
47
47
  filename = self.get_filename(self.filepath)
48
-
49
48
  output_path = self.get_test_file_path(module.__name__, output_path)
50
49
 
50
+ # Determine the actual class name used in the module
51
+ actual_class_name = class_name
52
+ if 'generated_' in self.filepath and class_name:
53
+ # For generated classes, find the actual class name in the module
54
+ for name, obj in inspect.getmembers(module):
55
+ if inspect.isclass(obj):
56
+ actual_class_name = name
57
+ break
58
+
51
59
  context = GeneratorContext(
52
60
  filepath=self.filepath,
53
61
  filename=filename,
54
- class_name=class_name,
62
+ class_name=actual_class_name, # Use the actual class name
55
63
  module=module,
56
64
  output_path=output_path,
57
65
  test_cases=test_cases