testgenie-py 0.3.7__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. testgen/analyzer/ast_analyzer.py +2 -11
  2. testgen/analyzer/fuzz_analyzer.py +1 -6
  3. testgen/analyzer/random_feedback_analyzer.py +20 -293
  4. testgen/analyzer/reinforcement_analyzer.py +59 -57
  5. testgen/analyzer/test_case_analyzer_context.py +0 -6
  6. testgen/controller/cli_controller.py +35 -29
  7. testgen/controller/docker_controller.py +1 -0
  8. testgen/db/dao.py +68 -0
  9. testgen/db/dao_impl.py +226 -0
  10. testgen/{sqlite → db}/db.py +15 -6
  11. testgen/generator/pytest_generator.py +2 -10
  12. testgen/generator/unit_test_generator.py +2 -11
  13. testgen/main.py +1 -3
  14. testgen/models/coverage_data.py +56 -0
  15. testgen/models/db_test_case.py +65 -0
  16. testgen/models/function.py +56 -0
  17. testgen/models/function_metadata.py +11 -1
  18. testgen/models/generator_context.py +32 -2
  19. testgen/models/source_file.py +29 -0
  20. testgen/models/test_result.py +38 -0
  21. testgen/models/test_suite.py +20 -0
  22. testgen/reinforcement/agent.py +1 -27
  23. testgen/reinforcement/environment.py +11 -93
  24. testgen/reinforcement/statement_coverage_state.py +5 -4
  25. testgen/service/analysis_service.py +31 -22
  26. testgen/service/cfg_service.py +3 -1
  27. testgen/service/coverage_service.py +115 -0
  28. testgen/service/db_service.py +140 -0
  29. testgen/service/generator_service.py +77 -20
  30. testgen/service/logging_service.py +2 -2
  31. testgen/service/service.py +62 -231
  32. testgen/service/test_executor_service.py +145 -0
  33. testgen/util/coverage_utils.py +38 -116
  34. testgen/util/coverage_visualizer.py +10 -9
  35. testgen/util/file_utils.py +10 -111
  36. testgen/util/randomizer.py +0 -26
  37. testgen/util/utils.py +197 -38
  38. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/METADATA +1 -1
  39. testgenie_py-0.3.8.dist-info/RECORD +72 -0
  40. testgen/inspector/inspector.py +0 -59
  41. testgen/presentation/__init__.py +0 -0
  42. testgen/presentation/cli_view.py +0 -12
  43. testgen/sqlite/__init__.py +0 -0
  44. testgen/sqlite/db_service.py +0 -239
  45. testgen/testgen.db +0 -0
  46. testgenie_py-0.3.7.dist-info/RECORD +0 -67
  47. /testgen/{inspector → db}/__init__.py +0 -0
  48. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/WHEEL +0 -0
  49. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,65 @@
1
+ class DBTestCase:
2
+ def __init__(self, expected_output, inputs, test_function: str, last_run_time, test_method_type: int, test_suite_id: int, function_id: int):
3
+ self._expected_output = expected_output
4
+ self._inputs = inputs
5
+ self._test_function = test_function
6
+ self._last_run_time = last_run_time
7
+ self._test_method_type = test_method_type
8
+ self._test_suite_id = test_suite_id
9
+ self._function_id = function_id
10
+
11
+ @property
12
+ def expected_output(self):
13
+ return self._expected_output
14
+
15
+ @expected_output.setter
16
+ def expected_output(self, value):
17
+ self._expected_output = value
18
+
19
+ @property
20
+ def inputs(self):
21
+ return self._inputs
22
+
23
+ @inputs.setter
24
+ def inputs(self, value):
25
+ self._inputs = value
26
+
27
+ @property
28
+ def test_function(self) -> str:
29
+ return self._test_function
30
+
31
+ @test_function.setter
32
+ def test_function(self, value: str) -> None:
33
+ self._test_function = value
34
+
35
+ @property
36
+ def last_run_time(self):
37
+ return self._last_run_time
38
+
39
+ @last_run_time.setter
40
+ def last_run_time(self, value) -> None:
41
+ self._last_run_time = value
42
+
43
+ @property
44
+ def test_method_type(self) -> int:
45
+ return self._test_method_type
46
+
47
+ @test_method_type.setter
48
+ def test_method_type(self, value: int) -> None:
49
+ self._test_method_type = value
50
+
51
+ @property
52
+ def test_suite_id(self) -> int:
53
+ return self._test_suite_id
54
+
55
+ @test_suite_id.setter
56
+ def test_suite_id(self, value: int) -> None:
57
+ self._test_suite_id = value
58
+
59
+ @property
60
+ def function_id(self) -> int:
61
+ return self._function_id
62
+
63
+ @function_id.setter
64
+ def function_id(self, value: int) -> None:
65
+ self._function_id = value
@@ -0,0 +1,56 @@
1
+ class Function:
2
+ def __init__(self, name: str, params, start_line: int, end_line: int, num_lines: int, source_file_id: int):
3
+ self._name = name
4
+ self._params = params
5
+ self._start_line = start_line
6
+ self._end_line = end_line
7
+ self._num_lines = num_lines
8
+ self._source_file_id = source_file_id
9
+
10
+ @property
11
+ def name(self) -> str:
12
+ return self._name
13
+
14
+ @name.setter
15
+ def name(self, value: str) -> None:
16
+ self._name = value
17
+
18
+ @property
19
+ def params(self) -> str:
20
+ return self._params
21
+
22
+ @params.setter
23
+ def params(self, value: str) -> None:
24
+ self._params = value
25
+
26
+ @property
27
+ def start_line(self) -> int:
28
+ return self._start_line
29
+
30
+ @start_line.setter
31
+ def start_line(self, value: int) -> None:
32
+ self._start_line = value
33
+
34
+ @property
35
+ def end_line(self) -> int:
36
+ return self._end_line
37
+
38
+ @end_line.setter
39
+ def end_line(self, value: int) -> None:
40
+ self._end_line = value
41
+
42
+ @property
43
+ def num_lines(self) -> int:
44
+ return self._num_lines
45
+
46
+ @num_lines.setter
47
+ def num_lines(self, value: int) -> None:
48
+ self._num_lines = value
49
+
50
+ @property
51
+ def source_file_id(self) -> int:
52
+ return self._source_file_id
53
+
54
+ @source_file_id.setter
55
+ def source_file_id(self, value: int) -> None:
56
+ self._source_file_id = value
@@ -1,12 +1,14 @@
1
1
  import ast
2
2
  from types import ModuleType
3
+ from typing import Any
3
4
 
4
5
 
5
6
  class FunctionMetadata:
6
- def __init__(self, filename: str, module: ModuleType, class_name: str, function_name: str, func_def: ast.FunctionDef, params: dict):
7
+ def __init__(self, filename: str, module: ModuleType, class_name: str, func: Any, function_name: str, func_def: ast.FunctionDef, params: dict):
7
8
  self._filename: str = filename
8
9
  self._module: ModuleType = module
9
10
  self._class_name: str = class_name
11
+ self._func: Any = func
10
12
  self._function_name: str = function_name
11
13
  self._func_def: ast.FunctionDef = func_def
12
14
  self._params: dict = params
@@ -34,6 +36,14 @@ class FunctionMetadata:
34
36
  @class_name.setter
35
37
  def class_name(self, class_name: str):
36
38
  self._class_name = class_name
39
+
40
+ @property
41
+ def func(self) -> Any:
42
+ return self._func
43
+
44
+ @func.setter
45
+ def func(self, func: Any):
46
+ self._func = func
37
47
 
38
48
  @property
39
49
  def function_name(self) -> str:
@@ -1,18 +1,24 @@
1
1
  from types import ModuleType
2
2
  from typing import List
3
3
 
4
+ from pyasn1_modules.rfc6031 import id_ct_KP_sKeyPackage
5
+
4
6
  from testgen.models.test_case import TestCase
5
7
 
6
8
 
7
9
  class GeneratorContext:
8
- def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str, test_cases: List[TestCase]):
10
+ def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str,
11
+ test_cases: List[TestCase], is_package: bool, package_name: str, import_path: str):
9
12
  self._filepath: str = filepath
10
13
  self._filename: str = filename
11
14
  self._class_name: str = class_name
12
15
  self._module: ModuleType = module
13
16
  self._output_path: str = output_path
14
17
  self._test_cases: List[TestCase] = test_cases
15
-
18
+ self._is_package: bool = is_package
19
+ self._package_name: str = package_name
20
+ self._import_path: str = import_path
21
+
16
22
  @property
17
23
  def filepath(self) -> str:
18
24
  return self._filepath
@@ -61,3 +67,27 @@ class GeneratorContext:
61
67
  def test_cases(self, value: List[TestCase]) -> None:
62
68
  self._test_cases = value
63
69
 
70
+ @property
71
+ def is_package(self) -> bool:
72
+ return self._is_package
73
+
74
+ @is_package.setter
75
+ def is_package(self, value: bool) -> None:
76
+ self._is_package = value
77
+
78
+ @property
79
+ def package_name(self) -> str:
80
+ return self._package_name
81
+
82
+ @package_name.setter
83
+ def package_name(self, value: str) -> None:
84
+ self._package_name = value
85
+
86
+ @property
87
+ def import_path(self) -> str:
88
+ return self._import_path
89
+
90
+ @import_path.setter
91
+ def import_path(self, value: str) -> None:
92
+ self._import_path = value
93
+
@@ -0,0 +1,29 @@
1
+ class SourceFile:
2
+ def __init__(self, path: str, lines_of_code: int, last_modified):
3
+ self._path = path
4
+ self._lines_of_code = lines_of_code
5
+ self._last_modified = last_modified
6
+
7
+ @property
8
+ def path(self) -> str:
9
+ return self._path
10
+
11
+ @path.setter
12
+ def path(self, value: str) -> None:
13
+ self._path = value
14
+
15
+ @property
16
+ def lines_of_code(self) -> int:
17
+ return self._lines_of_code
18
+
19
+ @lines_of_code.setter
20
+ def lines_of_code(self, value: int) -> None:
21
+ self._lines_of_code = value
22
+
23
+ @property
24
+ def last_modified(self):
25
+ return self._last_modified
26
+
27
+ @last_modified.setter
28
+ def last_modified(self, value) -> None:
29
+ self._last_modified = value
@@ -0,0 +1,38 @@
1
+ class TestResult:
2
+ def __init__(self, test_case_id: int, status: bool, error: str, execution_time):
3
+ self._test_case_id = test_case_id
4
+ self._status = status
5
+ self._error = error
6
+ self._execution_time = execution_time
7
+
8
+ @property
9
+ def test_case_id(self) -> int:
10
+ return self._test_case_id
11
+
12
+ @test_case_id.setter
13
+ def test_case_id(self, value: int) -> None:
14
+ self._test_case_id = value
15
+
16
+ @property
17
+ def status(self) -> bool:
18
+ return self._status
19
+
20
+ @status.setter
21
+ def status(self, value: bool) -> None:
22
+ self._status = value
23
+
24
+ @property
25
+ def error(self) -> str:
26
+ return self._error
27
+
28
+ @error.setter
29
+ def error(self, value: str) -> None:
30
+ self._error = value
31
+
32
+ @property
33
+ def execution_time(self):
34
+ return self._execution_time
35
+
36
+ @execution_time.setter
37
+ def execution_time(self, value) -> None:
38
+ self._execution_time = value
@@ -0,0 +1,20 @@
1
+ class TestSuite:
2
+ def __init__(self, name: str, creation_date):
3
+ self._name = name
4
+ self._creation_date = creation_date
5
+
6
+ @property
7
+ def name(self) -> str:
8
+ return self._name
9
+
10
+ @name.setter
11
+ def name(self, value: str) -> None:
12
+ self._name = value
13
+
14
+ @property
15
+ def creation_date(self):
16
+ return self._creation_date
17
+
18
+ @creation_date.setter
19
+ def creation_date(self, value) -> None:
20
+ self._creation_date = value
@@ -128,30 +128,4 @@ class ReinforcementAgent:
128
128
  print(f"UPDATING Q TABLE FOR STATE: {state}, ACTION: {action} WITH REWARD: {reward}")
129
129
  new_q = (1 - self.learning_rate) * current_q + self.learning_rate * (reward + max_next_q)
130
130
 
131
- self.q_table[(state, action)] = new_q
132
-
133
- """def optimize_test_suit(self, current_state, executable_statements):
134
- # Try to optimize test cases by repeatedly performing remove actions if reached full coverage
135
- test_case_count = current_state[1]
136
- optimization_attempts = min(10, test_case_count - 1)
137
-
138
- for _ in range(optimization_attempts):
139
- if test_case_count <= 1:
140
- break
141
-
142
- action = "remove"
143
- next_state, reward = self.env.step(action)
144
-
145
- new_covered = next_state[0]
146
- new_uncovered = [stmt for stmt in executable_statements if stmt not in new_covered]
147
-
148
- if len(new_uncovered) == 0:
149
- current_state = next_state
150
- test_case_count = current_state[2]
151
- print(f"Optimized to {test_case_count} test cases.")
152
- else:
153
- # Add a test case back if removing broke coverage
154
- self.env.step("add")
155
- break
156
-
157
- return current_state"""
131
+ self.q_table[(state, action)] = new_q
@@ -4,6 +4,7 @@ from typing import List, Tuple
4
4
 
5
5
  import coverage
6
6
 
7
+ from testgen.models.function_metadata import FunctionMetadata
7
8
  from testgen.service.logging_service import get_logger
8
9
  import testgen.util.coverage_utils
9
10
  import testgen.util.file_utils
@@ -13,11 +14,9 @@ from testgen.models.test_case import TestCase
13
14
 
14
15
 
15
16
  class ReinforcementEnvironment:
16
- def __init__(self, file_name, fut: ast.FunctionDef, module, class_name: str | None, initial_test_cases: List[TestCase], state: AbstractState):
17
- self.file_name = file_name
18
- self.fut = fut
19
- self.module = module
20
- self.class_name = class_name
17
+ def __init__(self, filepath: str, function_data: FunctionMetadata, initial_test_cases: List[TestCase], state: AbstractState):
18
+ self.filepath = filepath
19
+ self.function_data = function_data
21
20
  self.initial_test_cases = initial_test_cases
22
21
  self.test_cases = initial_test_cases.copy()
23
22
  self.state = state
@@ -35,13 +34,13 @@ class ReinforcementEnvironment:
35
34
 
36
35
  # Execute action
37
36
  if action == "add":
38
- self.test_cases.append(randomizer.new_random_test_case(self.file_name, self.class_name, self.fut))
37
+ self.test_cases.append(randomizer.new_random_test_case(self.filepath, self.function_data.class_name, self.function_data.func_def))
39
38
  elif action == "merge" and len(self.test_cases) > 1:
40
- self.test_cases.append(randomizer.combine_cases(self.module, self.test_cases))
39
+ self.test_cases.append(randomizer.combine_cases(self.function_data.module, self.test_cases))
41
40
  elif action == "remove" and len(self.test_cases) > 1:
42
41
  self.test_cases = randomizer.remove_case(self.test_cases)
43
42
  elif action == "z3":
44
- self.test_cases = randomizer.get_z3_test_cases(self.file_name, self.class_name, self.fut, self.test_cases)
43
+ self.test_cases = randomizer.get_z3_test_cases(self.filepath, self.function_data.class_name, self.function_data.func_def, self.test_cases)
45
44
  else:
46
45
  raise ValueError("Invalid action")
47
46
 
@@ -91,81 +90,6 @@ class ReinforcementEnvironment:
91
90
 
92
91
  print(f"Final reward {reward}")
93
92
  return reward
94
-
95
-
96
- def get_all_executable_statements(self):
97
- """Get all executable statements including else branches"""
98
- import ast
99
-
100
- test_cases = [tc for tc in self.test_cases if tc.func_name == self.fut.name]
101
-
102
- executable_lines = set()
103
- if not test_cases:
104
- self.logger.debug("Warning: No test cases available to determine executable statements")
105
- from testgen.util.randomizer import new_random_test_case
106
- temp_case = new_random_test_case(self.file_name, self.class_name, self.fut)
107
- analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, temp_case.inputs)
108
- executable_lines.update(analysis[1]) # Add executable lines from coverage analysis
109
- else:
110
- analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, test_cases[0].inputs)
111
-
112
- executable_lines.update(analysis[1]) # Add executable lines from coverage analysis
113
- # Get standard executable lines from coverage.py
114
- executable_lines = list(executable_lines)
115
-
116
- # Parse the source file to find else branches
117
- with open(self.file_name, 'r') as f:
118
- source = f.read()
119
-
120
- # Parse the code
121
- tree = ast.parse(source)
122
- # Find our specific function
123
- for node in ast.walk(tree):
124
- if isinstance(node, ast.ClassDef) and node.name == self.class_name:
125
- # If we have a class, find the method
126
- for method in node.body:
127
- if isinstance(method, ast.FunctionDef) and method.name == self.fut.name:
128
- # Find all if statements in this method
129
- for if_node in ast.walk(method):
130
- if isinstance(if_node, ast.If) and if_node.orelse:
131
- # There's an else branch
132
- if isinstance(if_node.orelse[0], ast.If):
133
- # This is an elif - already counted
134
- continue
135
-
136
- # Get the line number of the first statement in the else block
137
- # and subtract 1 to get the 'else:' line
138
- else_line = if_node.orelse[0].lineno - 1
139
-
140
- # Check if this is actually an else line (not a nested if)
141
- with open(self.file_name, 'r') as f:
142
- lines = f.readlines()
143
- if else_line <= len(lines):
144
- line_content = lines[else_line - 1].strip()
145
- if line_content == "else:":
146
- if else_line not in executable_lines:
147
- executable_lines.append(else_line)
148
- if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
149
- # Find all if statements in this function
150
- for if_node in ast.walk(node):
151
- if isinstance(if_node, ast.If) and if_node.orelse:
152
- # There's an else branch
153
- if isinstance(if_node.orelse[0], ast.If):
154
- # This is an elif - already counted
155
- continue
156
-
157
- # Get the line number of the first statement in the else block
158
- # and subtract 1 to get the 'else:' line
159
- else_line = if_node.orelse[0].lineno - 1
160
-
161
- # Check if this is actually an else line (not a nested if)
162
- with open(self.file_name, 'r') as f:
163
- lines = f.readlines()
164
- if else_line <= len(lines):
165
- line_content = lines[else_line - 1].strip()
166
- if line_content == "else:":
167
- if else_line not in executable_lines:
168
- executable_lines.append(else_line)
169
93
 
170
94
  return sorted(executable_lines)
171
95
 
@@ -180,13 +104,7 @@ class ReinforcementEnvironment:
180
104
  # Execute all test cases
181
105
  for test_case in self.test_cases:
182
106
  try:
183
- module = testgen.util.file_utils.load_module(self.file_name)
184
- if self.class_name:
185
- class_obj = getattr(module, self.class_name)
186
- instance = class_obj()
187
- func = getattr(instance, self.fut.name)
188
- else:
189
- func = getattr(module, self.fut.name)
107
+ func = self.function_data.func
190
108
  _ = func(*test_case.inputs)
191
109
  except Exception as e:
192
110
  import traceback
@@ -195,7 +113,7 @@ class ReinforcementEnvironment:
195
113
  self.cov.stop()
196
114
 
197
115
  # Get detailed coverage data including branches
198
- file_path = os.path.abspath(self.file_name)
116
+ file_path = os.path.abspath(self.filepath)
199
117
  data = self.cov.get_data()
200
118
 
201
119
  # Extract function-specific coverage
@@ -225,13 +143,13 @@ class ReinforcementEnvironment:
225
143
  import ast
226
144
 
227
145
  try:
228
- with open(self.file_name, 'r') as f:
146
+ with open(self.filepath, 'r') as f:
229
147
  source = f.read()
230
148
 
231
149
  tree = ast.parse(source)
232
150
 
233
151
  for node in ast.walk(tree):
234
- if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
152
+ if isinstance(node, ast.FunctionDef) and node.name == self.function_data.function_name:
235
153
  # Find the first line of the function
236
154
  start_line = node.lineno
237
155
 
@@ -14,11 +14,11 @@ class StatementCoverageState(AbstractState):
14
14
  """Returns calculated coverage and length of test cases in a tuple"""
15
15
  all_covered_statements = set()
16
16
  for test_case in self.environment.test_cases:
17
- analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.file_name, self.environment.class_name, self.environment.fut.name, test_case.inputs)
17
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.filepath, self.environment.function_data, test_case.inputs)
18
18
  covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
19
19
  all_covered_statements.update(covered)
20
20
 
21
- executable_statements = self.environment.get_all_executable_statements()
21
+ executable_statements = testgen.util.coverage_utils.get_all_executable_statements(self.environment.filepath, self.environment.function_data, self.environment.test_cases)
22
22
 
23
23
  if not executable_statements or executable_statements == 0:
24
24
  calc_coverage = 0.0
@@ -26,10 +26,11 @@ class StatementCoverageState(AbstractState):
26
26
  calc_coverage: float = (len(all_covered_statements) / len(executable_statements)) * 100
27
27
 
28
28
  self.logger.debug(f"GET STATE ALL COVERED STATEMENTS: {all_covered_statements}")
29
- self.logger.debug(f"GET STATE ALL EXECUTABLE STATEMENTS: {self.environment.get_all_executable_statements()}")
29
+ self.logger.debug(f"GET STATE ALL EXECUTABLE STATEMENTS: {executable_statements}")
30
30
  self.logger.debug(f"GET STATE FLOAT COVERAGE: {calc_coverage}")
31
31
 
32
32
  if calc_coverage >= 100:
33
- print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment.fut.name}!!!!!!!!")
33
+ print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment.function_data.function_name}!!!!!!!!")
34
+
34
35
  return calc_coverage, len(self.environment.test_cases)
35
36
 
@@ -2,9 +2,10 @@ import inspect
2
2
  import ast
3
3
  import time
4
4
  from types import ModuleType
5
- from typing import Dict, List
5
+ from typing import Dict, List, Any
6
6
 
7
7
  import testgen
8
+ from testgen.analyzer.reinforcement_analyzer import ReinforcementAnalyzer
8
9
  from testgen.service.logging_service import get_logger
9
10
  import testgen.util.file_utils
10
11
  import testgen.util.file_utils as file_utils
@@ -37,11 +38,8 @@ class AnalysisService:
37
38
 
38
39
  def generate_test_cases(self) -> List[TestCase]:
39
40
  """Generate test cases using the current strategy."""
40
- if self.test_strategy == REINFORCE_STRAT:
41
- return self.do_reinforcement_learning(self.file_path)
42
- else:
43
- self.test_case_analyzer_context.do_logic()
44
- return self.test_case_analyzer_context.test_cases
41
+ self.test_case_analyzer_context.do_strategy(30)
42
+ return self.test_case_analyzer_context.test_cases
45
43
 
46
44
  def create_analysis_context(self, filepath: str) -> AnalysisContext:
47
45
  """Create an analysis context for the given file."""
@@ -88,6 +86,7 @@ class AnalysisService:
88
86
  mode = mode or self.reinforcement_mode
89
87
  module: ModuleType = testgen.util.file_utils.load_module(filepath)
90
88
  tree: ast.Module = testgen.util.file_utils.load_and_parse_file_for_tree(filepath)
89
+ list_of_function_data: List[FunctionMetadata] = self.get_function_data(filepath, module, class_name)
91
90
  functions: List[ast.FunctionDef] = testgen.util.utils.get_functions(tree)
92
91
  self.class_name = class_name
93
92
  time_limit: int = 30
@@ -95,14 +94,14 @@ class AnalysisService:
95
94
 
96
95
  q_table = self._load_q_table()
97
96
 
98
- for function in functions:
99
- print(f"\nStarting reinforcement learning for function {function.name}")
97
+ for function in list_of_function_data:
98
+ print(f"\nStarting reinforcement learning for function {function.function_name}")
100
99
  start_time = time.time()
101
100
  function_test_cases: List[TestCase] = []
102
101
  best_coverage: float = 0.0
103
102
 
104
103
  # Create environment and agent once per function
105
- environment = ReinforcementEnvironment(filepath, function, module, self.class_name, function_test_cases, state=StatementCoverageState(None))
104
+ environment = ReinforcementEnvironment(filepath, function, function_test_cases, state=StatementCoverageState(None))
106
105
  environment.state = StatementCoverageState(environment)
107
106
 
108
107
  # Create agent with existing Q-table
@@ -115,10 +114,10 @@ class AnalysisService:
115
114
  new_test_cases = agent.collect_test_cases()
116
115
  function_test_cases.extend(new_test_cases)
117
116
 
118
- print(f"\nNumber of test cases for {function.name}: {len(function_test_cases)}")
117
+ print(f"\nNumber of test cases for {function.function_name}: {len(function_test_cases)}")
119
118
 
120
119
  current_coverage: float = environment.run_tests()
121
- print(f"Current coverage: {function.name}: {current_coverage}")
120
+ print(f"Current coverage: {function.function_name}: {current_coverage}")
122
121
 
123
122
  q_table.update(agent.q_table)
124
123
 
@@ -134,8 +133,8 @@ class AnalysisService:
134
133
  unique_test_cases.append(case)
135
134
 
136
135
  all_test_cases.extend(unique_test_cases)
137
- print(f"Final coverage for {function.name}: {best_coverage}%")
138
- print(f"Final test cases for {function.name}: {len(unique_test_cases)}")
136
+ print(f"Final coverage for {function.function_name}: {best_coverage}%")
137
+ print(f"Final test cases for {function.function_name}: {len(unique_test_cases)}")
139
138
 
140
139
  self._save_q_table(q_table)
141
140
 
@@ -146,12 +145,23 @@ class AnalysisService:
146
145
  def _create_function_metadata(self, filename: str, module: ModuleType, class_name: str | None,
147
146
  func_node: ast.FunctionDef) -> FunctionMetadata:
148
147
  function_name = func_node.name
149
-
148
+ func = self._get_func_attr(function_name, module, class_name)
150
149
  param_types = self._get_params(func_node)
151
150
 
152
- return FunctionMetadata(filename, module, class_name, function_name, func_node, param_types)
153
-
154
- def _get_params(self, func_node: ast.FunctionDef) -> Dict[str, str]:
151
+ return FunctionMetadata(filename, module, class_name, func, function_name, func_node, param_types)
152
+
153
+ @staticmethod
154
+ def _get_func_attr(function_name: str, module: ModuleType, class_name: str | None) -> Any:
155
+ if class_name:
156
+ cls = getattr(module, class_name)
157
+ instance = cls()
158
+ func = getattr(instance, function_name)
159
+ else:
160
+ func = getattr(module, function_name)
161
+ return func
162
+
163
+ @staticmethod
164
+ def _get_params(func_node: ast.FunctionDef) -> Dict[str, str]:
155
165
  # Extract parameter types
156
166
  param_types = {}
157
167
  for arg in func_node.args.args:
@@ -173,18 +183,17 @@ class AnalysisService:
173
183
 
174
184
  if strategy == AST_STRAT:
175
185
  analyzer = ASTAnalyzer(analysis_context)
176
- self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
177
186
  elif strategy == FUZZ_STRAT:
178
187
  analyzer = FuzzAnalyzer(analysis_context)
179
- self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
180
188
  elif strategy == RANDOM_STRAT:
181
189
  analyzer = RandomFeedbackAnalyzer(analysis_context)
182
- self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
183
190
  elif strategy == REINFORCE_STRAT:
184
- pass
191
+ analyzer = ReinforcementAnalyzer(analysis_context, mode=self.reinforcement_mode)
185
192
  else:
186
193
  raise NotImplementedError(f"Test strategy {strategy} not implemented")
187
-
194
+
195
+ self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
196
+
188
197
  def set_file_path(self, path: str):
189
198
  """Set the file path for analysis."""
190
199
  self.file_path = path
@@ -1,5 +1,7 @@
1
1
  import os
2
2
  from typing import List
3
+
4
+ from testgen.models.function_metadata import FunctionMetadata
3
5
  from testgen.models.test_case import TestCase
4
6
  from testgen.service.logging_service import get_logger
5
7
  from testgen.util.coverage_visualizer import CoverageVisualizer
@@ -46,7 +48,7 @@ class CFGService:
46
48
  filename = os.path.basename(file_path).replace('.py', '')
47
49
 
48
50
  for func in analysis_context.function_data:
49
- self.visualizer.get_covered_lines(file_path, analysis_context.class_name, func.func_def, test_cases)
51
+ self.visualizer.get_covered_lines(file_path, func, test_cases)
50
52
 
51
53
  base_filename = f"{filename}_{func.function_name}_coverage"
52
54
  output_filepath = self.get_versioned_filename(visualization_dir, base_filename)