testgenie-py 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/ast_analyzer.py +2 -11
 - testgen/analyzer/fuzz_analyzer.py +1 -6
 - testgen/analyzer/random_feedback_analyzer.py +20 -293
 - testgen/analyzer/reinforcement_analyzer.py +59 -57
 - testgen/analyzer/test_case_analyzer_context.py +0 -6
 - testgen/controller/cli_controller.py +35 -29
 - testgen/controller/docker_controller.py +1 -0
 - testgen/db/dao.py +68 -0
 - testgen/db/dao_impl.py +226 -0
 - testgen/{sqlite → db}/db.py +15 -6
 - testgen/generator/pytest_generator.py +2 -10
 - testgen/generator/unit_test_generator.py +2 -11
 - testgen/main.py +1 -3
 - testgen/models/coverage_data.py +56 -0
 - testgen/models/db_test_case.py +65 -0
 - testgen/models/function.py +56 -0
 - testgen/models/function_metadata.py +11 -1
 - testgen/models/generator_context.py +30 -3
 - testgen/models/source_file.py +29 -0
 - testgen/models/test_result.py +38 -0
 - testgen/models/test_suite.py +20 -0
 - testgen/reinforcement/agent.py +1 -27
 - testgen/reinforcement/environment.py +11 -93
 - testgen/reinforcement/statement_coverage_state.py +5 -4
 - testgen/service/analysis_service.py +31 -22
 - testgen/service/cfg_service.py +3 -1
 - testgen/service/coverage_service.py +115 -0
 - testgen/service/db_service.py +140 -0
 - testgen/service/generator_service.py +77 -20
 - testgen/service/logging_service.py +2 -2
 - testgen/service/service.py +62 -231
 - testgen/service/test_executor_service.py +145 -0
 - testgen/util/coverage_utils.py +38 -116
 - testgen/util/coverage_visualizer.py +10 -9
 - testgen/util/file_utils.py +10 -111
 - testgen/util/randomizer.py +0 -26
 - testgen/util/utils.py +197 -38
 - {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/METADATA +1 -1
 - testgenie_py-0.3.9.dist-info/RECORD +72 -0
 - testgen/inspector/inspector.py +0 -59
 - testgen/presentation/__init__.py +0 -0
 - testgen/presentation/cli_view.py +0 -12
 - testgen/sqlite/__init__.py +0 -0
 - testgen/sqlite/db_service.py +0 -239
 - testgen/testgen.db +0 -0
 - testgenie_py-0.3.7.dist-info/RECORD +0 -67
 - /testgen/{inspector → db}/__init__.py +0 -0
 - {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/WHEEL +0 -0
 - {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/entry_points.txt +0 -0
 
| 
         @@ -0,0 +1,65 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            class DBTestCase:
         
     | 
| 
      
 2 
     | 
    
         
            +
                def __init__(self, expected_output, inputs, test_function: str, last_run_time, test_method_type: int, test_suite_id: int, function_id: int):
         
     | 
| 
      
 3 
     | 
    
         
            +
                    self._expected_output = expected_output
         
     | 
| 
      
 4 
     | 
    
         
            +
                    self._inputs = inputs
         
     | 
| 
      
 5 
     | 
    
         
            +
                    self._test_function = test_function
         
     | 
| 
      
 6 
     | 
    
         
            +
                    self._last_run_time = last_run_time
         
     | 
| 
      
 7 
     | 
    
         
            +
                    self._test_method_type = test_method_type
         
     | 
| 
      
 8 
     | 
    
         
            +
                    self._test_suite_id = test_suite_id
         
     | 
| 
      
 9 
     | 
    
         
            +
                    self._function_id = function_id
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
                @property
         
     | 
| 
      
 12 
     | 
    
         
            +
                def expected_output(self):
         
     | 
| 
      
 13 
     | 
    
         
            +
                    return self._expected_output
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
                @expected_output.setter
         
     | 
| 
      
 16 
     | 
    
         
            +
                def expected_output(self, value):
         
     | 
| 
      
 17 
     | 
    
         
            +
                    self._expected_output = value
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                @property
         
     | 
| 
      
 20 
     | 
    
         
            +
                def inputs(self):
         
     | 
| 
      
 21 
     | 
    
         
            +
                    return self._inputs
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
                @inputs.setter
         
     | 
| 
      
 24 
     | 
    
         
            +
                def inputs(self, value):
         
     | 
| 
      
 25 
     | 
    
         
            +
                    self._inputs = value
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                @property
         
     | 
| 
      
 28 
     | 
    
         
            +
                def test_function(self) -> str:
         
     | 
| 
      
 29 
     | 
    
         
            +
                    return self._test_function
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                @test_function.setter
         
     | 
| 
      
 32 
     | 
    
         
            +
                def test_function(self, value: str) -> None:
         
     | 
| 
      
 33 
     | 
    
         
            +
                    self._test_function = value
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                @property
         
     | 
| 
      
 36 
     | 
    
         
            +
                def last_run_time(self):
         
     | 
| 
      
 37 
     | 
    
         
            +
                    return self._last_run_time
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                @last_run_time.setter
         
     | 
| 
      
 40 
     | 
    
         
            +
                def last_run_time(self, value) -> None:
         
     | 
| 
      
 41 
     | 
    
         
            +
                    self._last_run_time = value
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                @property
         
     | 
| 
      
 44 
     | 
    
         
            +
                def test_method_type(self) -> int:
         
     | 
| 
      
 45 
     | 
    
         
            +
                    return self._test_method_type
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                @test_method_type.setter
         
     | 
| 
      
 48 
     | 
    
         
            +
                def test_method_type(self, value: int) -> None:
         
     | 
| 
      
 49 
     | 
    
         
            +
                    self._test_method_type = value
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                @property
         
     | 
| 
      
 52 
     | 
    
         
            +
                def test_suite_id(self) -> int:
         
     | 
| 
      
 53 
     | 
    
         
            +
                    return self._test_suite_id
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                @test_suite_id.setter
         
     | 
| 
      
 56 
     | 
    
         
            +
                def test_suite_id(self, value: int) -> None:
         
     | 
| 
      
 57 
     | 
    
         
            +
                    self._test_suite_id = value
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                @property
         
     | 
| 
      
 60 
     | 
    
         
            +
                def function_id(self) -> int:
         
     | 
| 
      
 61 
     | 
    
         
            +
                    return self._function_id
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                @function_id.setter
         
     | 
| 
      
 64 
     | 
    
         
            +
                def function_id(self, value: int) -> None:
         
     | 
| 
      
 65 
     | 
    
         
            +
                    self._function_id = value
         
     | 
| 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            class Function:
         
     | 
| 
      
 2 
     | 
    
         
            +
                def __init__(self, name: str, params, start_line: int, end_line: int, num_lines: int, source_file_id: int):
         
     | 
| 
      
 3 
     | 
    
         
            +
                    self._name = name
         
     | 
| 
      
 4 
     | 
    
         
            +
                    self._params = params
         
     | 
| 
      
 5 
     | 
    
         
            +
                    self._start_line = start_line
         
     | 
| 
      
 6 
     | 
    
         
            +
                    self._end_line = end_line
         
     | 
| 
      
 7 
     | 
    
         
            +
                    self._num_lines = num_lines
         
     | 
| 
      
 8 
     | 
    
         
            +
                    self._source_file_id = source_file_id
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
                @property
         
     | 
| 
      
 11 
     | 
    
         
            +
                def name(self) -> str:
         
     | 
| 
      
 12 
     | 
    
         
            +
                    return self._name
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
                @name.setter
         
     | 
| 
      
 15 
     | 
    
         
            +
                def name(self, value: str) -> None:
         
     | 
| 
      
 16 
     | 
    
         
            +
                    self._name = value
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                @property
         
     | 
| 
      
 19 
     | 
    
         
            +
                def params(self) -> str:
         
     | 
| 
      
 20 
     | 
    
         
            +
                    return self._params
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                @params.setter
         
     | 
| 
      
 23 
     | 
    
         
            +
                def params(self, value: str) -> None:
         
     | 
| 
      
 24 
     | 
    
         
            +
                    self._params = value
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
                @property
         
     | 
| 
      
 27 
     | 
    
         
            +
                def start_line(self) -> int:
         
     | 
| 
      
 28 
     | 
    
         
            +
                    return self._start_line
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                @start_line.setter
         
     | 
| 
      
 31 
     | 
    
         
            +
                def start_line(self, value: int) -> None:
         
     | 
| 
      
 32 
     | 
    
         
            +
                    self._start_line = value
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                @property
         
     | 
| 
      
 35 
     | 
    
         
            +
                def end_line(self) -> int:
         
     | 
| 
      
 36 
     | 
    
         
            +
                    return self._end_line
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                @end_line.setter
         
     | 
| 
      
 39 
     | 
    
         
            +
                def end_line(self, value: int) -> None:
         
     | 
| 
      
 40 
     | 
    
         
            +
                    self._end_line = value
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                @property
         
     | 
| 
      
 43 
     | 
    
         
            +
                def num_lines(self) -> int:
         
     | 
| 
      
 44 
     | 
    
         
            +
                    return self._num_lines
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
                @num_lines.setter
         
     | 
| 
      
 47 
     | 
    
         
            +
                def num_lines(self, value: int) -> None:
         
     | 
| 
      
 48 
     | 
    
         
            +
                    self._num_lines = value
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                @property
         
     | 
| 
      
 51 
     | 
    
         
            +
                def source_file_id(self) -> int:
         
     | 
| 
      
 52 
     | 
    
         
            +
                    return self._source_file_id
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                @source_file_id.setter
         
     | 
| 
      
 55 
     | 
    
         
            +
                def source_file_id(self, value: int) -> None:
         
     | 
| 
      
 56 
     | 
    
         
            +
                    self._source_file_id = value
         
     | 
| 
         @@ -1,12 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import ast
         
     | 
| 
       2 
2 
     | 
    
         
             
            from types import ModuleType
         
     | 
| 
      
 3 
     | 
    
         
            +
            from typing import Any
         
     | 
| 
       3 
4 
     | 
    
         | 
| 
       4 
5 
     | 
    
         | 
| 
       5 
6 
     | 
    
         
             
            class FunctionMetadata:
         
     | 
| 
       6 
     | 
    
         
            -
                def __init__(self, filename: str, module: ModuleType, class_name: str, function_name: str, func_def: ast.FunctionDef, params: dict):
         
     | 
| 
      
 7 
     | 
    
         
            +
                def __init__(self, filename: str, module: ModuleType, class_name: str, func: Any, function_name: str, func_def: ast.FunctionDef, params: dict):
         
     | 
| 
       7 
8 
     | 
    
         
             
                    self._filename: str = filename
         
     | 
| 
       8 
9 
     | 
    
         
             
                    self._module: ModuleType = module
         
     | 
| 
       9 
10 
     | 
    
         
             
                    self._class_name: str = class_name
         
     | 
| 
      
 11 
     | 
    
         
            +
                    self._func: Any = func
         
     | 
| 
       10 
12 
     | 
    
         
             
                    self._function_name: str = function_name
         
     | 
| 
       11 
13 
     | 
    
         
             
                    self._func_def: ast.FunctionDef = func_def
         
     | 
| 
       12 
14 
     | 
    
         
             
                    self._params: dict = params
         
     | 
| 
         @@ -34,6 +36,14 @@ class FunctionMetadata: 
     | 
|
| 
       34 
36 
     | 
    
         
             
                @class_name.setter
         
     | 
| 
       35 
37 
     | 
    
         
             
                def class_name(self, class_name: str):
         
     | 
| 
       36 
38 
     | 
    
         
             
                    self._class_name = class_name
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                @property
         
     | 
| 
      
 41 
     | 
    
         
            +
                def func(self) -> Any:
         
     | 
| 
      
 42 
     | 
    
         
            +
                    return self._func
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                @func.setter
         
     | 
| 
      
 45 
     | 
    
         
            +
                def func(self, func: Any):
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self._func = func
         
     | 
| 
       37 
47 
     | 
    
         | 
| 
       38 
48 
     | 
    
         
             
                @property
         
     | 
| 
       39 
49 
     | 
    
         
             
                def function_name(self) -> str:
         
     | 
| 
         @@ -1,18 +1,21 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from types import ModuleType
         
     | 
| 
       2 
2 
     | 
    
         
             
            from typing import List
         
     | 
| 
       3 
     | 
    
         
            -
             
     | 
| 
       4 
3 
     | 
    
         
             
            from testgen.models.test_case import TestCase
         
     | 
| 
       5 
4 
     | 
    
         | 
| 
       6 
5 
     | 
    
         | 
| 
       7 
6 
     | 
    
         
             
            class GeneratorContext:
         
     | 
| 
       8 
     | 
    
         
            -
                def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str, 
     | 
| 
      
 7 
     | 
    
         
            +
                def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str,
         
     | 
| 
      
 8 
     | 
    
         
            +
                             test_cases: List[TestCase], is_package: bool, package_name: str, import_path: str):
         
     | 
| 
       9 
9 
     | 
    
         
             
                    self._filepath: str = filepath
         
     | 
| 
       10 
10 
     | 
    
         
             
                    self._filename: str = filename
         
     | 
| 
       11 
11 
     | 
    
         
             
                    self._class_name: str = class_name
         
     | 
| 
       12 
12 
     | 
    
         
             
                    self._module: ModuleType = module
         
     | 
| 
       13 
13 
     | 
    
         
             
                    self._output_path: str = output_path
         
     | 
| 
       14 
14 
     | 
    
         
             
                    self._test_cases: List[TestCase] = test_cases
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
      
 15 
     | 
    
         
            +
                    self._is_package: bool = is_package
         
     | 
| 
      
 16 
     | 
    
         
            +
                    self._package_name: str = package_name
         
     | 
| 
      
 17 
     | 
    
         
            +
                    self._import_path: str = import_path
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
       16 
19 
     | 
    
         
             
                @property
         
     | 
| 
       17 
20 
     | 
    
         
             
                def filepath(self) -> str:
         
     | 
| 
       18 
21 
     | 
    
         
             
                    return self._filepath
         
     | 
| 
         @@ -61,3 +64,27 @@ class GeneratorContext: 
     | 
|
| 
       61 
64 
     | 
    
         
             
                def test_cases(self, value: List[TestCase]) -> None:
         
     | 
| 
       62 
65 
     | 
    
         
             
                    self._test_cases = value
         
     | 
| 
       63 
66 
     | 
    
         | 
| 
      
 67 
     | 
    
         
            +
                @property
         
     | 
| 
      
 68 
     | 
    
         
            +
                def is_package(self) -> bool:
         
     | 
| 
      
 69 
     | 
    
         
            +
                    return self._is_package
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                @is_package.setter
         
     | 
| 
      
 72 
     | 
    
         
            +
                def is_package(self, value: bool) -> None:
         
     | 
| 
      
 73 
     | 
    
         
            +
                    self._is_package = value
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                @property
         
     | 
| 
      
 76 
     | 
    
         
            +
                def package_name(self) -> str:
         
     | 
| 
      
 77 
     | 
    
         
            +
                    return self._package_name
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
                @package_name.setter
         
     | 
| 
      
 80 
     | 
    
         
            +
                def package_name(self, value: str) -> None:
         
     | 
| 
      
 81 
     | 
    
         
            +
                    self._package_name = value
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                @property
         
     | 
| 
      
 84 
     | 
    
         
            +
                def import_path(self) -> str:
         
     | 
| 
      
 85 
     | 
    
         
            +
                    return self._import_path
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                @import_path.setter
         
     | 
| 
      
 88 
     | 
    
         
            +
                def import_path(self, value: str) -> None:
         
     | 
| 
      
 89 
     | 
    
         
            +
                    self._import_path = value
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            class SourceFile:
         
     | 
| 
      
 2 
     | 
    
         
            +
                def __init__(self, path: str, lines_of_code: int, last_modified):
         
     | 
| 
      
 3 
     | 
    
         
            +
                    self._path = path
         
     | 
| 
      
 4 
     | 
    
         
            +
                    self._lines_of_code = lines_of_code
         
     | 
| 
      
 5 
     | 
    
         
            +
                    self._last_modified = last_modified
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
                @property
         
     | 
| 
      
 8 
     | 
    
         
            +
                def path(self) -> str:
         
     | 
| 
      
 9 
     | 
    
         
            +
                    return self._path
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
                @path.setter
         
     | 
| 
      
 12 
     | 
    
         
            +
                def path(self, value: str) -> None:
         
     | 
| 
      
 13 
     | 
    
         
            +
                    self._path = value
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
                @property
         
     | 
| 
      
 16 
     | 
    
         
            +
                def lines_of_code(self) -> int:
         
     | 
| 
      
 17 
     | 
    
         
            +
                    return self._lines_of_code
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                @lines_of_code.setter
         
     | 
| 
      
 20 
     | 
    
         
            +
                def lines_of_code(self, value: int) -> None:
         
     | 
| 
      
 21 
     | 
    
         
            +
                    self._lines_of_code = value
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
                @property
         
     | 
| 
      
 24 
     | 
    
         
            +
                def last_modified(self):
         
     | 
| 
      
 25 
     | 
    
         
            +
                    return self._last_modified
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                @last_modified.setter
         
     | 
| 
      
 28 
     | 
    
         
            +
                def last_modified(self, value) -> None:
         
     | 
| 
      
 29 
     | 
    
         
            +
                    self._last_modified = value
         
     | 
| 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            class TestResult:
         
     | 
| 
      
 2 
     | 
    
         
            +
                def __init__(self, test_case_id: int, status: bool, error: str, execution_time):
         
     | 
| 
      
 3 
     | 
    
         
            +
                    self._test_case_id = test_case_id
         
     | 
| 
      
 4 
     | 
    
         
            +
                    self._status = status
         
     | 
| 
      
 5 
     | 
    
         
            +
                    self._error = error
         
     | 
| 
      
 6 
     | 
    
         
            +
                    self._execution_time = execution_time
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
                @property
         
     | 
| 
      
 9 
     | 
    
         
            +
                def test_case_id(self) -> int:
         
     | 
| 
      
 10 
     | 
    
         
            +
                    return self._test_case_id
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
                @test_case_id.setter
         
     | 
| 
      
 13 
     | 
    
         
            +
                def test_case_id(self, value: int) -> None:
         
     | 
| 
      
 14 
     | 
    
         
            +
                    self._test_case_id = value
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                @property
         
     | 
| 
      
 17 
     | 
    
         
            +
                def status(self) -> bool:
         
     | 
| 
      
 18 
     | 
    
         
            +
                    return self._status
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                @status.setter
         
     | 
| 
      
 21 
     | 
    
         
            +
                def status(self, value: bool) -> None:
         
     | 
| 
      
 22 
     | 
    
         
            +
                    self._status = value
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                @property
         
     | 
| 
      
 25 
     | 
    
         
            +
                def error(self) -> str:
         
     | 
| 
      
 26 
     | 
    
         
            +
                    return self._error
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                @error.setter
         
     | 
| 
      
 29 
     | 
    
         
            +
                def error(self, value: str) -> None:
         
     | 
| 
      
 30 
     | 
    
         
            +
                    self._error = value
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                @property
         
     | 
| 
      
 33 
     | 
    
         
            +
                def execution_time(self):
         
     | 
| 
      
 34 
     | 
    
         
            +
                    return self._execution_time
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                @execution_time.setter
         
     | 
| 
      
 37 
     | 
    
         
            +
                def execution_time(self, value) -> None:
         
     | 
| 
      
 38 
     | 
    
         
            +
                    self._execution_time = value
         
     | 
| 
         @@ -0,0 +1,20 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            class TestSuite:
         
     | 
| 
      
 2 
     | 
    
         
            +
                def __init__(self, name: str, creation_date):
         
     | 
| 
      
 3 
     | 
    
         
            +
                    self._name = name
         
     | 
| 
      
 4 
     | 
    
         
            +
                    self._creation_date = creation_date
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
                @property
         
     | 
| 
      
 7 
     | 
    
         
            +
                def name(self) -> str:
         
     | 
| 
      
 8 
     | 
    
         
            +
                    return self._name
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
                @name.setter
         
     | 
| 
      
 11 
     | 
    
         
            +
                def name(self, value: str) -> None:
         
     | 
| 
      
 12 
     | 
    
         
            +
                    self._name = value
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
                @property
         
     | 
| 
      
 15 
     | 
    
         
            +
                def creation_date(self):
         
     | 
| 
      
 16 
     | 
    
         
            +
                    return self._creation_date
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                @creation_date.setter
         
     | 
| 
      
 19 
     | 
    
         
            +
                def creation_date(self, value) -> None:
         
     | 
| 
      
 20 
     | 
    
         
            +
                    self._creation_date = value
         
     | 
    
        testgen/reinforcement/agent.py
    CHANGED
    
    | 
         @@ -128,30 +128,4 @@ class ReinforcementAgent: 
     | 
|
| 
       128 
128 
     | 
    
         
             
                    print(f"UPDATING Q TABLE FOR STATE: {state}, ACTION: {action} WITH REWARD: {reward}")
         
     | 
| 
       129 
129 
     | 
    
         
             
                    new_q = (1 - self.learning_rate) * current_q + self.learning_rate * (reward + max_next_q)
         
     | 
| 
       130 
130 
     | 
    
         | 
| 
       131 
     | 
    
         
            -
                    self.q_table[(state, action)] = new_q
         
     | 
| 
       132 
     | 
    
         
            -
                    
         
     | 
| 
       133 
     | 
    
         
            -
                    """def optimize_test_suit(self, current_state, executable_statements):
         
     | 
| 
       134 
     | 
    
         
            -
                    # Try to optimize test cases by repeatedly performing remove actions if reached full coverage
         
     | 
| 
       135 
     | 
    
         
            -
                    test_case_count = current_state[1]
         
     | 
| 
       136 
     | 
    
         
            -
                    optimization_attempts = min(10, test_case_count - 1)
         
     | 
| 
       137 
     | 
    
         
            -
             
     | 
| 
       138 
     | 
    
         
            -
                    for _ in range(optimization_attempts):
         
     | 
| 
       139 
     | 
    
         
            -
                        if test_case_count <= 1:
         
     | 
| 
       140 
     | 
    
         
            -
                            break
         
     | 
| 
       141 
     | 
    
         
            -
             
     | 
| 
       142 
     | 
    
         
            -
                        action = "remove"
         
     | 
| 
       143 
     | 
    
         
            -
                        next_state, reward = self.env.step(action)
         
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
                        new_covered = next_state[0]
         
     | 
| 
       146 
     | 
    
         
            -
                        new_uncovered = [stmt for stmt in executable_statements if stmt not in new_covered]
         
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
       148 
     | 
    
         
            -
                        if len(new_uncovered) == 0:
         
     | 
| 
       149 
     | 
    
         
            -
                            current_state = next_state
         
     | 
| 
       150 
     | 
    
         
            -
                            test_case_count = current_state[2]
         
     | 
| 
       151 
     | 
    
         
            -
                            print(f"Optimized to {test_case_count} test cases.")
         
     | 
| 
       152 
     | 
    
         
            -
                        else:
         
     | 
| 
       153 
     | 
    
         
            -
                            # Add a test case back if removing broke coverage
         
     | 
| 
       154 
     | 
    
         
            -
                            self.env.step("add")
         
     | 
| 
       155 
     | 
    
         
            -
                            break
         
     | 
| 
       156 
     | 
    
         
            -
             
     | 
| 
       157 
     | 
    
         
            -
                    return current_state"""
         
     | 
| 
      
 131 
     | 
    
         
            +
                    self.q_table[(state, action)] = new_q
         
     | 
| 
         @@ -4,6 +4,7 @@ from typing import List, Tuple 
     | 
|
| 
       4 
4 
     | 
    
         | 
| 
       5 
5 
     | 
    
         
             
            import coverage
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
      
 7 
     | 
    
         
            +
            from testgen.models.function_metadata import FunctionMetadata
         
     | 
| 
       7 
8 
     | 
    
         
             
            from testgen.service.logging_service import get_logger
         
     | 
| 
       8 
9 
     | 
    
         
             
            import testgen.util.coverage_utils
         
     | 
| 
       9 
10 
     | 
    
         
             
            import testgen.util.file_utils
         
     | 
| 
         @@ -13,11 +14,9 @@ from testgen.models.test_case import TestCase 
     | 
|
| 
       13 
14 
     | 
    
         | 
| 
       14 
15 
     | 
    
         | 
| 
       15 
16 
     | 
    
         
             
            class ReinforcementEnvironment:
         
     | 
| 
       16 
     | 
    
         
            -
                def __init__(self,  
     | 
| 
       17 
     | 
    
         
            -
                    self. 
     | 
| 
       18 
     | 
    
         
            -
                    self. 
     | 
| 
       19 
     | 
    
         
            -
                    self.module = module
         
     | 
| 
       20 
     | 
    
         
            -
                    self.class_name  = class_name
         
     | 
| 
      
 17 
     | 
    
         
            +
                def __init__(self, filepath: str, function_data: FunctionMetadata, initial_test_cases: List[TestCase], state: AbstractState):
         
     | 
| 
      
 18 
     | 
    
         
            +
                    self.filepath = filepath
         
     | 
| 
      
 19 
     | 
    
         
            +
                    self.function_data = function_data
         
     | 
| 
       21 
20 
     | 
    
         
             
                    self.initial_test_cases = initial_test_cases
         
     | 
| 
       22 
21 
     | 
    
         
             
                    self.test_cases = initial_test_cases.copy()
         
     | 
| 
       23 
22 
     | 
    
         
             
                    self.state = state
         
     | 
| 
         @@ -35,13 +34,13 @@ class ReinforcementEnvironment: 
     | 
|
| 
       35 
34 
     | 
    
         | 
| 
       36 
35 
     | 
    
         
             
                    # Execute action
         
     | 
| 
       37 
36 
     | 
    
         
             
                    if action == "add":
         
     | 
| 
       38 
     | 
    
         
            -
                        self.test_cases.append(randomizer.new_random_test_case(self. 
     | 
| 
      
 37 
     | 
    
         
            +
                        self.test_cases.append(randomizer.new_random_test_case(self.filepath, self.function_data.class_name, self.function_data.func_def))
         
     | 
| 
       39 
38 
     | 
    
         
             
                    elif action == "merge" and len(self.test_cases) > 1:
         
     | 
| 
       40 
     | 
    
         
            -
                        self.test_cases.append(randomizer.combine_cases(self.module, self.test_cases))
         
     | 
| 
      
 39 
     | 
    
         
            +
                        self.test_cases.append(randomizer.combine_cases(self.function_data.module, self.test_cases))
         
     | 
| 
       41 
40 
     | 
    
         
             
                    elif action == "remove" and len(self.test_cases) > 1:
         
     | 
| 
       42 
41 
     | 
    
         
             
                        self.test_cases = randomizer.remove_case(self.test_cases)
         
     | 
| 
       43 
42 
     | 
    
         
             
                    elif action == "z3":
         
     | 
| 
       44 
     | 
    
         
            -
                        self.test_cases = randomizer.get_z3_test_cases(self. 
     | 
| 
      
 43 
     | 
    
         
            +
                        self.test_cases = randomizer.get_z3_test_cases(self.filepath, self.function_data.class_name, self.function_data.func_def, self.test_cases)
         
     | 
| 
       45 
44 
     | 
    
         
             
                    else:
         
     | 
| 
       46 
45 
     | 
    
         
             
                        raise ValueError("Invalid action")
         
     | 
| 
       47 
46 
     | 
    
         | 
| 
         @@ -91,81 +90,6 @@ class ReinforcementEnvironment: 
     | 
|
| 
       91 
90 
     | 
    
         | 
| 
       92 
91 
     | 
    
         
             
                    print(f"Final reward {reward}")
         
     | 
| 
       93 
92 
     | 
    
         
             
                    return reward
         
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
     | 
    
         
            -
                
         
     | 
| 
       96 
     | 
    
         
            -
                def get_all_executable_statements(self):
         
     | 
| 
       97 
     | 
    
         
            -
                    """Get all executable statements including else branches"""
         
     | 
| 
       98 
     | 
    
         
            -
                    import ast
         
     | 
| 
       99 
     | 
    
         
            -
             
     | 
| 
       100 
     | 
    
         
            -
                    test_cases = [tc for tc in self.test_cases if tc.func_name == self.fut.name]
         
     | 
| 
       101 
     | 
    
         
            -
             
     | 
| 
       102 
     | 
    
         
            -
                    executable_lines = set()
         
     | 
| 
       103 
     | 
    
         
            -
                    if not test_cases:
         
     | 
| 
       104 
     | 
    
         
            -
                        self.logger.debug("Warning: No test cases available to determine executable statements")
         
     | 
| 
       105 
     | 
    
         
            -
                        from testgen.util.randomizer import new_random_test_case
         
     | 
| 
       106 
     | 
    
         
            -
                        temp_case = new_random_test_case(self.file_name, self.class_name, self.fut)
         
     | 
| 
       107 
     | 
    
         
            -
                        analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, temp_case.inputs)
         
     | 
| 
       108 
     | 
    
         
            -
                        executable_lines.update(analysis[1])  # Add executable lines from coverage analysis
         
     | 
| 
       109 
     | 
    
         
            -
                    else:
         
     | 
| 
       110 
     | 
    
         
            -
                        analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.class_name, self.fut.name, test_cases[0].inputs)
         
     | 
| 
       111 
     | 
    
         
            -
                    
         
     | 
| 
       112 
     | 
    
         
            -
                    executable_lines.update(analysis[1])  # Add executable lines from coverage analysis
         
     | 
| 
       113 
     | 
    
         
            -
                    # Get standard executable lines from coverage.py
         
     | 
| 
       114 
     | 
    
         
            -
                    executable_lines = list(executable_lines)
         
     | 
| 
       115 
     | 
    
         
            -
                    
         
     | 
| 
       116 
     | 
    
         
            -
                    # Parse the source file to find else branches
         
     | 
| 
       117 
     | 
    
         
            -
                    with open(self.file_name, 'r') as f:
         
     | 
| 
       118 
     | 
    
         
            -
                        source = f.read()
         
     | 
| 
       119 
     | 
    
         
            -
                    
         
     | 
| 
       120 
     | 
    
         
            -
                    # Parse the code
         
     | 
| 
       121 
     | 
    
         
            -
                    tree = ast.parse(source)        
         
     | 
| 
       122 
     | 
    
         
            -
                    # Find our specific function
         
     | 
| 
       123 
     | 
    
         
            -
                    for node in ast.walk(tree):
         
     | 
| 
       124 
     | 
    
         
            -
                        if isinstance(node, ast.ClassDef) and node.name == self.class_name:
         
     | 
| 
       125 
     | 
    
         
            -
                            # If we have a class, find the method
         
     | 
| 
       126 
     | 
    
         
            -
                            for method in node.body:
         
     | 
| 
       127 
     | 
    
         
            -
                                if isinstance(method, ast.FunctionDef) and method.name == self.fut.name:
         
     | 
| 
       128 
     | 
    
         
            -
                                    # Find all if statements in this method
         
     | 
| 
       129 
     | 
    
         
            -
                                    for if_node in ast.walk(method):
         
     | 
| 
       130 
     | 
    
         
            -
                                        if isinstance(if_node, ast.If) and if_node.orelse:
         
     | 
| 
       131 
     | 
    
         
            -
                                            # There's an else branch
         
     | 
| 
       132 
     | 
    
         
            -
                                            if isinstance(if_node.orelse[0], ast.If):
         
     | 
| 
       133 
     | 
    
         
            -
                                                # This is an elif - already counted
         
     | 
| 
       134 
     | 
    
         
            -
                                                continue
         
     | 
| 
       135 
     | 
    
         
            -
                                            
         
     | 
| 
       136 
     | 
    
         
            -
                                            # Get the line number of the first statement in the else block
         
     | 
| 
       137 
     | 
    
         
            -
                                            # and subtract 1 to get the 'else:' line
         
     | 
| 
       138 
     | 
    
         
            -
                                            else_line = if_node.orelse[0].lineno - 1
         
     | 
| 
       139 
     | 
    
         
            -
                                            
         
     | 
| 
       140 
     | 
    
         
            -
                                            # Check if this is actually an else line (not a nested if)
         
     | 
| 
       141 
     | 
    
         
            -
                                            with open(self.file_name, 'r') as f:
         
     | 
| 
       142 
     | 
    
         
            -
                                                lines = f.readlines()
         
     | 
| 
       143 
     | 
    
         
            -
                                                if else_line <= len(lines):
         
     | 
| 
       144 
     | 
    
         
            -
                                                    line_content = lines[else_line - 1].strip()
         
     | 
| 
       145 
     | 
    
         
            -
                                                    if line_content == "else:":
         
     | 
| 
       146 
     | 
    
         
            -
                                                        if else_line not in executable_lines:
         
     | 
| 
       147 
     | 
    
         
            -
                                                            executable_lines.append(else_line)
         
     | 
| 
       148 
     | 
    
         
            -
                        if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
         
     | 
| 
       149 
     | 
    
         
            -
                            # Find all if statements in this function
         
     | 
| 
       150 
     | 
    
         
            -
                            for if_node in ast.walk(node):
         
     | 
| 
       151 
     | 
    
         
            -
                                if isinstance(if_node, ast.If) and if_node.orelse:
         
     | 
| 
       152 
     | 
    
         
            -
                                    # There's an else branch
         
     | 
| 
       153 
     | 
    
         
            -
                                    if isinstance(if_node.orelse[0], ast.If):
         
     | 
| 
       154 
     | 
    
         
            -
                                        # This is an elif - already counted
         
     | 
| 
       155 
     | 
    
         
            -
                                        continue
         
     | 
| 
       156 
     | 
    
         
            -
                                    
         
     | 
| 
       157 
     | 
    
         
            -
                                    # Get the line number of the first statement in the else block
         
     | 
| 
       158 
     | 
    
         
            -
                                    # and subtract 1 to get the 'else:' line
         
     | 
| 
       159 
     | 
    
         
            -
                                    else_line = if_node.orelse[0].lineno - 1
         
     | 
| 
       160 
     | 
    
         
            -
                                    
         
     | 
| 
       161 
     | 
    
         
            -
                                    # Check if this is actually an else line (not a nested if)
         
     | 
| 
       162 
     | 
    
         
            -
                                    with open(self.file_name, 'r') as f:
         
     | 
| 
       163 
     | 
    
         
            -
                                        lines = f.readlines()
         
     | 
| 
       164 
     | 
    
         
            -
                                        if else_line <= len(lines):
         
     | 
| 
       165 
     | 
    
         
            -
                                            line_content = lines[else_line - 1].strip()
         
     | 
| 
       166 
     | 
    
         
            -
                                            if line_content == "else:":
         
     | 
| 
       167 
     | 
    
         
            -
                                                if else_line not in executable_lines:
         
     | 
| 
       168 
     | 
    
         
            -
                                                    executable_lines.append(else_line)
         
     | 
| 
       169 
93 
     | 
    
         | 
| 
       170 
94 
     | 
    
         
             
                    return sorted(executable_lines)
         
     | 
| 
       171 
95 
     | 
    
         | 
| 
         @@ -180,13 +104,7 @@ class ReinforcementEnvironment: 
     | 
|
| 
       180 
104 
     | 
    
         
             
                    # Execute all test cases
         
     | 
| 
       181 
105 
     | 
    
         
             
                    for test_case in self.test_cases:
         
     | 
| 
       182 
106 
     | 
    
         
             
                        try:
         
     | 
| 
       183 
     | 
    
         
            -
                             
     | 
| 
       184 
     | 
    
         
            -
                            if self.class_name:
         
     | 
| 
       185 
     | 
    
         
            -
                                class_obj = getattr(module, self.class_name)
         
     | 
| 
       186 
     | 
    
         
            -
                                instance = class_obj()
         
     | 
| 
       187 
     | 
    
         
            -
                                func = getattr(instance, self.fut.name)
         
     | 
| 
       188 
     | 
    
         
            -
                            else:
         
     | 
| 
       189 
     | 
    
         
            -
                                func = getattr(module, self.fut.name)
         
     | 
| 
      
 107 
     | 
    
         
            +
                            func = self.function_data.func
         
     | 
| 
       190 
108 
     | 
    
         
             
                            _ = func(*test_case.inputs)
         
     | 
| 
       191 
109 
     | 
    
         
             
                        except Exception as e:
         
     | 
| 
       192 
110 
     | 
    
         
             
                            import traceback
         
     | 
| 
         @@ -195,7 +113,7 @@ class ReinforcementEnvironment: 
     | 
|
| 
       195 
113 
     | 
    
         
             
                    self.cov.stop()
         
     | 
| 
       196 
114 
     | 
    
         | 
| 
       197 
115 
     | 
    
         
             
                    # Get detailed coverage data including branches
         
     | 
| 
       198 
     | 
    
         
            -
                    file_path = os.path.abspath(self. 
     | 
| 
      
 116 
     | 
    
         
            +
                    file_path = os.path.abspath(self.filepath)
         
     | 
| 
       199 
117 
     | 
    
         
             
                    data = self.cov.get_data()
         
     | 
| 
       200 
118 
     | 
    
         | 
| 
       201 
119 
     | 
    
         
             
                    # Extract function-specific coverage
         
     | 
| 
         @@ -225,13 +143,13 @@ class ReinforcementEnvironment: 
     | 
|
| 
       225 
143 
     | 
    
         
             
                    import ast
         
     | 
| 
       226 
144 
     | 
    
         | 
| 
       227 
145 
     | 
    
         
             
                    try:
         
     | 
| 
       228 
     | 
    
         
            -
                        with open(self. 
     | 
| 
      
 146 
     | 
    
         
            +
                        with open(self.filepath, 'r') as f:
         
     | 
| 
       229 
147 
     | 
    
         
             
                            source = f.read()
         
     | 
| 
       230 
148 
     | 
    
         | 
| 
       231 
149 
     | 
    
         
             
                        tree = ast.parse(source)
         
     | 
| 
       232 
150 
     | 
    
         | 
| 
       233 
151 
     | 
    
         
             
                        for node in ast.walk(tree):
         
     | 
| 
       234 
     | 
    
         
            -
                            if isinstance(node, ast.FunctionDef) and node.name == self. 
     | 
| 
      
 152 
     | 
    
         
            +
                            if isinstance(node, ast.FunctionDef) and node.name == self.function_data.function_name:
         
     | 
| 
       235 
153 
     | 
    
         
             
                                # Find the first line of the function
         
     | 
| 
       236 
154 
     | 
    
         
             
                                start_line = node.lineno
         
     | 
| 
       237 
155 
     | 
    
         | 
| 
         @@ -14,11 +14,11 @@ class StatementCoverageState(AbstractState): 
     | 
|
| 
       14 
14 
     | 
    
         
             
                    """Returns calculated coverage and length of test cases in a tuple"""
         
     | 
| 
       15 
15 
     | 
    
         
             
                    all_covered_statements = set()
         
     | 
| 
       16 
16 
     | 
    
         
             
                    for test_case in self.environment.test_cases:
         
     | 
| 
       17 
     | 
    
         
            -
                        analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment. 
     | 
| 
      
 17 
     | 
    
         
            +
                        analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.filepath, self.environment.function_data, test_case.inputs)
         
     | 
| 
       18 
18 
     | 
    
         
             
                        covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
         
     | 
| 
       19 
19 
     | 
    
         
             
                        all_covered_statements.update(covered)
         
     | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
     | 
    
         
            -
                    executable_statements = self.environment. 
     | 
| 
      
 21 
     | 
    
         
            +
                    executable_statements = testgen.util.coverage_utils.get_all_executable_statements(self.environment.filepath, self.environment.function_data, self.environment.test_cases)
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
       23 
23 
     | 
    
         
             
                    if not executable_statements or executable_statements == 0:
         
     | 
| 
       24 
24 
     | 
    
         
             
                        calc_coverage = 0.0
         
     | 
| 
         @@ -26,10 +26,11 @@ class StatementCoverageState(AbstractState): 
     | 
|
| 
       26 
26 
     | 
    
         
             
                        calc_coverage: float = (len(all_covered_statements) / len(executable_statements)) * 100
         
     | 
| 
       27 
27 
     | 
    
         | 
| 
       28 
28 
     | 
    
         
             
                    self.logger.debug(f"GET STATE ALL COVERED STATEMENTS: {all_covered_statements}")
         
     | 
| 
       29 
     | 
    
         
            -
                    self.logger.debug(f"GET STATE ALL EXECUTABLE STATEMENTS: { 
     | 
| 
      
 29 
     | 
    
         
            +
                    self.logger.debug(f"GET STATE ALL EXECUTABLE STATEMENTS: {executable_statements}")
         
     | 
| 
       30 
30 
     | 
    
         
             
                    self.logger.debug(f"GET STATE FLOAT COVERAGE: {calc_coverage}")
         
     | 
| 
       31 
31 
     | 
    
         | 
| 
       32 
32 
     | 
    
         
             
                    if calc_coverage >= 100:
         
     | 
| 
       33 
     | 
    
         
            -
                        print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment. 
     | 
| 
      
 33 
     | 
    
         
            +
                        print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment.function_data.function_name}!!!!!!!!")
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
       34 
35 
     | 
    
         
             
                    return calc_coverage, len(self.environment.test_cases)
         
     | 
| 
       35 
36 
     | 
    
         | 
| 
         @@ -2,9 +2,10 @@ import inspect 
     | 
|
| 
       2 
2 
     | 
    
         
             
            import ast
         
     | 
| 
       3 
3 
     | 
    
         
             
            import time
         
     | 
| 
       4 
4 
     | 
    
         
             
            from types import ModuleType
         
     | 
| 
       5 
     | 
    
         
            -
            from typing import Dict, List
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import Dict, List, Any
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            import testgen
         
     | 
| 
      
 8 
     | 
    
         
            +
            from testgen.analyzer.reinforcement_analyzer import ReinforcementAnalyzer
         
     | 
| 
       8 
9 
     | 
    
         
             
            from testgen.service.logging_service import get_logger
         
     | 
| 
       9 
10 
     | 
    
         
             
            import testgen.util.file_utils
         
     | 
| 
       10 
11 
     | 
    
         
             
            import testgen.util.file_utils as file_utils
         
     | 
| 
         @@ -37,11 +38,8 @@ class AnalysisService: 
     | 
|
| 
       37 
38 
     | 
    
         | 
| 
       38 
39 
     | 
    
         
             
                def generate_test_cases(self) -> List[TestCase]:
         
     | 
| 
       39 
40 
     | 
    
         
             
                        """Generate test cases using the current strategy."""
         
     | 
| 
       40 
     | 
    
         
            -
                         
     | 
| 
       41 
     | 
    
         
            -
             
     | 
| 
       42 
     | 
    
         
            -
                        else:
         
     | 
| 
       43 
     | 
    
         
            -
                            self.test_case_analyzer_context.do_logic()
         
     | 
| 
       44 
     | 
    
         
            -
                            return self.test_case_analyzer_context.test_cases
         
     | 
| 
      
 41 
     | 
    
         
            +
                        self.test_case_analyzer_context.do_strategy(30)
         
     | 
| 
      
 42 
     | 
    
         
            +
                        return self.test_case_analyzer_context.test_cases
         
     | 
| 
       45 
43 
     | 
    
         | 
| 
       46 
44 
     | 
    
         
             
                def create_analysis_context(self, filepath: str) -> AnalysisContext:
         
     | 
| 
       47 
45 
     | 
    
         
             
                    """Create an analysis context for the given file."""
         
     | 
| 
         @@ -88,6 +86,7 @@ class AnalysisService: 
     | 
|
| 
       88 
86 
     | 
    
         
             
                    mode = mode or self.reinforcement_mode
         
     | 
| 
       89 
87 
     | 
    
         
             
                    module: ModuleType = testgen.util.file_utils.load_module(filepath)
         
     | 
| 
       90 
88 
     | 
    
         
             
                    tree: ast.Module = testgen.util.file_utils.load_and_parse_file_for_tree(filepath)
         
     | 
| 
      
 89 
     | 
    
         
            +
                    list_of_function_data: List[FunctionMetadata] = self.get_function_data(filepath, module, class_name)
         
     | 
| 
       91 
90 
     | 
    
         
             
                    functions: List[ast.FunctionDef] = testgen.util.utils.get_functions(tree)
         
     | 
| 
       92 
91 
     | 
    
         
             
                    self.class_name = class_name
         
     | 
| 
       93 
92 
     | 
    
         
             
                    time_limit: int = 30
         
     | 
| 
         @@ -95,14 +94,14 @@ class AnalysisService: 
     | 
|
| 
       95 
94 
     | 
    
         | 
| 
       96 
95 
     | 
    
         
             
                    q_table = self._load_q_table()
         
     | 
| 
       97 
96 
     | 
    
         | 
| 
       98 
     | 
    
         
            -
                    for function in  
     | 
| 
       99 
     | 
    
         
            -
                        print(f"\nStarting reinforcement learning for function {function. 
     | 
| 
      
 97 
     | 
    
         
            +
                    for function in list_of_function_data:
         
     | 
| 
      
 98 
     | 
    
         
            +
                        print(f"\nStarting reinforcement learning for function {function.function_name}")
         
     | 
| 
       100 
99 
     | 
    
         
             
                        start_time = time.time()
         
     | 
| 
       101 
100 
     | 
    
         
             
                        function_test_cases: List[TestCase] = []
         
     | 
| 
       102 
101 
     | 
    
         
             
                        best_coverage: float = 0.0
         
     | 
| 
       103 
102 
     | 
    
         | 
| 
       104 
103 
     | 
    
         
             
                        # Create environment and agent once per function
         
     | 
| 
       105 
     | 
    
         
            -
                        environment = ReinforcementEnvironment(filepath, function,  
     | 
| 
      
 104 
     | 
    
         
            +
                        environment = ReinforcementEnvironment(filepath, function, function_test_cases, state=StatementCoverageState(None))
         
     | 
| 
       106 
105 
     | 
    
         
             
                        environment.state = StatementCoverageState(environment)
         
     | 
| 
       107 
106 
     | 
    
         | 
| 
       108 
107 
     | 
    
         
             
                        # Create agent with existing Q-table
         
     | 
| 
         @@ -115,10 +114,10 @@ class AnalysisService: 
     | 
|
| 
       115 
114 
     | 
    
         
             
                            new_test_cases = agent.collect_test_cases()
         
     | 
| 
       116 
115 
     | 
    
         
             
                            function_test_cases.extend(new_test_cases)
         
     | 
| 
       117 
116 
     | 
    
         | 
| 
       118 
     | 
    
         
            -
                        print(f"\nNumber of test cases for {function. 
     | 
| 
      
 117 
     | 
    
         
            +
                        print(f"\nNumber of test cases for {function.function_name}: {len(function_test_cases)}")
         
     | 
| 
       119 
118 
     | 
    
         | 
| 
       120 
119 
     | 
    
         
             
                        current_coverage: float = environment.run_tests()
         
     | 
| 
       121 
     | 
    
         
            -
                        print(f"Current coverage: {function. 
     | 
| 
      
 120 
     | 
    
         
            +
                        print(f"Current coverage: {function.function_name}: {current_coverage}")
         
     | 
| 
       122 
121 
     | 
    
         | 
| 
       123 
122 
     | 
    
         
             
                        q_table.update(agent.q_table)
         
     | 
| 
       124 
123 
     | 
    
         | 
| 
         @@ -134,8 +133,8 @@ class AnalysisService: 
     | 
|
| 
       134 
133 
     | 
    
         
             
                                unique_test_cases.append(case)
         
     | 
| 
       135 
134 
     | 
    
         | 
| 
       136 
135 
     | 
    
         
             
                        all_test_cases.extend(unique_test_cases)
         
     | 
| 
       137 
     | 
    
         
            -
                        print(f"Final coverage for {function. 
     | 
| 
       138 
     | 
    
         
            -
                        print(f"Final test cases for {function. 
     | 
| 
      
 136 
     | 
    
         
            +
                        print(f"Final coverage for {function.function_name}: {best_coverage}%")
         
     | 
| 
      
 137 
     | 
    
         
            +
                        print(f"Final test cases for {function.function_name}: {len(unique_test_cases)}")
         
     | 
| 
       139 
138 
     | 
    
         | 
| 
       140 
139 
     | 
    
         
             
                    self._save_q_table(q_table)
         
     | 
| 
       141 
140 
     | 
    
         | 
| 
         @@ -146,12 +145,23 @@ class AnalysisService: 
     | 
|
| 
       146 
145 
     | 
    
         
             
                def _create_function_metadata(self, filename: str, module: ModuleType, class_name: str | None,
         
     | 
| 
       147 
146 
     | 
    
         
             
                                              func_node: ast.FunctionDef) -> FunctionMetadata:
         
     | 
| 
       148 
147 
     | 
    
         
             
                    function_name = func_node.name
         
     | 
| 
       149 
     | 
    
         
            -
             
     | 
| 
      
 148 
     | 
    
         
            +
                    func = self._get_func_attr(function_name, module, class_name)
         
     | 
| 
       150 
149 
     | 
    
         
             
                    param_types = self._get_params(func_node)
         
     | 
| 
       151 
150 
     | 
    
         | 
| 
       152 
     | 
    
         
            -
                    return FunctionMetadata(filename, module, class_name, function_name, func_node, param_types)
         
     | 
| 
       153 
     | 
    
         
            -
             
     | 
| 
       154 
     | 
    
         
            -
                 
     | 
| 
      
 151 
     | 
    
         
            +
                    return FunctionMetadata(filename, module, class_name, func, function_name, func_node, param_types)
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 154 
     | 
    
         
            +
                def _get_func_attr(function_name: str, module: ModuleType, class_name: str | None) -> Any:
         
     | 
| 
      
 155 
     | 
    
         
            +
                    if class_name:
         
     | 
| 
      
 156 
     | 
    
         
            +
                        cls = getattr(module, class_name)
         
     | 
| 
      
 157 
     | 
    
         
            +
                        instance = cls()
         
     | 
| 
      
 158 
     | 
    
         
            +
                        func = getattr(instance, function_name)
         
     | 
| 
      
 159 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 160 
     | 
    
         
            +
                        func = getattr(module, function_name)
         
     | 
| 
      
 161 
     | 
    
         
            +
                    return func
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 164 
     | 
    
         
            +
                def _get_params(func_node: ast.FunctionDef) -> Dict[str, str]:
         
     | 
| 
       155 
165 
     | 
    
         
             
                    # Extract parameter types
         
     | 
| 
       156 
166 
     | 
    
         
             
                    param_types = {}
         
     | 
| 
       157 
167 
     | 
    
         
             
                    for arg in func_node.args.args:
         
     | 
| 
         @@ -173,18 +183,17 @@ class AnalysisService: 
     | 
|
| 
       173 
183 
     | 
    
         | 
| 
       174 
184 
     | 
    
         
             
                    if strategy == AST_STRAT:
         
     | 
| 
       175 
185 
     | 
    
         
             
                        analyzer = ASTAnalyzer(analysis_context)
         
     | 
| 
       176 
     | 
    
         
            -
                        self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
         
     | 
| 
       177 
186 
     | 
    
         
             
                    elif strategy == FUZZ_STRAT:
         
     | 
| 
       178 
187 
     | 
    
         
             
                        analyzer = FuzzAnalyzer(analysis_context)
         
     | 
| 
       179 
     | 
    
         
            -
                        self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
         
     | 
| 
       180 
188 
     | 
    
         
             
                    elif strategy == RANDOM_STRAT:
         
     | 
| 
       181 
189 
     | 
    
         
             
                        analyzer = RandomFeedbackAnalyzer(analysis_context)
         
     | 
| 
       182 
     | 
    
         
            -
                        self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
         
     | 
| 
       183 
190 
     | 
    
         
             
                    elif strategy == REINFORCE_STRAT:
         
     | 
| 
       184 
     | 
    
         
            -
                         
     | 
| 
      
 191 
     | 
    
         
            +
                        analyzer = ReinforcementAnalyzer(analysis_context, mode=self.reinforcement_mode)
         
     | 
| 
       185 
192 
     | 
    
         
             
                    else:
         
     | 
| 
       186 
193 
     | 
    
         
             
                        raise NotImplementedError(f"Test strategy {strategy} not implemented")
         
     | 
| 
       187 
     | 
    
         
            -
             
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
                    self.test_case_analyzer_context = TestCaseAnalyzerContext(analysis_context, analyzer)
         
     | 
| 
      
 196 
     | 
    
         
            +
             
     | 
| 
       188 
197 
     | 
    
         
             
                def set_file_path(self, path: str):
         
     | 
| 
       189 
198 
     | 
    
         
             
                    """Set the file path for analysis."""
         
     | 
| 
       190 
199 
     | 
    
         
             
                    self.file_path = path
         
     | 
    
        testgen/service/cfg_service.py
    CHANGED
    
    | 
         @@ -1,5 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import os
         
     | 
| 
       2 
2 
     | 
    
         
             
            from typing import List
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            from testgen.models.function_metadata import FunctionMetadata
         
     | 
| 
       3 
5 
     | 
    
         
             
            from testgen.models.test_case import TestCase
         
     | 
| 
       4 
6 
     | 
    
         
             
            from testgen.service.logging_service import get_logger
         
     | 
| 
       5 
7 
     | 
    
         
             
            from testgen.util.coverage_visualizer import CoverageVisualizer
         
     | 
| 
         @@ -46,7 +48,7 @@ class CFGService: 
     | 
|
| 
       46 
48 
     | 
    
         
             
                    filename = os.path.basename(file_path).replace('.py', '')
         
     | 
| 
       47 
49 
     | 
    
         | 
| 
       48 
50 
     | 
    
         
             
                    for func in analysis_context.function_data:
         
     | 
| 
       49 
     | 
    
         
            -
                        self.visualizer.get_covered_lines(file_path,  
     | 
| 
      
 51 
     | 
    
         
            +
                        self.visualizer.get_covered_lines(file_path, func, test_cases)
         
     | 
| 
       50 
52 
     | 
    
         | 
| 
       51 
53 
     | 
    
         
             
                        base_filename = f"{filename}_{func.function_name}_coverage"
         
     | 
| 
       52 
54 
     | 
    
         
             
                        output_filepath = self.get_versioned_filename(visualization_dir, base_filename)
         
     |