testgenie-py 0.3.7__tar.gz → 0.3.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/PKG-INFO +1 -1
  2. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/pyproject.toml +1 -1
  3. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/ast_analyzer.py +2 -11
  4. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/fuzz_analyzer.py +1 -6
  5. testgenie_py-0.3.8/testgen/analyzer/random_feedback_analyzer.py +248 -0
  6. testgenie_py-0.3.8/testgen/analyzer/reinforcement_analyzer.py +77 -0
  7. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/test_case_analyzer_context.py +0 -6
  8. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/controller/cli_controller.py +35 -29
  9. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/controller/docker_controller.py +1 -0
  10. testgenie_py-0.3.8/testgen/db/dao.py +68 -0
  11. testgenie_py-0.3.8/testgen/db/dao_impl.py +226 -0
  12. {testgenie_py-0.3.7/testgen/sqlite → testgenie_py-0.3.8/testgen/db}/db.py +15 -6
  13. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/generator/pytest_generator.py +2 -10
  14. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/generator/unit_test_generator.py +2 -11
  15. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/main.py +1 -3
  16. testgenie_py-0.3.8/testgen/models/coverage_data.py +56 -0
  17. testgenie_py-0.3.8/testgen/models/db_test_case.py +65 -0
  18. testgenie_py-0.3.8/testgen/models/function.py +56 -0
  19. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/models/function_metadata.py +11 -1
  20. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/models/generator_context.py +32 -2
  21. testgenie_py-0.3.8/testgen/models/source_file.py +29 -0
  22. testgenie_py-0.3.8/testgen/models/test_result.py +38 -0
  23. testgenie_py-0.3.8/testgen/models/test_suite.py +20 -0
  24. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/reinforcement/agent.py +1 -27
  25. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/reinforcement/environment.py +11 -93
  26. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/reinforcement/statement_coverage_state.py +5 -4
  27. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/service/analysis_service.py +31 -22
  28. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/service/cfg_service.py +3 -1
  29. testgenie_py-0.3.8/testgen/service/coverage_service.py +115 -0
  30. testgenie_py-0.3.8/testgen/service/db_service.py +140 -0
  31. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/service/generator_service.py +77 -20
  32. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/service/logging_service.py +2 -2
  33. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/service/service.py +62 -231
  34. testgenie_py-0.3.8/testgen/service/test_executor_service.py +145 -0
  35. testgenie_py-0.3.8/testgen/util/coverage_utils.py +152 -0
  36. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/coverage_visualizer.py +10 -9
  37. testgenie_py-0.3.8/testgen/util/file_utils.py +85 -0
  38. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/randomizer.py +0 -26
  39. testgenie_py-0.3.8/testgen/util/utils.py +302 -0
  40. testgenie_py-0.3.7/testgen/analyzer/random_feedback_analyzer.py +0 -521
  41. testgenie_py-0.3.7/testgen/analyzer/reinforcement_analyzer.py +0 -75
  42. testgenie_py-0.3.7/testgen/inspector/inspector.py +0 -59
  43. testgenie_py-0.3.7/testgen/presentation/cli_view.py +0 -12
  44. testgenie_py-0.3.7/testgen/sqlite/db_service.py +0 -239
  45. testgenie_py-0.3.7/testgen/testgen.db +0 -0
  46. testgenie_py-0.3.7/testgen/util/__init__.py +0 -0
  47. testgenie_py-0.3.7/testgen/util/coverage_utils.py +0 -230
  48. testgenie_py-0.3.7/testgen/util/file_utils.py +0 -186
  49. testgenie_py-0.3.7/testgen/util/utils.py +0 -143
  50. testgenie_py-0.3.7/testgen/util/z3_utils/__init__.py +0 -0
  51. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/README.md +0 -0
  52. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/__init__.py +0 -0
  53. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/__init__.py +0 -0
  54. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/contracts/__init__.py +0 -0
  55. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/contracts/contract.py +0 -0
  56. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/contracts/no_exception_contract.py +0 -0
  57. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/contracts/nonnull_contract.py +0 -0
  58. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/analyzer/test_case_analyzer.py +0 -0
  59. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/controller/__init__.py +0 -0
  60. {testgenie_py-0.3.7/testgen/generator → testgenie_py-0.3.8/testgen/db}/__init__.py +0 -0
  61. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/docker/Dockerfile +0 -0
  62. {testgenie_py-0.3.7/testgen/inspector → testgenie_py-0.3.8/testgen/generator}/__init__.py +0 -0
  63. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/generator/code_generator.py +0 -0
  64. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/generator/doctest_generator.py +0 -0
  65. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/generator/generator.py +0 -0
  66. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/generator/test_generator.py +0 -0
  67. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/models/__init__.py +0 -0
  68. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/models/analysis_context.py +0 -0
  69. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/models/test_case.py +0 -0
  70. {testgenie_py-0.3.7/testgen/presentation → testgenie_py-0.3.8/testgen/reinforcement}/__init__.py +0 -0
  71. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/reinforcement/abstract_state.py +0 -0
  72. {testgenie_py-0.3.7/testgen/reinforcement → testgenie_py-0.3.8/testgen/service}/__init__.py +0 -0
  73. {testgenie_py-0.3.7/testgen/service → testgenie_py-0.3.8/testgen/tree}/__init__.py +0 -0
  74. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/tree/node.py +0 -0
  75. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/tree/tree_utils.py +0 -0
  76. {testgenie_py-0.3.7/testgen/sqlite → testgenie_py-0.3.8/testgen/util}/__init__.py +0 -0
  77. {testgenie_py-0.3.7/testgen/tree → testgenie_py-0.3.8/testgen/util/z3_utils}/__init__.py +0 -0
  78. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/z3_utils/ast_to_z3.py +0 -0
  79. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/z3_utils/branch_condition.py +0 -0
  80. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/z3_utils/constraint_extractor.py +0 -0
  81. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/z3_utils/variable_finder.py +0 -0
  82. {testgenie_py-0.3.7 → testgenie_py-0.3.8}/testgen/util/z3_utils/z3_test_case.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: testgenie-py
3
- Version: 0.3.7
3
+ Version: 0.3.8
4
4
  Summary: Automated unit test generation tool for Python.
5
5
  Author: cjseitz
6
6
  Author-email: charlesjseitz@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "testgenie-py"
3
- version = "0.3.7"
3
+ version = "0.3.8"
4
4
  description = "Automated unit test generation tool for Python."
5
5
  authors = ["cjseitz <charlesjseitz@gmail.com>"]
6
6
  readme = "README.md"
@@ -53,17 +53,8 @@ class ASTAnalyzer(TestCaseAnalyzerStrategy, ABC):
53
53
 
54
54
  if not input_exists:
55
55
  try:
56
- generated_file = function_metadata.filename
57
- module = testgen.util.file_utils.load_module(generated_file)
58
-
59
- if function_metadata.class_name:
60
- cls = getattr(module, function_metadata.class_name)
61
- instance = cls()
62
- func = getattr(instance, func_name)
63
- output = func(*inputs)
64
- else:
65
- func = getattr(module, func_name)
66
- output = func(*inputs)
56
+ func = function_metadata.func
57
+ output = func(*inputs)
67
58
 
68
59
  except Exception as e:
69
60
  print(f"Error executing function: {e}")
@@ -29,13 +29,8 @@ class FuzzAnalyzer(TestCaseAnalyzerStrategy, ABC):
29
29
  else:
30
30
  raise ValueError("Module not set in function metadata. Cannot perform fuzzing without a module.")
31
31
 
32
- class_name = function_metadata.class_name if function_metadata.class_name else None
33
32
  try:
34
- if not class_name is None:
35
- cls = getattr(module, class_name, None)
36
- func = getattr(cls(), function_metadata.function_name, None) if cls else None
37
- else:
38
- func = getattr(module, function_metadata.function_name, None)
33
+ func = function_metadata.func
39
34
  if func:
40
35
  return self.run_fuzzing(func, function_metadata.function_name, function_metadata.params, module, 10)
41
36
  except Exception as e:
@@ -0,0 +1,248 @@
1
+ import random
2
+ import time
3
+ import traceback
4
+ from typing import List, Dict, Set
5
+ import testgen.util.coverage_utils as coverage_utils
6
+ from testgen.analyzer.contracts.contract import Contract
7
+ from testgen.analyzer.contracts.no_exception_contract import NoExceptionContract
8
+ from testgen.analyzer.contracts.nonnull_contract import NonNullContract
9
+ from testgen.models.test_case import TestCase
10
+ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
11
+ from abc import ABC
12
+
13
+ from testgen.models.function_metadata import FunctionMetadata
14
+
15
+
16
+ # Citation in which this method and algorithm were taken from:
17
+ # C. Pacheco, S. K. Lahiri, M. D. Ernst and T. Ball, "Feedback-Directed Random Test Generation," 29th International
18
+ # Conference on Software Engineering (ICSE'07), Minneapolis, MN, USA, 2007, pp. 75-84, doi: 10.1109/ICSE.2007.37.
19
+ # keywords: {System testing;Contracts;Object oriented modeling;Law;Legal factors;Open source software;Software
20
+ # testing;Feedback;Filters;Error correction codes},
21
+
22
+ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
23
+ def __init__(self, analysis_context=None):
24
+ super().__init__(analysis_context)
25
+ self.test_cases = []
26
+ self.covered_lines: Dict[str, Set[int]] = {}
27
+ self.covered_functions: Set[str] = set()
28
+
29
+ def collect_test_cases(self, function_metadata: FunctionMetadata, time_limit: int = 5) -> List[TestCase]:
30
+ self.test_cases = []
31
+ start_time = time.time()
32
+
33
+ while (time.time() - start_time) < time_limit:
34
+
35
+ try:
36
+ param_values = self.generate_random_inputs(function_metadata.params)
37
+ func_name = function_metadata.function_name
38
+ function = function_metadata.func
39
+
40
+ param_names = function_metadata.params.keys()
41
+
42
+ ordered_args = [param_values.get(name, None) for name in param_names]
43
+
44
+ result = function(*ordered_args)
45
+ test_case = TestCase(func_name, tuple(ordered_args), result)
46
+
47
+ if not self.is_duplicate_test_case(test_case):
48
+ self.test_cases.append(test_case)
49
+
50
+ covered = self.covered(function_metadata)
51
+ if covered:
52
+ break
53
+ else:
54
+ # Optionally log duplicate detection
55
+ self.logger.debug(f"Skipping duplicate test case: {func_name}{test_case.inputs}")
56
+
57
+ except Exception as e:
58
+ print(f"Error testing {function_metadata.function_name}: {e}")
59
+
60
+ return self.test_cases
61
+
62
+ def is_duplicate_test_case(self, new_test_case: TestCase) -> bool:
63
+ for existing_test_case in self.test_cases:
64
+ if (existing_test_case.func_name == new_test_case.func_name and
65
+ existing_test_case.inputs == new_test_case.inputs):
66
+ return True
67
+ return False
68
+
69
+ def covered(self, func: FunctionMetadata) -> bool:
70
+ if func.function_name not in self.covered_lines:
71
+ self.covered_lines[func.function_name] = set()
72
+
73
+ test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
74
+
75
+ for test_case in test_cases:
76
+ analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, func, test_case.inputs)
77
+ covered = coverage_utils.get_list_of_covered_statements(analysis)
78
+ self.covered_lines[func.function_name].update(covered)
79
+ self.logger.debug(f"Covered lines for {func.function_name}: {self.covered_lines[func.function_name]}")
80
+
81
+
82
+ executable_statements = set(coverage_utils.get_all_executable_statements(self._analysis_context.filepath, func, test_cases))
83
+ self.logger.debug(f"Executable statements for {func.function_name}: {executable_statements}")
84
+
85
+ return self.covered_lines[func.function_name] == executable_statements
86
+
87
+ def execute_sequence(self, sequence, contracts: List[Contract]):
88
+ """Execute a sequence and check contract violations"""
89
+ func_name, args_dict = sequence
90
+
91
+ try:
92
+ # Use module from analysis context if available
93
+ function_metadata = self.get_function_metadata(func_name)
94
+ func = function_metadata.func
95
+ param_names = function_metadata.params.keys()
96
+
97
+ ordered_args = [args_dict.get(name, None) for name in param_names]
98
+
99
+ # Check preconditions
100
+ for contract in contracts:
101
+ if not contract.check_preconditions(tuple(ordered_args)):
102
+ print(f"Preconditions failed for {func_name} with {tuple(ordered_args)}")
103
+ return None, True
104
+
105
+ # Execute function with properly ordered arguments
106
+ output = func(*ordered_args)
107
+ exception = None
108
+
109
+ except Exception as e:
110
+ print(f"EXCEPTION IN RANDOM FEEDBACK: {e}")
111
+ print(traceback.format_exc())
112
+ output = None
113
+ exception = e
114
+
115
+ # Check postconditions
116
+ for contract in contracts:
117
+ if not contract.check_postconditions(tuple(ordered_args), output, exception):
118
+ print(f"Postcondition failed for {func_name} with {tuple(ordered_args)}")
119
+ return output, True
120
+
121
+ return output, False
122
+
123
+ def get_function_metadata(self, func_name: str) -> FunctionMetadata | None:
124
+ for function_data in self._analysis_context.function_data:
125
+ if function_data.function_name == func_name:
126
+ return function_data
127
+ return None
128
+
129
+ # TODO: Currently only getting random vals of primitives, extend to sequences
130
+ def random_seqs_and_vals(self, param_types, non_error_seqs=None):
131
+ return self.generate_random_inputs(param_types)
132
+
133
+ @staticmethod
134
+ def generate_random_inputs(param_types):
135
+ """Generate inputs for fuzzing based on parameter types."""
136
+ inputs = {}
137
+ for param, param_type in param_types.items():
138
+ if param_type == "int":
139
+ random_integer = random.randint(-500, 500) # Wider range for better edge cases
140
+ inputs[param] = random_integer
141
+ elif param_type == "bool":
142
+ random_choice = random.choice([True, False])
143
+ inputs[param] = random_choice
144
+ elif param_type == "float":
145
+ random_float = random.uniform(-500.0, 500.0) # Wider range for better edge cases
146
+ inputs[param] = random_float
147
+ elif param_type == "str":
148
+ # Generate diverse strings instead of always "abc"
149
+ string_type = random.choice([
150
+ "empty", "short", "medium", "long", "special", "numeric", "whitespace"
151
+ ])
152
+
153
+ if string_type == "empty":
154
+ inputs[param] = ""
155
+ elif string_type == "short":
156
+ inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=random.randint(1, 3)))
157
+ elif string_type == "medium":
158
+ inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ', k=random.randint(4, 10)))
159
+ elif string_type == "long":
160
+ inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', k=random.randint(11, 30)))
161
+ elif string_type == "special":
162
+ inputs[param] = ''.join(random.choices('!@#$%^&*()_+-=[]{}|;:,./<>?', k=random.randint(1, 8)))
163
+ elif string_type == "numeric":
164
+ inputs[param] = ''.join(random.choices('0123456789', k=random.randint(1, 10)))
165
+ else: # whitespace
166
+ inputs[param] = ' ' * random.randint(1, 5)
167
+ else:
168
+ # For unknown types, try a default value
169
+ inputs[param] = None
170
+
171
+ return inputs
172
+
173
+ # Algorithm described in above article
174
+ # Classes is the classes for which we want to generate sequences
175
+ # Contracts express invariant properties that hold both at entry and exit from a call
176
+ # Contract takes as input the current state of the system (runtime values created in the sequence so far, and any exception thrown by the last call), and returns satisfied or violated
177
+ # Output is the runtime values and boolean flag violated
178
+ # Filters determine which values of a sequence are extensible and should be used as inputs
179
+ def generate_sequences(self, function_metadata: List[FunctionMetadata], classes=None, contracts: List[Contract] = None, filters=None, time_limit=20):
180
+ contracts = [NonNullContract(), NoExceptionContract()]
181
+ error_seqs = [] # execution violates a contract
182
+ non_error_seqs = [] # execution does not violate a contract
183
+
184
+ functions = self._analysis_context.function_data
185
+ start_time = time.time()
186
+ while(time.time() - start_time) >= time_limit:
187
+ # Get random function
188
+ func = random.choice(functions)
189
+ param_types: dict = func.params
190
+ vals: dict = self.random_seqs_and_vals(param_types)
191
+ new_seq = (func.function_name, vals)
192
+ if new_seq in error_seqs or new_seq in non_error_seqs:
193
+ continue
194
+ outs_violated: tuple = self.execute_sequence(new_seq, contracts)
195
+ violated: bool = outs_violated[1]
196
+ # Create tuple of sequence ((func name, args), output)
197
+ new_seq_out = (new_seq, outs_violated[0])
198
+ if violated:
199
+ error_seqs.append(new_seq_out)
200
+ else:
201
+ # Question: Should I use the failed contract to be the assertion in unit test??
202
+ non_error_seqs.append(new_seq_out)
203
+ return error_seqs, non_error_seqs
204
+
205
+ def generate_sequences_new(self, contracts: List[Contract] = None, filters=None, time_limit=20):
206
+ contracts = [NonNullContract(), NoExceptionContract()]
207
+ error_seqs = [] # execution violates a contract
208
+ non_error_seqs = [] # execution does not violate a contract
209
+
210
+ functions = self._analysis_context.function_data.copy()
211
+ start_time = time.time()
212
+
213
+ while (time.time() - start_time) < time_limit:
214
+ # Get random function
215
+ func = random.choice(functions)
216
+ param_types: dict = func.params
217
+ vals: dict = self.random_seqs_and_vals(param_types)
218
+ new_seq = (func.function_name, vals)
219
+
220
+ if new_seq in [seq[0] for seq in error_seqs] or new_seq in [seq[0] for seq in non_error_seqs]:
221
+ continue
222
+
223
+ outs_violated: tuple = self.execute_sequence(new_seq, contracts)
224
+ violated: bool = outs_violated[1]
225
+
226
+ # Create tuple of sequence ((func name, args), output)
227
+ new_seq_out = (new_seq, outs_violated[0])
228
+
229
+ if violated:
230
+ error_seqs.append(new_seq_out)
231
+
232
+ else:
233
+ non_error_seqs.append(new_seq_out)
234
+
235
+ test_case = TestCase(new_seq_out[0][0], tuple(new_seq_out[0][1].values()), new_seq_out[1])
236
+ self.test_cases.append(test_case)
237
+ fully_covered = self.covered(func)
238
+ if fully_covered:
239
+ print(f"Function {func.function_name} is fully covered")
240
+ functions.remove(func)
241
+
242
+ if not functions:
243
+ self.test_cases.sort(key=lambda tc: tc.func_name)
244
+ print("All functions covered")
245
+ break
246
+
247
+ self.test_cases.sort(key=lambda tc: tc.func_name)
248
+ return error_seqs, non_error_seqs
@@ -0,0 +1,77 @@
1
+
2
+ import ast
3
+ from abc import ABC
4
+ import random
5
+ from typing import List
6
+
7
+ import testgen.util.randomizer
8
+ from testgen.models.function_metadata import FunctionMetadata
9
+ from testgen.models.test_case import TestCase
10
+ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
11
+ from testgen.reinforcement.environment import ReinforcementEnvironment
12
+
13
+
14
+ # Goal: Learn a policy to generate a set of test cases with optimal code coverage, and minimum number of test cases
15
+ # Environment: FUT (Function Under Test)
16
+ # Agent: System
17
+ # Actions: Create new test case, combine test cases, delete test cases
18
+ # Rewards:
19
+
20
+ from typing import List, Optional
21
+ from testgen.models.test_case import TestCase
22
+ from testgen.models.analysis_context import AnalysisContext
23
+ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
24
+ from testgen.reinforcement.agent import ReinforcementAgent
25
+ from testgen.reinforcement.environment import ReinforcementEnvironment
26
+ from testgen.reinforcement.statement_coverage_state import StatementCoverageState
27
+
28
+
29
+ class ReinforcementAnalyzer(TestCaseAnalyzerStrategy):
30
+ def __init__(self, analysis_context: AnalysisContext, mode: str = "train"):
31
+ super().__init__(analysis_context)
32
+ self.analysis_context = analysis_context
33
+ self.mode = mode
34
+
35
+ def collect_test_cases(self, function_metadata: FunctionMetadata):
36
+ # Implement or delegate as needed
37
+ return self.analyze(function_metadata)
38
+
39
+ def analyze(self, function_metadata: FunctionMetadata) -> List[TestCase]:
40
+ from testgen.service.analysis_service import AnalysisService
41
+
42
+ q_table = AnalysisService._load_q_table()
43
+ function_test_cases: List[TestCase] = []
44
+
45
+
46
+ environment = ReinforcementEnvironment(
47
+ self.analysis_context.filepath,
48
+ function_metadata,
49
+ function_test_cases,
50
+ state=StatementCoverageState(None)
51
+ )
52
+ environment.state = StatementCoverageState(environment)
53
+ agent = ReinforcementAgent(
54
+ self.analysis_context.filepath,
55
+ environment,
56
+ function_test_cases,
57
+ q_table
58
+ )
59
+ episodes = 10 if self.mode == "train" else 1
60
+ for _ in range(episodes):
61
+ if self.mode == "train":
62
+ new_test_cases = agent.do_q_learning()
63
+ else:
64
+ new_test_cases = agent.collect_test_cases()
65
+ function_test_cases.extend(new_test_cases)
66
+
67
+ seen = set()
68
+ unique_test_cases = []
69
+ for case in function_test_cases:
70
+ case_inputs = tuple(case.inputs) if isinstance(case.inputs, list) else case.inputs
71
+ case_key = (case.func_name, case_inputs)
72
+ if case_key not in seen:
73
+ seen.add(case_key)
74
+ unique_test_cases.append(case)
75
+
76
+ AnalysisService._save_q_table(q_table)
77
+ return unique_test_cases
@@ -11,12 +11,6 @@ class TestCaseAnalyzerContext:
11
11
  self._test_case_analyzer = test_case_analyzer
12
12
  self._analysis_context = analysis_context
13
13
  self._test_cases = []
14
-
15
- # TODO: GET RID OF THIS STUPID METHOD IT IS POINTLESS
16
- # JUST CALL INSIDE ANALYZER_SERVICE
17
- def do_logic(self) -> List[TestCase]:
18
- """Run the analysis process"""
19
- self.do_strategy(20)
20
14
 
21
15
  def do_strategy(self, time_limit=None) -> List[TestCase]:
22
16
  """Execute the analysis strategy for all functions with an optional time limit"""
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import inspect
3
2
  import os
4
3
  import sys
5
4
 
@@ -10,8 +9,7 @@ from testgen.service.logging_service import LoggingService, get_logger
10
9
  from testgen.util.file_utils import adjust_file_path_for_docker, get_project_root_in_docker
11
10
  from testgen.controller.docker_controller import DockerController
12
11
  from testgen.service.service import Service
13
- from testgen.presentation.cli_view import CLIView
14
- from testgen.sqlite.db_service import DBService
12
+ from testgen.service.db_service import DBService
15
13
 
16
14
  AST_STRAT = 1
17
15
  FUZZ_STRAT = 2
@@ -24,9 +22,9 @@ DOCTEST_FORMAT = 3
24
22
 
25
23
  class CLIController:
26
24
  #TODO: Possibly create a view 'interface' and use dependency injection to extend other views
27
- def __init__(self, service: Service, view: CLIView):
25
+ def __init__(self, service: Service):
28
26
  self.service = service
29
- self.view = view
27
+ self.logger = None
30
28
 
31
29
  def run(self):
32
30
 
@@ -35,12 +33,16 @@ class CLIController:
35
33
  args = parser.parse_args()
36
34
 
37
35
  LoggingService.get_instance().initialize(
38
- debug_mode=args.debug if hasattr(args, 'debug') else False,
36
+ debug_mode=args.debug if args.debug else False,
39
37
  log_file=args.log_file if hasattr(args, 'log_file') else None,
40
38
  console_output=True
41
39
  )
42
40
 
43
- logger = get_logger()
41
+ self.logger = get_logger()
42
+
43
+ if args.functions:
44
+ self.service.get_all_functions(args.file_path)
45
+ return
44
46
 
45
47
  if args.query:
46
48
  print(f"Querying database for file: {args.file_path}")
@@ -48,7 +50,7 @@ class CLIController:
48
50
  return
49
51
 
50
52
  if args.coverage:
51
- self.service.get_coverage(args.file_path)
53
+ self.service.run_coverage(args.file_path)
52
54
  return
53
55
 
54
56
  running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
@@ -59,7 +61,7 @@ class CLIController:
59
61
  client = self.docker_available()
60
62
  # Skip Docker-dependent operations if client is None
61
63
  if client is None and args.safe:
62
- self.view.display_message("Running with --safe flag requires Docker. Continuing without safe mode.")
64
+ self.logger.debug("Running with --safe flag requires Docker. Continuing without safe mode.")
63
65
  args.safe = False
64
66
  self.execute_generation(args)
65
67
  else:
@@ -69,15 +71,13 @@ class CLIController:
69
71
  if not successful:
70
72
  if hasattr(args, 'db') and args.db:
71
73
  self.service.db_service = DBService(args.db)
72
- self.view.display_message(f"Using database: {args.db}")
74
+ self.logger.debug(f"Using database: {args.db}")
73
75
  self.execute_generation(args)
74
- # Else successful, do nothing - we're done
75
76
  else:
76
- # Initialize database service with specified path
77
77
  if hasattr(args, 'db') and args.db:
78
78
  self.service.db_service = DBService(args.db)
79
- self.view.display_message(f"Using database: {args.db}")
80
- self.view.display_message("Running in local mode...")
79
+ self.logger.debug(f"Using database: {args.db}")
80
+ self.logger.debug("Running in local mode...")
81
81
  self.execute_generation(args)
82
82
 
83
83
  def execute_generation(self, args: argparse.Namespace, running_in_docker: bool = False):
@@ -85,22 +85,23 @@ class CLIController:
85
85
  self.set_service_args(args)
86
86
 
87
87
  if running_in_docker:
88
- self.view.display_message("Running in Docker mode...")
88
+ self.logger.debug("Running in Docker mode...")
89
89
  self.service.generate_test_cases()
90
90
 
91
91
  else:
92
92
  test_file = self.service.generate_tests(args.output)
93
- self.view.display_message(f"Unit tests saved to: {test_file}")
94
- self.view.display_message("Running coverage...")
93
+ self.logger.debug(f"Unit tests saved to: {test_file}")
94
+ print("Executing tests...")
95
+ self.service.run_tests(test_file)
96
+ print("Running coverage...")
95
97
  self.service.run_coverage(test_file)
96
- self.view.display_message("Tests and coverage data saved to database.")
98
+ self.logger.debug("Tests and coverage data saved to database.")
97
99
 
98
100
  if args.visualize:
99
101
  self.service.visualize_test_coverage()
100
102
 
101
103
  except Exception as e:
102
- self.view.display_error(f"An error occurred: {e}")
103
- # Make sure to close the DB connection on error
104
+ self.logger.error(f"An error occurred: {e}")
104
105
  if hasattr(self.service, 'db_service'):
105
106
  self.service.db_service.close()
106
107
 
@@ -168,6 +169,11 @@ class CLIController:
168
169
  action="store_true",
169
170
  help="Run coverage analysis on the generated tests"
170
171
  )
172
+ parser.add_argument(
173
+ "-f", "--functions",
174
+ action="store_true",
175
+ help="List all functions in file"
176
+ )
171
177
  return parser
172
178
 
173
179
  def set_test_format(self, args: argparse.Namespace):
@@ -180,32 +186,32 @@ class CLIController:
180
186
 
181
187
  def set_test_strategy(self, args: argparse.Namespace):
182
188
  if args.test_mode == "random":
183
- self.view.display_message("Using Random Feedback-Directed Test Generation Strategy.")
189
+ print("Using Random Feedback-Directed Test Generation Strategy.")
184
190
  self.service.set_test_analysis_strategy(RANDOM_STRAT)
185
191
  elif args.test_mode == "fuzz":
186
- self.view.display_message("Using Fuzz Test Generation Strategy...")
192
+ print("Using Fuzz Test Generation Strategy...")
187
193
  self.service.set_test_analysis_strategy(FUZZ_STRAT)
188
194
  elif args.test_mode == "reinforce":
189
- self.view.display_message("Using Reinforcement Learning Test Generation Strategy...")
195
+ print("Using Reinforcement Learning Test Generation Strategy...")
190
196
  if args.reinforce_mode == "train":
191
- self.view.display_message("Training mode enabled - will update Q-table")
197
+ print("Training mode enabled - will update Q-table")
192
198
  else:
193
- self.view.display_message("Training mode disabled - will use existing Q-table")
199
+ print("Training mode disabled - will use existing Q-table")
194
200
  self.service.set_test_analysis_strategy(REINFORCE_STRAT)
195
201
  self.service.set_reinforcement_mode(args.reinforce_mode)
196
202
  else:
197
- self.view.display_message("Generating function code using AST analysis...")
203
+ print("Generating function code using AST analysis...")
198
204
  generated_file_path = self.service.generate_function_code()
199
- self.view.display_message(f"Generated code saved to: {generated_file_path}")
205
+ print(f"Generated code saved to: {generated_file_path}")
200
206
  if not args.generate_only:
201
- self.view.display_message("Using Simple AST Traversal Test Generation Strategy...")
207
+ print("Using Simple AST Traversal Test Generation Strategy...")
202
208
  self.service.set_test_analysis_strategy(AST_STRAT)
203
209
 
204
210
  def docker_available(self) -> DockerClient | None:
205
211
  try:
206
212
  client = docker.from_env()
207
213
  client.ping()
208
- self.view.display_message("Docker daemon is running and connected.")
214
+ print("Docker daemon is running and connected.")
209
215
  return client
210
216
  except docker.errors.DockerException as err:
211
217
  print(f"Docker is not available: {err}")
@@ -98,6 +98,7 @@ class DockerController:
98
98
 
99
99
  if not args.generate_only:
100
100
  print("Running coverage...")
101
+ self.service.run_tests(test_file)
101
102
  self.service.run_coverage(test_file)
102
103
 
103
104
  # Add explicit return True here
@@ -0,0 +1,68 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Tuple, Any
3
+
4
+ from testgen.models.function import Function
5
+
6
+
7
+ class Dao(ABC):
8
+ @abstractmethod
9
+ def insert_test_suite(self, name: str) -> int:
10
+ pass
11
+
12
+ @abstractmethod
13
+ def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
14
+ pass
15
+
16
+ @abstractmethod
17
+ def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
18
+ pass
19
+
20
+ @abstractmethod
21
+ def insert_test_case(self, test_case: Any, test_suite_id: int, function_id: int, test_method_type: int) -> int:
22
+ pass
23
+
24
+ @abstractmethod
25
+ def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
26
+ pass
27
+
28
+ @abstractmethod
29
+ def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
30
+ branch_coverage: float, source_file_id: int, function_id: int = None) -> int:
31
+ pass
32
+
33
+ @abstractmethod
34
+ def get_test_suites(self) -> List[Any]:
35
+ pass
36
+
37
+ @abstractmethod
38
+ def get_test_cases_by_function(self, function_name: str) -> List[Any]:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def get_source_file_id_by_path(self, filepath: str) -> int:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_coverage_by_file(self, file_path: str) -> List[Any]:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def get_test_file_data(self, file_path: str) -> List[Any]:
51
+ pass
52
+
53
+ @abstractmethod
54
+ def get_function_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int)-> int:
55
+ pass
56
+
57
+ @abstractmethod
58
+ def get_functions_by_file(self, filepath: str) -> List[Function]:
59
+ pass
60
+
61
+ @abstractmethod
62
+ def get_test_suite_id_by_name(self, name: str) -> int:
63
+ pass
64
+
65
+ @abstractmethod
66
+ def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
67
+ pass
68
+