testgenie-py 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. testgen/analyzer/ast_analyzer.py +2 -11
  2. testgen/analyzer/fuzz_analyzer.py +1 -6
  3. testgen/analyzer/random_feedback_analyzer.py +20 -293
  4. testgen/analyzer/reinforcement_analyzer.py +59 -57
  5. testgen/analyzer/test_case_analyzer_context.py +0 -6
  6. testgen/controller/cli_controller.py +35 -29
  7. testgen/controller/docker_controller.py +1 -0
  8. testgen/db/dao.py +68 -0
  9. testgen/db/dao_impl.py +226 -0
  10. testgen/{sqlite → db}/db.py +15 -6
  11. testgen/generator/pytest_generator.py +2 -10
  12. testgen/generator/unit_test_generator.py +2 -11
  13. testgen/main.py +1 -3
  14. testgen/models/coverage_data.py +56 -0
  15. testgen/models/db_test_case.py +65 -0
  16. testgen/models/function.py +56 -0
  17. testgen/models/function_metadata.py +11 -1
  18. testgen/models/generator_context.py +30 -3
  19. testgen/models/source_file.py +29 -0
  20. testgen/models/test_result.py +38 -0
  21. testgen/models/test_suite.py +20 -0
  22. testgen/reinforcement/agent.py +1 -27
  23. testgen/reinforcement/environment.py +11 -93
  24. testgen/reinforcement/statement_coverage_state.py +5 -4
  25. testgen/service/analysis_service.py +31 -22
  26. testgen/service/cfg_service.py +3 -1
  27. testgen/service/coverage_service.py +115 -0
  28. testgen/service/db_service.py +140 -0
  29. testgen/service/generator_service.py +77 -20
  30. testgen/service/logging_service.py +2 -2
  31. testgen/service/service.py +62 -231
  32. testgen/service/test_executor_service.py +145 -0
  33. testgen/util/coverage_utils.py +38 -116
  34. testgen/util/coverage_visualizer.py +10 -9
  35. testgen/util/file_utils.py +10 -111
  36. testgen/util/randomizer.py +0 -26
  37. testgen/util/utils.py +197 -38
  38. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/METADATA +1 -1
  39. testgenie_py-0.3.9.dist-info/RECORD +72 -0
  40. testgen/inspector/inspector.py +0 -59
  41. testgen/presentation/__init__.py +0 -0
  42. testgen/presentation/cli_view.py +0 -12
  43. testgen/sqlite/__init__.py +0 -0
  44. testgen/sqlite/db_service.py +0 -239
  45. testgen/testgen.db +0 -0
  46. testgenie_py-0.3.7.dist-info/RECORD +0 -67
  47. /testgen/{inspector → db}/__init__.py +0 -0
  48. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/WHEEL +0 -0
  49. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/entry_points.txt +0 -0
@@ -53,17 +53,8 @@ class ASTAnalyzer(TestCaseAnalyzerStrategy, ABC):
53
53
 
54
54
  if not input_exists:
55
55
  try:
56
- generated_file = function_metadata.filename
57
- module = testgen.util.file_utils.load_module(generated_file)
58
-
59
- if function_metadata.class_name:
60
- cls = getattr(module, function_metadata.class_name)
61
- instance = cls()
62
- func = getattr(instance, func_name)
63
- output = func(*inputs)
64
- else:
65
- func = getattr(module, func_name)
66
- output = func(*inputs)
56
+ func = function_metadata.func
57
+ output = func(*inputs)
67
58
 
68
59
  except Exception as e:
69
60
  print(f"Error executing function: {e}")
@@ -29,13 +29,8 @@ class FuzzAnalyzer(TestCaseAnalyzerStrategy, ABC):
29
29
  else:
30
30
  raise ValueError("Module not set in function metadata. Cannot perform fuzzing without a module.")
31
31
 
32
- class_name = function_metadata.class_name if function_metadata.class_name else None
33
32
  try:
34
- if not class_name is None:
35
- cls = getattr(module, class_name, None)
36
- func = getattr(cls(), function_metadata.function_name, None) if cls else None
37
- else:
38
- func = getattr(module, function_metadata.function_name, None)
33
+ func = function_metadata.func
39
34
  if func:
40
35
  return self.run_fuzzing(func, function_metadata.function_name, function_metadata.params, module, 10)
41
36
  except Exception as e:
@@ -1,13 +1,7 @@
1
- import ast
2
- import importlib
3
1
  import random
4
2
  import time
5
3
  import traceback
6
4
  from typing import List, Dict, Set
7
- import z3
8
-
9
- import testgen.util.randomizer
10
- import testgen.util.utils as utils
11
5
  import testgen.util.coverage_utils as coverage_utils
12
6
  from testgen.analyzer.contracts.contract import Contract
13
7
  from testgen.analyzer.contracts.no_exception_contract import NoExceptionContract
@@ -17,8 +11,6 @@ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
17
11
  from abc import ABC
18
12
 
19
13
  from testgen.models.function_metadata import FunctionMetadata
20
- from testgen.util.z3_utils.constraint_extractor import extract_branch_conditions
21
- from testgen.util.z3_utils.ast_to_z3 import ast_to_z3_constraint
22
14
 
23
15
 
24
16
  # Citation in which this method and algorithm were taken from:
@@ -42,26 +34,12 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
42
34
 
43
35
  try:
44
36
  param_values = self.generate_random_inputs(function_metadata.params)
45
- module = self.analysis_context.module
46
37
  func_name = function_metadata.function_name
38
+ function = function_metadata.func
47
39
 
48
- if self._analysis_context.class_name:
49
- cls = getattr(module, self._analysis_context.class_name)
50
- obj = cls()
51
- function = getattr(obj, func_name)
52
- else:
53
- function = getattr(module, func_name)
54
-
55
- import inspect
56
- sig = inspect.signature(function)
57
- param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
40
+ param_names = function_metadata.params.keys()
58
41
 
59
- ordered_args = []
60
- for name in param_names:
61
- if name in param_values:
62
- ordered_args.append(param_values[name])
63
- else:
64
- ordered_args.append(None)
42
+ ordered_args = [param_values.get(name, None) for name in param_names]
65
43
 
66
44
  result = function(*ordered_args)
67
45
  test_case = TestCase(func_name, tuple(ordered_args), result)
@@ -92,14 +70,16 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
92
70
  if func.function_name not in self.covered_lines:
93
71
  self.covered_lines[func.function_name] = set()
94
72
 
95
- for test_case in [tc for tc in self.test_cases if tc.func_name == func.function_name]:
96
- analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name,
97
- func.function_name, test_case.inputs)
73
+ test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
74
+
75
+ for test_case in test_cases:
76
+ analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, func, test_case.inputs)
98
77
  covered = coverage_utils.get_list_of_covered_statements(analysis)
99
78
  self.covered_lines[func.function_name].update(covered)
100
79
  self.logger.debug(f"Covered lines for {func.function_name}: {self.covered_lines[func.function_name]}")
101
80
 
102
- executable_statements = set(self.get_all_executable_statements(func))
81
+
82
+ executable_statements = set(coverage_utils.get_all_executable_statements(self._analysis_context.filepath, func, test_cases))
103
83
  self.logger.debug(f"Executable statements for {func.function_name}: {executable_statements}")
104
84
 
105
85
  return self.covered_lines[func.function_name] == executable_statements
@@ -107,34 +87,14 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
107
87
  def execute_sequence(self, sequence, contracts: List[Contract]):
108
88
  """Execute a sequence and check contract violations"""
109
89
  func_name, args_dict = sequence
110
- args = tuple(args_dict.values()) # Convert dict values to tuple
111
90
 
112
91
  try:
113
92
  # Use module from analysis context if available
114
- module = self.analysis_context.module
115
-
116
- if self._analysis_context.class_name:
117
- cls = getattr(module, self._analysis_context.class_name, None)
118
- if cls is None:
119
- raise AttributeError(f"Class '{self._analysis_context.class_name}' not found")
120
- obj = cls() # Instantiate the class
121
- func = getattr(obj, func_name, None)
122
-
123
- import inspect
124
- sig = inspect.signature(func)
125
- param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
126
- else:
127
- func = getattr(module, func_name, None)
128
-
129
- import inspect
130
- sig = inspect.signature(func)
131
- param_names = [p.name for p in sig.parameters.values()]
93
+ function_metadata = self.get_function_metadata(func_name)
94
+ func = function_metadata.func
95
+ param_names = function_metadata.params.keys()
132
96
 
133
- # Create ordered arguments based on function signature
134
- ordered_args = []
135
- for name in param_names:
136
- if name in args_dict:
137
- ordered_args.append(args_dict[name])
97
+ ordered_args = [args_dict.get(name, None) for name in param_names]
138
98
 
139
99
  # Check preconditions
140
100
  for contract in contracts:
@@ -159,26 +119,17 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
159
119
  return output, True
160
120
 
161
121
  return output, False
162
-
122
+
123
+ def get_function_metadata(self, func_name: str) -> FunctionMetadata | None:
124
+ for function_data in self._analysis_context.function_data:
125
+ if function_data.function_name == func_name:
126
+ return function_data
127
+ return None
163
128
 
164
129
  # TODO: Currently only getting random vals of primitives, extend to sequences
165
130
  def random_seqs_and_vals(self, param_types, non_error_seqs=None):
166
131
  return self.generate_random_inputs(param_types)
167
132
 
168
- @staticmethod
169
- def extract_parameter_types(func_node):
170
- """Extract parameter types from a function node."""
171
- param_types = {}
172
- for arg in func_node.args.args:
173
- param_name = arg.arg
174
- if arg.annotation:
175
- param_type = ast.unparse(arg.annotation)
176
- param_types[param_name] = param_type
177
- else:
178
- if param_name != 'self':
179
- param_types[param_name] = None
180
- return param_types
181
-
182
133
  @staticmethod
183
134
  def generate_random_inputs(param_types):
184
135
  """Generate inputs for fuzzing based on parameter types."""
@@ -294,228 +245,4 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
294
245
  break
295
246
 
296
247
  self.test_cases.sort(key=lambda tc: tc.func_name)
297
- return error_seqs, non_error_seqs
298
-
299
- def get_all_executable_statements(self, func: FunctionMetadata):
300
- import ast
301
-
302
- test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
303
-
304
- if not test_cases:
305
- print("Warning: No test cases available to determine executable statements")
306
- from testgen.util.randomizer import new_random_test_case
307
- temp_case = new_random_test_case(self._analysis_context.filepath, func.func_def)
308
- analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name,
309
- temp_case.inputs)
310
- else:
311
- analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name, test_cases[0].inputs)
312
-
313
- executable_lines = list(analysis[1])
314
-
315
- with open(self._analysis_context.filepath, 'r') as f:
316
- source = f.read()
317
-
318
- tree = ast.parse(source)
319
-
320
- for node in ast.walk(tree):
321
- if isinstance(node, ast.FunctionDef) and node.name == func.func_def.name:
322
- for if_node in ast.walk(node):
323
- if isinstance(if_node, ast.If) and if_node.orelse:
324
- if isinstance(if_node.orelse[0], ast.If):
325
- continue
326
- else_line = if_node.orelse[0].lineno - 1
327
-
328
- with open(self._analysis_context.filepath, 'r') as f:
329
- lines = f.readlines()
330
- if else_line <= len(lines):
331
- line_content = lines[else_line - 1].strip()
332
- if line_content == "else:":
333
- if else_line not in executable_lines:
334
- executable_lines.append(else_line)
335
-
336
- return sorted(executable_lines)
337
-
338
- """
339
- def collect_test_cases_with_z3(self, function_metadata: FunctionMetadata) -> List[TestCase]:
340
- test_cases = []
341
-
342
- z3_test_cases = self.generate_z3_test_cases(function_metadata)
343
- if z3_test_cases:
344
- test_cases.extend(z3_test_cases)
345
-
346
- if not test_cases:
347
- test_cases = self.generate_sequences_new()[1]
348
-
349
- self.test_cases = test_cases
350
- return test_cases
351
-
352
- def generate_z3_test_cases(self, function_metadata: FunctionMetadata) -> List[TestCase]:
353
- test_cases = []
354
-
355
- branch_conditions, param_types = extract_branch_conditions(function_metadata.func_def)
356
-
357
- if not branch_conditions:
358
- random_inputs = self.generate_random_inputs(function_metadata.params)
359
- try:
360
- module = self.analysis_context.module
361
- func_name = function_metadata.function_name
362
-
363
- if self._analysis_context.class_name:
364
- cls = getattr(module, self._analysis_context.class_name)
365
- obj = cls()
366
- func = getattr(obj, func_name)
367
- ordered_args = self._order_arguments(func, random_inputs)
368
- output = func(*ordered_args)
369
- else:
370
- func = getattr(module, func_name)
371
- ordered_args = self._order_arguments(func, random_inputs)
372
- output = func(*ordered_args)
373
-
374
- test_cases.append(TestCase(func_name, tuple(ordered_args), output))
375
- except Exception as e:
376
- print(f"Error executing function with random inputs: {e}")
377
-
378
- return test_cases
379
-
380
- for branch_condition in branch_conditions:
381
- try:
382
- z3_expr, z3_vars = ast_to_z3_constraint(branch_condition, function_metadata.params)
383
-
384
- solver = z3.Solver()
385
- solver.add(z3_expr)
386
-
387
- neg_solver = z3.Solver()
388
- neg_solver.add(z3.Not(z3_expr))
389
-
390
- for current_solver in [solver, neg_solver]:
391
- if current_solver.check() == z3.sat:
392
- model = current_solver.model()
393
-
394
- param_values = self._extract_z3_solution(model, z3_vars, function_metadata.params)
395
-
396
- ordered_params = self._order_parameters(function_metadata.func_def, param_values)
397
-
398
- try:
399
- module = self.analysis_context.module
400
- func_name = function_metadata.function_name
401
-
402
- if self._analysis_context.class_name:
403
- cls = getattr(module, self._analysis_context.class_name)
404
- obj = cls()
405
- func = getattr(obj, func_name)
406
- else:
407
- func = getattr(module, func_name)
408
-
409
- result = func(*ordered_params)
410
- test_cases.append(TestCase(func_name, tuple(ordered_params), result))
411
- except Exception as e:
412
- print(f"Error executing function with Z3 solution: {e}")
413
- self._add_random_test_case(function_metadata, test_cases)
414
- else:
415
- self._add_random_test_case(function_metadata, test_cases)
416
-
417
- except Exception as e:
418
- print(f"Error processing branch condition with Z3: {e}")
419
- self._add_random_test_case(function_metadata, test_cases)
420
-
421
- return test_cases
422
-
423
- def _extract_z3_solution(self, model, z3_vars, param_types):
424
- param_values = {}
425
-
426
- for var_name, z3_var in z3_vars.items():
427
- if var_name in param_types:
428
- try:
429
- model_value = model.evaluate(z3_var)
430
-
431
- if param_types[var_name] == "int":
432
- param_values[var_name] = model_value.as_long()
433
- elif param_types[var_name] == "float":
434
- param_values[var_name] = float(model_value.as_decimal(10))
435
- elif param_types[var_name] == "bool":
436
- param_values[var_name] = z3.is_true(model_value)
437
- elif param_types[var_name] == "str":
438
- str_val = str(model_value)
439
- if str_val.startswith('"') and str_val.endswith('"'):
440
- str_val = str_val[1:-1]
441
- param_values[var_name] = str_val
442
- else:
443
- # Default to int for unrecognized types
444
- param_values[var_name] = model_value.as_long()
445
- except Exception as e:
446
- print(f"Couldn't get {var_name} from model: {e}")
447
- # Use default values for parameters not in the model
448
- if param_types[var_name] == "int":
449
- param_values[var_name] = 0
450
- elif param_types[var_name] == "float":
451
- param_values[var_name] = 0.0
452
- elif param_types[var_name] == "bool":
453
- param_values[var_name] = False
454
- elif param_types[var_name] == "str":
455
- param_values[var_name] = ""
456
- else:
457
- param_values[var_name] = None
458
-
459
- return param_values
460
-
461
- def _order_parameters(self, func_node, param_values):
462
- ordered_params = []
463
-
464
- for arg in func_node.args.args:
465
- arg_name = arg.arg
466
- if arg_name == 'self': # Skip self parameter
467
- continue
468
- if arg_name in param_values:
469
- ordered_params.append(param_values[arg_name])
470
- else:
471
- # Default value handling if parameter not in solution
472
- if arg.annotation and hasattr(arg.annotation, 'id'):
473
- if arg.annotation.id == 'int':
474
- ordered_params.append(0)
475
- elif arg.annotation.id == 'float':
476
- ordered_params.append(0.0)
477
- elif arg.annotation.id == 'bool':
478
- ordered_params.append(False)
479
- elif arg.annotation.id == 'str':
480
- ordered_params.append('')
481
- else:
482
- ordered_params.append(None)
483
- else:
484
- ordered_params.append(None)
485
-
486
- return ordered_params
487
-
488
- def _order_arguments(self, func, args_dict):
489
- import inspect
490
- sig = inspect.signature(func)
491
- param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
492
-
493
- ordered_args = []
494
- for name in param_names:
495
- if name in args_dict:
496
- ordered_args.append(args_dict[name])
497
- else:
498
- ordered_args.append(None) # Default to None if missing
499
-
500
- return ordered_args
501
-
502
- def _add_random_test_case(self, function_metadata, test_cases):
503
- random_inputs = self.generate_random_inputs(function_metadata.params)
504
- try:
505
- module = self.analysis_context.module
506
- func_name = function_metadata.function_name
507
-
508
- if self._analysis_context.class_name:
509
- cls = getattr(module, self._analysis_context.class_name)
510
- obj = cls()
511
- func = getattr(obj, func_name)
512
- else:
513
- func = getattr(module, func_name)
514
-
515
- ordered_args = self._order_arguments(func, random_inputs)
516
-
517
- output = func(*ordered_args)
518
- test_cases.append(TestCase(func_name, tuple(ordered_args), output))
519
- except Exception as e:
520
- print(f"Error executing function with random inputs: {e}")
521
- """
248
+ return error_seqs, non_error_seqs
@@ -5,6 +5,7 @@ import random
5
5
  from typing import List
6
6
 
7
7
  import testgen.util.randomizer
8
+ from testgen.models.function_metadata import FunctionMetadata
8
9
  from testgen.models.test_case import TestCase
9
10
  from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
10
11
  from testgen.reinforcement.environment import ReinforcementEnvironment
@@ -16,60 +17,61 @@ from testgen.reinforcement.environment import ReinforcementEnvironment
16
17
  # Actions: Create new test case, combine test cases, delete test cases
17
18
  # Rewards:
18
19
 
19
- # Maybe consider the state as number of branches covered in a function possibly considering it the same state
20
-
21
- class ReinforcementAnalyzer(TestCaseAnalyzerStrategy, ABC):
22
- def __init__(self, module, class_name: str, env: ReinforcementEnvironment):
23
- super().__init__(module, class_name)
24
- self.env = env # includes file name, module/fut, coverage, and test cases
25
- self.q_table = {} # Dictionary of key: state, action pairs and value: q-value
26
- self.actions = ["add", "merge", "remove"] # three possible actions
27
-
28
- def collect_test_cases(self, func_node: ast.FunctionDef) -> List[TestCase]:
29
- self.env.test_cases.append(testgen.util.randomizer.new_random_test_case(f"{self._module.__name__}.py", func_node))
30
- return self.env.test_cases
31
-
32
- def set_env(self, env: ReinforcementEnvironment):
33
- self.env = env
34
-
35
- def refine_test_cases(self, func: ast.FunctionDef) -> List[TestCase]:
36
- state = self.env.get_state() # Should return state as Tuple(List[TestCase], coverage)
37
-
38
- if not isinstance(state, tuple) or len(state) != 2:
39
- raise ValueError(f"Expected state to be a tuple (test_cases, coverage_score), but got: {state}")
40
-
41
- action = self.choose_action(state)
42
- new_state, reward = self.env.step(action)
43
-
44
- if not isinstance(new_state, tuple) or len(new_state) != 2:
45
- raise ValueError(f"Expected new_state to be a tuple (test_cases, coverage_score), but got: {new_state}")
46
-
47
- print(f"AFTER NEW STATE, REWARD: {reward}")
48
-
49
- self.update_q_table(state, action, reward, new_state)
50
-
51
- return self.env.test_cases
52
-
53
- def choose_action(self, state):
54
- choice = random.choice(["EXPLORATION", "EXPLOITATION"])
55
- test_cases, coverage_score = state # Unpack state properly
56
- state_key = (tuple((tc.func_name, tc.inputs, tc.expected) for tc in test_cases), coverage_score)
57
-
58
- if choice == "EXPLORATION":
59
- return random.choice(self.actions)
60
- else:
61
- # Is going to try to pick the highest value in the q_table with the state_key and action pair
62
- # Is probably always going to be 0 unless we have the same exact test cases and coverage as represented in the state key
63
- return max(self.actions, key=lambda action: self.q_table.get((state_key, action), 0), default=random.choice(self.actions))
64
-
65
- def update_q_table(self, state, action, reward, next_state):
66
- test_cases, coverage_score = state
67
- next_test_cases, next_coverage_score = next_state
68
-
69
- state_key = (tuple((tc.func_name, tc.inputs, tc.expected) for tc in test_cases), coverage_score)
70
- next_state_key = (tuple((tc.func_name, tc.inputs, tc.expected) for tc in next_test_cases), next_coverage_score)
71
-
72
- print("HERE UPDATE TABLE")
73
- old_q = self.q_table.get((state_key, action), 0)
74
- future_q = max(self.q_table.get((next_state_key, a), 0) for a in self.actions)
75
- self.q_table[(state_key, action)] = old_q * (reward + future_q)
20
+ from typing import List, Optional
21
+ from testgen.models.test_case import TestCase
22
+ from testgen.models.analysis_context import AnalysisContext
23
+ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
24
+ from testgen.reinforcement.agent import ReinforcementAgent
25
+ from testgen.reinforcement.environment import ReinforcementEnvironment
26
+ from testgen.reinforcement.statement_coverage_state import StatementCoverageState
27
+
28
+
29
+ class ReinforcementAnalyzer(TestCaseAnalyzerStrategy):
30
+ def __init__(self, analysis_context: AnalysisContext, mode: str = "train"):
31
+ super().__init__(analysis_context)
32
+ self.analysis_context = analysis_context
33
+ self.mode = mode
34
+
35
+ def collect_test_cases(self, function_metadata: FunctionMetadata):
36
+ # Implement or delegate as needed
37
+ return self.analyze(function_metadata)
38
+
39
+ def analyze(self, function_metadata: FunctionMetadata) -> List[TestCase]:
40
+ from testgen.service.analysis_service import AnalysisService
41
+
42
+ q_table = AnalysisService._load_q_table()
43
+ function_test_cases: List[TestCase] = []
44
+
45
+
46
+ environment = ReinforcementEnvironment(
47
+ self.analysis_context.filepath,
48
+ function_metadata,
49
+ function_test_cases,
50
+ state=StatementCoverageState(None)
51
+ )
52
+ environment.state = StatementCoverageState(environment)
53
+ agent = ReinforcementAgent(
54
+ self.analysis_context.filepath,
55
+ environment,
56
+ function_test_cases,
57
+ q_table
58
+ )
59
+ episodes = 10 if self.mode == "train" else 1
60
+ for _ in range(episodes):
61
+ if self.mode == "train":
62
+ new_test_cases = agent.do_q_learning()
63
+ else:
64
+ new_test_cases = agent.collect_test_cases()
65
+ function_test_cases.extend(new_test_cases)
66
+
67
+ seen = set()
68
+ unique_test_cases = []
69
+ for case in function_test_cases:
70
+ case_inputs = tuple(case.inputs) if isinstance(case.inputs, list) else case.inputs
71
+ case_key = (case.func_name, case_inputs)
72
+ if case_key not in seen:
73
+ seen.add(case_key)
74
+ unique_test_cases.append(case)
75
+
76
+ AnalysisService._save_q_table(q_table)
77
+ return unique_test_cases
@@ -11,12 +11,6 @@ class TestCaseAnalyzerContext:
11
11
  self._test_case_analyzer = test_case_analyzer
12
12
  self._analysis_context = analysis_context
13
13
  self._test_cases = []
14
-
15
- # TODO: GET RID OF THIS STUPID METHOD IT IS POINTLESS
16
- # JUST CALL INSIDE ANALYZER_SERVICE
17
- def do_logic(self) -> List[TestCase]:
18
- """Run the analysis process"""
19
- self.do_strategy(20)
20
14
 
21
15
  def do_strategy(self, time_limit=None) -> List[TestCase]:
22
16
  """Execute the analysis strategy for all functions with an optional time limit"""