testgenie-py 0.3.7__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/ast_analyzer.py +2 -11
- testgen/analyzer/fuzz_analyzer.py +1 -6
- testgen/analyzer/random_feedback_analyzer.py +20 -293
- testgen/analyzer/reinforcement_analyzer.py +59 -57
- testgen/analyzer/test_case_analyzer_context.py +0 -6
- testgen/controller/cli_controller.py +35 -29
- testgen/controller/docker_controller.py +1 -0
- testgen/db/dao.py +68 -0
- testgen/db/dao_impl.py +226 -0
- testgen/{sqlite → db}/db.py +15 -6
- testgen/generator/pytest_generator.py +2 -10
- testgen/generator/unit_test_generator.py +2 -11
- testgen/main.py +1 -3
- testgen/models/coverage_data.py +56 -0
- testgen/models/db_test_case.py +65 -0
- testgen/models/function.py +56 -0
- testgen/models/function_metadata.py +11 -1
- testgen/models/generator_context.py +32 -2
- testgen/models/source_file.py +29 -0
- testgen/models/test_result.py +38 -0
- testgen/models/test_suite.py +20 -0
- testgen/reinforcement/agent.py +1 -27
- testgen/reinforcement/environment.py +11 -93
- testgen/reinforcement/statement_coverage_state.py +5 -4
- testgen/service/analysis_service.py +31 -22
- testgen/service/cfg_service.py +3 -1
- testgen/service/coverage_service.py +115 -0
- testgen/service/db_service.py +140 -0
- testgen/service/generator_service.py +77 -20
- testgen/service/logging_service.py +2 -2
- testgen/service/service.py +62 -231
- testgen/service/test_executor_service.py +145 -0
- testgen/util/coverage_utils.py +38 -116
- testgen/util/coverage_visualizer.py +10 -9
- testgen/util/file_utils.py +10 -111
- testgen/util/randomizer.py +0 -26
- testgen/util/utils.py +197 -38
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/METADATA +1 -1
- testgenie_py-0.3.8.dist-info/RECORD +72 -0
- testgen/inspector/inspector.py +0 -59
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +0 -12
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db_service.py +0 -239
- testgen/testgen.db +0 -0
- testgenie_py-0.3.7.dist-info/RECORD +0 -67
- /testgen/{inspector → db}/__init__.py +0 -0
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/WHEEL +0 -0
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.8.dist-info}/entry_points.txt +0 -0
testgen/analyzer/ast_analyzer.py
CHANGED
@@ -53,17 +53,8 @@ class ASTAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
53
53
|
|
54
54
|
if not input_exists:
|
55
55
|
try:
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
if function_metadata.class_name:
|
60
|
-
cls = getattr(module, function_metadata.class_name)
|
61
|
-
instance = cls()
|
62
|
-
func = getattr(instance, func_name)
|
63
|
-
output = func(*inputs)
|
64
|
-
else:
|
65
|
-
func = getattr(module, func_name)
|
66
|
-
output = func(*inputs)
|
56
|
+
func = function_metadata.func
|
57
|
+
output = func(*inputs)
|
67
58
|
|
68
59
|
except Exception as e:
|
69
60
|
print(f"Error executing function: {e}")
|
@@ -29,13 +29,8 @@ class FuzzAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
29
29
|
else:
|
30
30
|
raise ValueError("Module not set in function metadata. Cannot perform fuzzing without a module.")
|
31
31
|
|
32
|
-
class_name = function_metadata.class_name if function_metadata.class_name else None
|
33
32
|
try:
|
34
|
-
|
35
|
-
cls = getattr(module, class_name, None)
|
36
|
-
func = getattr(cls(), function_metadata.function_name, None) if cls else None
|
37
|
-
else:
|
38
|
-
func = getattr(module, function_metadata.function_name, None)
|
33
|
+
func = function_metadata.func
|
39
34
|
if func:
|
40
35
|
return self.run_fuzzing(func, function_metadata.function_name, function_metadata.params, module, 10)
|
41
36
|
except Exception as e:
|
@@ -1,13 +1,7 @@
|
|
1
|
-
import ast
|
2
|
-
import importlib
|
3
1
|
import random
|
4
2
|
import time
|
5
3
|
import traceback
|
6
4
|
from typing import List, Dict, Set
|
7
|
-
import z3
|
8
|
-
|
9
|
-
import testgen.util.randomizer
|
10
|
-
import testgen.util.utils as utils
|
11
5
|
import testgen.util.coverage_utils as coverage_utils
|
12
6
|
from testgen.analyzer.contracts.contract import Contract
|
13
7
|
from testgen.analyzer.contracts.no_exception_contract import NoExceptionContract
|
@@ -17,8 +11,6 @@ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
|
17
11
|
from abc import ABC
|
18
12
|
|
19
13
|
from testgen.models.function_metadata import FunctionMetadata
|
20
|
-
from testgen.util.z3_utils.constraint_extractor import extract_branch_conditions
|
21
|
-
from testgen.util.z3_utils.ast_to_z3 import ast_to_z3_constraint
|
22
14
|
|
23
15
|
|
24
16
|
# Citation in which this method and algorithm were taken from:
|
@@ -42,26 +34,12 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
42
34
|
|
43
35
|
try:
|
44
36
|
param_values = self.generate_random_inputs(function_metadata.params)
|
45
|
-
module = self.analysis_context.module
|
46
37
|
func_name = function_metadata.function_name
|
38
|
+
function = function_metadata.func
|
47
39
|
|
48
|
-
|
49
|
-
cls = getattr(module, self._analysis_context.class_name)
|
50
|
-
obj = cls()
|
51
|
-
function = getattr(obj, func_name)
|
52
|
-
else:
|
53
|
-
function = getattr(module, func_name)
|
54
|
-
|
55
|
-
import inspect
|
56
|
-
sig = inspect.signature(function)
|
57
|
-
param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
|
40
|
+
param_names = function_metadata.params.keys()
|
58
41
|
|
59
|
-
ordered_args = []
|
60
|
-
for name in param_names:
|
61
|
-
if name in param_values:
|
62
|
-
ordered_args.append(param_values[name])
|
63
|
-
else:
|
64
|
-
ordered_args.append(None)
|
42
|
+
ordered_args = [param_values.get(name, None) for name in param_names]
|
65
43
|
|
66
44
|
result = function(*ordered_args)
|
67
45
|
test_case = TestCase(func_name, tuple(ordered_args), result)
|
@@ -92,14 +70,16 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
92
70
|
if func.function_name not in self.covered_lines:
|
93
71
|
self.covered_lines[func.function_name] = set()
|
94
72
|
|
95
|
-
|
96
|
-
|
97
|
-
|
73
|
+
test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
|
74
|
+
|
75
|
+
for test_case in test_cases:
|
76
|
+
analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, func, test_case.inputs)
|
98
77
|
covered = coverage_utils.get_list_of_covered_statements(analysis)
|
99
78
|
self.covered_lines[func.function_name].update(covered)
|
100
79
|
self.logger.debug(f"Covered lines for {func.function_name}: {self.covered_lines[func.function_name]}")
|
101
80
|
|
102
|
-
|
81
|
+
|
82
|
+
executable_statements = set(coverage_utils.get_all_executable_statements(self._analysis_context.filepath, func, test_cases))
|
103
83
|
self.logger.debug(f"Executable statements for {func.function_name}: {executable_statements}")
|
104
84
|
|
105
85
|
return self.covered_lines[func.function_name] == executable_statements
|
@@ -107,34 +87,14 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
107
87
|
def execute_sequence(self, sequence, contracts: List[Contract]):
|
108
88
|
"""Execute a sequence and check contract violations"""
|
109
89
|
func_name, args_dict = sequence
|
110
|
-
args = tuple(args_dict.values()) # Convert dict values to tuple
|
111
90
|
|
112
91
|
try:
|
113
92
|
# Use module from analysis context if available
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
cls = getattr(module, self._analysis_context.class_name, None)
|
118
|
-
if cls is None:
|
119
|
-
raise AttributeError(f"Class '{self._analysis_context.class_name}' not found")
|
120
|
-
obj = cls() # Instantiate the class
|
121
|
-
func = getattr(obj, func_name, None)
|
122
|
-
|
123
|
-
import inspect
|
124
|
-
sig = inspect.signature(func)
|
125
|
-
param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
|
126
|
-
else:
|
127
|
-
func = getattr(module, func_name, None)
|
128
|
-
|
129
|
-
import inspect
|
130
|
-
sig = inspect.signature(func)
|
131
|
-
param_names = [p.name for p in sig.parameters.values()]
|
93
|
+
function_metadata = self.get_function_metadata(func_name)
|
94
|
+
func = function_metadata.func
|
95
|
+
param_names = function_metadata.params.keys()
|
132
96
|
|
133
|
-
|
134
|
-
ordered_args = []
|
135
|
-
for name in param_names:
|
136
|
-
if name in args_dict:
|
137
|
-
ordered_args.append(args_dict[name])
|
97
|
+
ordered_args = [args_dict.get(name, None) for name in param_names]
|
138
98
|
|
139
99
|
# Check preconditions
|
140
100
|
for contract in contracts:
|
@@ -159,26 +119,17 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
159
119
|
return output, True
|
160
120
|
|
161
121
|
return output, False
|
162
|
-
|
122
|
+
|
123
|
+
def get_function_metadata(self, func_name: str) -> FunctionMetadata | None:
|
124
|
+
for function_data in self._analysis_context.function_data:
|
125
|
+
if function_data.function_name == func_name:
|
126
|
+
return function_data
|
127
|
+
return None
|
163
128
|
|
164
129
|
# TODO: Currently only getting random vals of primitives, extend to sequences
|
165
130
|
def random_seqs_and_vals(self, param_types, non_error_seqs=None):
|
166
131
|
return self.generate_random_inputs(param_types)
|
167
132
|
|
168
|
-
@staticmethod
|
169
|
-
def extract_parameter_types(func_node):
|
170
|
-
"""Extract parameter types from a function node."""
|
171
|
-
param_types = {}
|
172
|
-
for arg in func_node.args.args:
|
173
|
-
param_name = arg.arg
|
174
|
-
if arg.annotation:
|
175
|
-
param_type = ast.unparse(arg.annotation)
|
176
|
-
param_types[param_name] = param_type
|
177
|
-
else:
|
178
|
-
if param_name != 'self':
|
179
|
-
param_types[param_name] = None
|
180
|
-
return param_types
|
181
|
-
|
182
133
|
@staticmethod
|
183
134
|
def generate_random_inputs(param_types):
|
184
135
|
"""Generate inputs for fuzzing based on parameter types."""
|
@@ -294,228 +245,4 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
294
245
|
break
|
295
246
|
|
296
247
|
self.test_cases.sort(key=lambda tc: tc.func_name)
|
297
|
-
return error_seqs, non_error_seqs
|
298
|
-
|
299
|
-
def get_all_executable_statements(self, func: FunctionMetadata):
|
300
|
-
import ast
|
301
|
-
|
302
|
-
test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
|
303
|
-
|
304
|
-
if not test_cases:
|
305
|
-
print("Warning: No test cases available to determine executable statements")
|
306
|
-
from testgen.util.randomizer import new_random_test_case
|
307
|
-
temp_case = new_random_test_case(self._analysis_context.filepath, func.func_def)
|
308
|
-
analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name,
|
309
|
-
temp_case.inputs)
|
310
|
-
else:
|
311
|
-
analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name, test_cases[0].inputs)
|
312
|
-
|
313
|
-
executable_lines = list(analysis[1])
|
314
|
-
|
315
|
-
with open(self._analysis_context.filepath, 'r') as f:
|
316
|
-
source = f.read()
|
317
|
-
|
318
|
-
tree = ast.parse(source)
|
319
|
-
|
320
|
-
for node in ast.walk(tree):
|
321
|
-
if isinstance(node, ast.FunctionDef) and node.name == func.func_def.name:
|
322
|
-
for if_node in ast.walk(node):
|
323
|
-
if isinstance(if_node, ast.If) and if_node.orelse:
|
324
|
-
if isinstance(if_node.orelse[0], ast.If):
|
325
|
-
continue
|
326
|
-
else_line = if_node.orelse[0].lineno - 1
|
327
|
-
|
328
|
-
with open(self._analysis_context.filepath, 'r') as f:
|
329
|
-
lines = f.readlines()
|
330
|
-
if else_line <= len(lines):
|
331
|
-
line_content = lines[else_line - 1].strip()
|
332
|
-
if line_content == "else:":
|
333
|
-
if else_line not in executable_lines:
|
334
|
-
executable_lines.append(else_line)
|
335
|
-
|
336
|
-
return sorted(executable_lines)
|
337
|
-
|
338
|
-
"""
|
339
|
-
def collect_test_cases_with_z3(self, function_metadata: FunctionMetadata) -> List[TestCase]:
|
340
|
-
test_cases = []
|
341
|
-
|
342
|
-
z3_test_cases = self.generate_z3_test_cases(function_metadata)
|
343
|
-
if z3_test_cases:
|
344
|
-
test_cases.extend(z3_test_cases)
|
345
|
-
|
346
|
-
if not test_cases:
|
347
|
-
test_cases = self.generate_sequences_new()[1]
|
348
|
-
|
349
|
-
self.test_cases = test_cases
|
350
|
-
return test_cases
|
351
|
-
|
352
|
-
def generate_z3_test_cases(self, function_metadata: FunctionMetadata) -> List[TestCase]:
|
353
|
-
test_cases = []
|
354
|
-
|
355
|
-
branch_conditions, param_types = extract_branch_conditions(function_metadata.func_def)
|
356
|
-
|
357
|
-
if not branch_conditions:
|
358
|
-
random_inputs = self.generate_random_inputs(function_metadata.params)
|
359
|
-
try:
|
360
|
-
module = self.analysis_context.module
|
361
|
-
func_name = function_metadata.function_name
|
362
|
-
|
363
|
-
if self._analysis_context.class_name:
|
364
|
-
cls = getattr(module, self._analysis_context.class_name)
|
365
|
-
obj = cls()
|
366
|
-
func = getattr(obj, func_name)
|
367
|
-
ordered_args = self._order_arguments(func, random_inputs)
|
368
|
-
output = func(*ordered_args)
|
369
|
-
else:
|
370
|
-
func = getattr(module, func_name)
|
371
|
-
ordered_args = self._order_arguments(func, random_inputs)
|
372
|
-
output = func(*ordered_args)
|
373
|
-
|
374
|
-
test_cases.append(TestCase(func_name, tuple(ordered_args), output))
|
375
|
-
except Exception as e:
|
376
|
-
print(f"Error executing function with random inputs: {e}")
|
377
|
-
|
378
|
-
return test_cases
|
379
|
-
|
380
|
-
for branch_condition in branch_conditions:
|
381
|
-
try:
|
382
|
-
z3_expr, z3_vars = ast_to_z3_constraint(branch_condition, function_metadata.params)
|
383
|
-
|
384
|
-
solver = z3.Solver()
|
385
|
-
solver.add(z3_expr)
|
386
|
-
|
387
|
-
neg_solver = z3.Solver()
|
388
|
-
neg_solver.add(z3.Not(z3_expr))
|
389
|
-
|
390
|
-
for current_solver in [solver, neg_solver]:
|
391
|
-
if current_solver.check() == z3.sat:
|
392
|
-
model = current_solver.model()
|
393
|
-
|
394
|
-
param_values = self._extract_z3_solution(model, z3_vars, function_metadata.params)
|
395
|
-
|
396
|
-
ordered_params = self._order_parameters(function_metadata.func_def, param_values)
|
397
|
-
|
398
|
-
try:
|
399
|
-
module = self.analysis_context.module
|
400
|
-
func_name = function_metadata.function_name
|
401
|
-
|
402
|
-
if self._analysis_context.class_name:
|
403
|
-
cls = getattr(module, self._analysis_context.class_name)
|
404
|
-
obj = cls()
|
405
|
-
func = getattr(obj, func_name)
|
406
|
-
else:
|
407
|
-
func = getattr(module, func_name)
|
408
|
-
|
409
|
-
result = func(*ordered_params)
|
410
|
-
test_cases.append(TestCase(func_name, tuple(ordered_params), result))
|
411
|
-
except Exception as e:
|
412
|
-
print(f"Error executing function with Z3 solution: {e}")
|
413
|
-
self._add_random_test_case(function_metadata, test_cases)
|
414
|
-
else:
|
415
|
-
self._add_random_test_case(function_metadata, test_cases)
|
416
|
-
|
417
|
-
except Exception as e:
|
418
|
-
print(f"Error processing branch condition with Z3: {e}")
|
419
|
-
self._add_random_test_case(function_metadata, test_cases)
|
420
|
-
|
421
|
-
return test_cases
|
422
|
-
|
423
|
-
def _extract_z3_solution(self, model, z3_vars, param_types):
|
424
|
-
param_values = {}
|
425
|
-
|
426
|
-
for var_name, z3_var in z3_vars.items():
|
427
|
-
if var_name in param_types:
|
428
|
-
try:
|
429
|
-
model_value = model.evaluate(z3_var)
|
430
|
-
|
431
|
-
if param_types[var_name] == "int":
|
432
|
-
param_values[var_name] = model_value.as_long()
|
433
|
-
elif param_types[var_name] == "float":
|
434
|
-
param_values[var_name] = float(model_value.as_decimal(10))
|
435
|
-
elif param_types[var_name] == "bool":
|
436
|
-
param_values[var_name] = z3.is_true(model_value)
|
437
|
-
elif param_types[var_name] == "str":
|
438
|
-
str_val = str(model_value)
|
439
|
-
if str_val.startswith('"') and str_val.endswith('"'):
|
440
|
-
str_val = str_val[1:-1]
|
441
|
-
param_values[var_name] = str_val
|
442
|
-
else:
|
443
|
-
# Default to int for unrecognized types
|
444
|
-
param_values[var_name] = model_value.as_long()
|
445
|
-
except Exception as e:
|
446
|
-
print(f"Couldn't get {var_name} from model: {e}")
|
447
|
-
# Use default values for parameters not in the model
|
448
|
-
if param_types[var_name] == "int":
|
449
|
-
param_values[var_name] = 0
|
450
|
-
elif param_types[var_name] == "float":
|
451
|
-
param_values[var_name] = 0.0
|
452
|
-
elif param_types[var_name] == "bool":
|
453
|
-
param_values[var_name] = False
|
454
|
-
elif param_types[var_name] == "str":
|
455
|
-
param_values[var_name] = ""
|
456
|
-
else:
|
457
|
-
param_values[var_name] = None
|
458
|
-
|
459
|
-
return param_values
|
460
|
-
|
461
|
-
def _order_parameters(self, func_node, param_values):
|
462
|
-
ordered_params = []
|
463
|
-
|
464
|
-
for arg in func_node.args.args:
|
465
|
-
arg_name = arg.arg
|
466
|
-
if arg_name == 'self': # Skip self parameter
|
467
|
-
continue
|
468
|
-
if arg_name in param_values:
|
469
|
-
ordered_params.append(param_values[arg_name])
|
470
|
-
else:
|
471
|
-
# Default value handling if parameter not in solution
|
472
|
-
if arg.annotation and hasattr(arg.annotation, 'id'):
|
473
|
-
if arg.annotation.id == 'int':
|
474
|
-
ordered_params.append(0)
|
475
|
-
elif arg.annotation.id == 'float':
|
476
|
-
ordered_params.append(0.0)
|
477
|
-
elif arg.annotation.id == 'bool':
|
478
|
-
ordered_params.append(False)
|
479
|
-
elif arg.annotation.id == 'str':
|
480
|
-
ordered_params.append('')
|
481
|
-
else:
|
482
|
-
ordered_params.append(None)
|
483
|
-
else:
|
484
|
-
ordered_params.append(None)
|
485
|
-
|
486
|
-
return ordered_params
|
487
|
-
|
488
|
-
def _order_arguments(self, func, args_dict):
|
489
|
-
import inspect
|
490
|
-
sig = inspect.signature(func)
|
491
|
-
param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
|
492
|
-
|
493
|
-
ordered_args = []
|
494
|
-
for name in param_names:
|
495
|
-
if name in args_dict:
|
496
|
-
ordered_args.append(args_dict[name])
|
497
|
-
else:
|
498
|
-
ordered_args.append(None) # Default to None if missing
|
499
|
-
|
500
|
-
return ordered_args
|
501
|
-
|
502
|
-
def _add_random_test_case(self, function_metadata, test_cases):
|
503
|
-
random_inputs = self.generate_random_inputs(function_metadata.params)
|
504
|
-
try:
|
505
|
-
module = self.analysis_context.module
|
506
|
-
func_name = function_metadata.function_name
|
507
|
-
|
508
|
-
if self._analysis_context.class_name:
|
509
|
-
cls = getattr(module, self._analysis_context.class_name)
|
510
|
-
obj = cls()
|
511
|
-
func = getattr(obj, func_name)
|
512
|
-
else:
|
513
|
-
func = getattr(module, func_name)
|
514
|
-
|
515
|
-
ordered_args = self._order_arguments(func, random_inputs)
|
516
|
-
|
517
|
-
output = func(*ordered_args)
|
518
|
-
test_cases.append(TestCase(func_name, tuple(ordered_args), output))
|
519
|
-
except Exception as e:
|
520
|
-
print(f"Error executing function with random inputs: {e}")
|
521
|
-
"""
|
248
|
+
return error_seqs, non_error_seqs
|
@@ -5,6 +5,7 @@ import random
|
|
5
5
|
from typing import List
|
6
6
|
|
7
7
|
import testgen.util.randomizer
|
8
|
+
from testgen.models.function_metadata import FunctionMetadata
|
8
9
|
from testgen.models.test_case import TestCase
|
9
10
|
from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
10
11
|
from testgen.reinforcement.environment import ReinforcementEnvironment
|
@@ -16,60 +17,61 @@ from testgen.reinforcement.environment import ReinforcementEnvironment
|
|
16
17
|
# Actions: Create new test case, combine test cases, delete test cases
|
17
18
|
# Rewards:
|
18
19
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
if
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
20
|
+
from typing import List, Optional
|
21
|
+
from testgen.models.test_case import TestCase
|
22
|
+
from testgen.models.analysis_context import AnalysisContext
|
23
|
+
from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
24
|
+
from testgen.reinforcement.agent import ReinforcementAgent
|
25
|
+
from testgen.reinforcement.environment import ReinforcementEnvironment
|
26
|
+
from testgen.reinforcement.statement_coverage_state import StatementCoverageState
|
27
|
+
|
28
|
+
|
29
|
+
class ReinforcementAnalyzer(TestCaseAnalyzerStrategy):
|
30
|
+
def __init__(self, analysis_context: AnalysisContext, mode: str = "train"):
|
31
|
+
super().__init__(analysis_context)
|
32
|
+
self.analysis_context = analysis_context
|
33
|
+
self.mode = mode
|
34
|
+
|
35
|
+
def collect_test_cases(self, function_metadata: FunctionMetadata):
|
36
|
+
# Implement or delegate as needed
|
37
|
+
return self.analyze(function_metadata)
|
38
|
+
|
39
|
+
def analyze(self, function_metadata: FunctionMetadata) -> List[TestCase]:
|
40
|
+
from testgen.service.analysis_service import AnalysisService
|
41
|
+
|
42
|
+
q_table = AnalysisService._load_q_table()
|
43
|
+
function_test_cases: List[TestCase] = []
|
44
|
+
|
45
|
+
|
46
|
+
environment = ReinforcementEnvironment(
|
47
|
+
self.analysis_context.filepath,
|
48
|
+
function_metadata,
|
49
|
+
function_test_cases,
|
50
|
+
state=StatementCoverageState(None)
|
51
|
+
)
|
52
|
+
environment.state = StatementCoverageState(environment)
|
53
|
+
agent = ReinforcementAgent(
|
54
|
+
self.analysis_context.filepath,
|
55
|
+
environment,
|
56
|
+
function_test_cases,
|
57
|
+
q_table
|
58
|
+
)
|
59
|
+
episodes = 10 if self.mode == "train" else 1
|
60
|
+
for _ in range(episodes):
|
61
|
+
if self.mode == "train":
|
62
|
+
new_test_cases = agent.do_q_learning()
|
63
|
+
else:
|
64
|
+
new_test_cases = agent.collect_test_cases()
|
65
|
+
function_test_cases.extend(new_test_cases)
|
66
|
+
|
67
|
+
seen = set()
|
68
|
+
unique_test_cases = []
|
69
|
+
for case in function_test_cases:
|
70
|
+
case_inputs = tuple(case.inputs) if isinstance(case.inputs, list) else case.inputs
|
71
|
+
case_key = (case.func_name, case_inputs)
|
72
|
+
if case_key not in seen:
|
73
|
+
seen.add(case_key)
|
74
|
+
unique_test_cases.append(case)
|
75
|
+
|
76
|
+
AnalysisService._save_q_table(q_table)
|
77
|
+
return unique_test_cases
|
@@ -11,12 +11,6 @@ class TestCaseAnalyzerContext:
|
|
11
11
|
self._test_case_analyzer = test_case_analyzer
|
12
12
|
self._analysis_context = analysis_context
|
13
13
|
self._test_cases = []
|
14
|
-
|
15
|
-
# TODO: GET RID OF THIS STUPID METHOD IT IS POINTLESS
|
16
|
-
# JUST CALL INSIDE ANALYZER_SERVICE
|
17
|
-
def do_logic(self) -> List[TestCase]:
|
18
|
-
"""Run the analysis process"""
|
19
|
-
self.do_strategy(20)
|
20
14
|
|
21
15
|
def do_strategy(self, time_limit=None) -> List[TestCase]:
|
22
16
|
"""Execute the analysis strategy for all functions with an optional time limit"""
|