testgenie-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. testgen/__init__.py +0 -0
  2. testgen/analyzer/__init__.py +0 -0
  3. testgen/analyzer/ast_analyzer.py +149 -0
  4. testgen/analyzer/contracts/__init__.py +0 -0
  5. testgen/analyzer/contracts/contract.py +13 -0
  6. testgen/analyzer/contracts/no_exception_contract.py +16 -0
  7. testgen/analyzer/contracts/nonnull_contract.py +15 -0
  8. testgen/analyzer/fuzz_analyzer.py +106 -0
  9. testgen/analyzer/random_feedback_analyzer.py +291 -0
  10. testgen/analyzer/reinforcement_analyzer.py +75 -0
  11. testgen/analyzer/test_case_analyzer.py +46 -0
  12. testgen/analyzer/test_case_analyzer_context.py +58 -0
  13. testgen/controller/__init__.py +0 -0
  14. testgen/controller/cli_controller.py +194 -0
  15. testgen/controller/docker_controller.py +169 -0
  16. testgen/docker/Dockerfile +22 -0
  17. testgen/docker/poetry.lock +361 -0
  18. testgen/docker/pyproject.toml +22 -0
  19. testgen/generator/__init__.py +0 -0
  20. testgen/generator/code_generator.py +66 -0
  21. testgen/generator/doctest_generator.py +208 -0
  22. testgen/generator/generator.py +55 -0
  23. testgen/generator/pytest_generator.py +77 -0
  24. testgen/generator/test_generator.py +26 -0
  25. testgen/generator/unit_test_generator.py +84 -0
  26. testgen/inspector/__init__.py +0 -0
  27. testgen/inspector/inspector.py +61 -0
  28. testgen/main.py +13 -0
  29. testgen/models/__init__.py +0 -0
  30. testgen/models/analysis_context.py +56 -0
  31. testgen/models/function_metadata.py +61 -0
  32. testgen/models/generator_context.py +63 -0
  33. testgen/models/test_case.py +8 -0
  34. testgen/presentation/__init__.py +0 -0
  35. testgen/presentation/cli_view.py +12 -0
  36. testgen/q_table/global_q_table.json +1 -0
  37. testgen/reinforcement/__init__.py +0 -0
  38. testgen/reinforcement/abstract_state.py +7 -0
  39. testgen/reinforcement/agent.py +153 -0
  40. testgen/reinforcement/environment.py +215 -0
  41. testgen/reinforcement/statement_coverage_state.py +33 -0
  42. testgen/service/__init__.py +0 -0
  43. testgen/service/analysis_service.py +260 -0
  44. testgen/service/cfg_service.py +55 -0
  45. testgen/service/generator_service.py +169 -0
  46. testgen/service/service.py +389 -0
  47. testgen/sqlite/__init__.py +0 -0
  48. testgen/sqlite/db.py +84 -0
  49. testgen/sqlite/db_service.py +219 -0
  50. testgen/tree/__init__.py +0 -0
  51. testgen/tree/node.py +7 -0
  52. testgen/tree/tree_utils.py +79 -0
  53. testgen/util/__init__.py +0 -0
  54. testgen/util/coverage_utils.py +168 -0
  55. testgen/util/coverage_visualizer.py +154 -0
  56. testgen/util/file_utils.py +110 -0
  57. testgen/util/randomizer.py +122 -0
  58. testgen/util/utils.py +143 -0
  59. testgen/util/z3_utils/__init__.py +0 -0
  60. testgen/util/z3_utils/ast_to_z3.py +99 -0
  61. testgen/util/z3_utils/branch_condition.py +72 -0
  62. testgen/util/z3_utils/constraint_extractor.py +36 -0
  63. testgen/util/z3_utils/variable_finder.py +10 -0
  64. testgen/util/z3_utils/z3_test_case.py +94 -0
  65. testgenie_py-0.1.0.dist-info/METADATA +24 -0
  66. testgenie_py-0.1.0.dist-info/RECORD +68 -0
  67. testgenie_py-0.1.0.dist-info/WHEEL +4 -0
  68. testgenie_py-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,63 @@
1
+ from types import ModuleType
2
+ from typing import List
3
+
4
+ from testgen.models.test_case import TestCase
5
+
6
+
7
+ class GeneratorContext:
8
+ def __init__(self, filepath: str, filename: str, class_name:str | None, module: ModuleType, output_path: str, test_cases: List[TestCase]):
9
+ self._filepath: str = filepath
10
+ self._filename: str = filename
11
+ self._class_name: str = class_name
12
+ self._module: ModuleType = module
13
+ self._output_path: str = output_path
14
+ self._test_cases: List[TestCase] = test_cases
15
+
16
+ @property
17
+ def filepath(self) -> str:
18
+ return self._filepath
19
+
20
+ @filepath.setter
21
+ def filepath(self, value: str) -> None:
22
+ self._filepath = value
23
+
24
+ @property
25
+ def filename(self) -> str:
26
+ return self._filename
27
+
28
+ @filename.setter
29
+ def filename(self, value: str) -> None:
30
+ self._filename = value
31
+
32
+ @property
33
+ def class_name(self) -> str | None:
34
+ return self._class_name
35
+
36
+ @class_name.setter
37
+ def class_name(self, value: str | None) -> None:
38
+ self._class_name = value
39
+
40
+ @property
41
+ def module(self) -> ModuleType:
42
+ return self._module
43
+
44
+ @module.setter
45
+ def module(self, value: ModuleType) -> None:
46
+ self._module = value
47
+
48
+ @property
49
+ def output_path(self) -> str:
50
+ return self._output_path
51
+
52
+ @output_path.setter
53
+ def output_path(self, value: str) -> None:
54
+ self._output_path = value
55
+
56
+ @property
57
+ def test_cases(self) -> List[TestCase]:
58
+ return self._test_cases
59
+
60
+ @test_cases.setter
61
+ def test_cases(self, value: List[TestCase]) -> None:
62
+ self._test_cases = value
63
+
@@ -0,0 +1,8 @@
1
+ import string
2
+ from typing import List
3
+
4
+ class TestCase:
5
+ def __init__(self, func_name, inputs: tuple, expected: any):
6
+ self.func_name = func_name
7
+ self.inputs = inputs
8
+ self.expected = expected
File without changes
@@ -0,0 +1,12 @@
1
+ class CLIView:
2
+ def __init__(self):
3
+ "pass"
4
+
5
+ def display_message(self, message: str):
6
+ print(f"[INFO] {message}")
7
+
8
+ def display_error(self, error: str):
9
+ print(f"[ERROR] {error}")
10
+
11
+ def prompt_input(self, prompt: str) -> str:
12
+ return input(f"{prompt}:> ")
@@ -0,0 +1 @@
1
+ {"(0.0, 0)|add": 1.0147553256632154, "(62.5, 1)|z3": -0.020000000000000004, "(62.5, 3)|add": -0.010000000000000002, "(62.5, 4)|add": 0.18000000000000002, "(87.5, 5)|add": 0.07100000000000001, "(75.0, 1)|add": 0.6092791100000001, "(87.5, 2)|add": 0.18509310000000004, "(87.5, 3)|z3": 0.0, "(87.5, 3)|add": -0.010000000000000002, "(87.5, 4)|add": 0.26100000000000007, "(87.5, 2)|remove": -0.09000000000000001, "(62.5, 1)|add": -0.019000000000000003, "(62.5, 2)|add": -0.010000000000000002, "(62.5, 3)|z3": -0.020000000000000004, "(62.5, 5)|add": -0.010000000000000002, "(62.5, 6)|add": -0.010000000000000002, "(62.5, 7)|add": 0.07100000000000001, "(87.5, 8)|add": -0.010000000000000002, "(87.5, 9)|z3": -0.010000000000000002, "(87.5, 3)|remove": -0.09000000000000001, "(62.5, 2)|z3": -0.020900000000000002, "(62.5, 4)|merge": -0.010000000000000002, "(62.5, 5)|merge": -0.010000000000000002, "(62.5, 6)|remove": 0.010000000000000002, "(62.5, 5)|z3": -0.011000000000000001, "(87.5, 3)|merge": 0.08000000000000002, "(62.5, 4)|remove": 0.010000000000000002, "(62.5, 3)|merge": -0.0009999999999999996, "(87.5, 6)|z3": 0.0, "(87.5, 6)|merge": 0.09000000000000001, "(87.5, 2)|z3": -0.0009999999999999996, "(75.0, 1)|z3": 0.10056590000000001, "(60.0, 1)|add": 0.18917349, "(60.0, 2)|add": 0.17900000000000002, "(70.0, 3)|add": 0.23680000000000007, "(90.0, 4)|add": -0.010000000000000002, "(90.0, 5)|add": 0.1629, "(90.0, 6)|add": -0.01929, "(90.0, 7)|z3": -0.0046099999999999995, "(90.0, 7)|merge": -0.010000000000000002, "(90.0, 8)|add": -0.010000000000000002, "(0.0, 0)|z3": 0.20518, "(60.0, 4)|add": 0.24390000000000006, "(70.0, 5)|add": -0.010000000000000002, "(70.0, 6)|remove": 0.019000000000000003, "(70.0, 5)|z3": -0.020000000000000004, "(70.0, 7)|add": -0.010000000000000002, "(70.0, 8)|z3": -0.020000000000000004, "(70.0, 10)|add": 0.09000000000000001, "(60.0, 2)|z3": 0.08000000000000002, "(70.0, 4)|add": 0.06100000000000001, "(90.0, 5)|z3": -0.010000000000000002, "(90.0, 6)|z3": -0.0029, "(90.0, 6)|merge": -0.010000000000000002, "(90.0, 7)|add": 0.05390000000000001, "(60.0, 1)|z3": 0.26273490000000005, "(70.0, 5)|merge": -0.009, "(70.0, 5)|remove": -0.08100000000000002, "(80.0, 5)|z3": -0.010000000000000002, "(80.0, 6)|add": 0.18449000000000004, "(80.0, 2)|add": -0.010000000000000002, "(80.0, 3)|add": -0.010000000000000002, "(80.0, 4)|z3": 0.053800000000000014, "(90.0, 8)|z3": -0.010000000000000002, "(90.0, 9)|add": -0.010000000000000002, "(80.0, 3)|remove": -0.08100000000000002, "(90.0, 4)|z3": 0.0, "(90.0, 4)|remove": 0.054200000000000005, "(90.0, 3)|add": 0.09000000000000001, "(90.0, 3)|z3": -0.0072000000000000015, "(80.0, 2)|z3": -0.010000000000000002, "(80.0, 3)|z3": -0.0009999999999999996, "(60.0, 4)|remove": 0.010000000000000002, "(60.0, 3)|add": 0.09710000000000002, "(80.0, 5)|add": 0.08944900000000001, "(90.0, 6)|remove": -0.08100000000000002, "(53.84615384615385, 1)|add": 0.3104837900000001, "(76.92307692307693, 2)|remove": -0.09000000000000001, "(61.53846153846154, 1)|add": 0.17100000000000004, "(76.92307692307693, 2)|add": -0.010000000000000002, "(76.92307692307693, 3)|add": 0.081, "(76.92307692307693, 4)|remove": -0.0719, "(76.92307692307693, 3)|z3": 0.0, "(84.61538461538461, 4)|z3": 0.0, "(53.84615384615385, 2)|z3": -0.04000000000000001, "(53.84615384615385, 6)|add": 0.09000000000000001, "(69.23076923076923, 7)|merge": -0.010000000000000002, "(69.23076923076923, 8)|add": 0.15390000000000004, "(76.92307692307693, 9)|add": -0.010000000000000002, "(76.92307692307693, 10)|add": 0.09000000000000001, "(61.53846153846154, 2)|add": -0.010000000000000002, "(61.53846153846154, 3)|add": -0.010000000000000002, "(61.53846153846154, 4)|add": 0.07100000000000001, "(84.61538461538461, 5)|add": -0.010000000000000002, "(84.61538461538461, 6)|add": 0.17100000000000004, "(92.3076923076923, 7)|add": -0.010000000000000002, "(92.3076923076923, 8)|add": -0.010000000000000002, "(69.23076923076923, 2)|add": 0.171, "(69.23076923076923, 3)|merge": -0.010000000000000002, "(69.23076923076923, 4)|add": -0.010000000000000002, "(69.23076923076923, 5)|add": -0.010000000000000002, "(69.23076923076923, 6)|z3": -0.011000000000000001, "(69.23076923076923, 9)|add": -0.010000000000000002, "(69.23076923076923, 2)|z3": -0.030000000000000006, "(69.23076923076923, 5)|remove": 0.019000000000000003, "(69.23076923076923, 4)|z3": -0.030000000000000006, "(69.23076923076923, 7)|add": -0.0029, "(84.61538461538461, 9)|add": 0.09000000000000001, "(92.3076923076923, 10)|add": 0.07100000000000001, "(53.84615384615385, 2)|merge": -0.010000000000000002, "(53.84615384615385, 3)|add": -0.010000000000000002, "(53.84615384615385, 4)|add": 0.09100000000000001, "(69.23076923076923, 4)|remove": 0.010000000000000002, "(69.23076923076923, 3)|add": 0.23471000000000006, "(76.92307692307693, 4)|add": -0.010000000000000002, "(53.84615384615385, 1)|z3": -0.04000000000000001, "(53.84615384615385, 5)|z3": -0.04000000000000001, "(53.84615384615385, 9)|add": 0.09000000000000001, "(69.23076923076923, 10)|add": -0.010000000000000002, "(69.23076923076923, 11)|add": -0.010000000000000002, "(69.23076923076923, 12)|add": -0.010000000000000002, "(76.92307692307693, 2)|z3": 0.09000000000000001, "(84.61538461538461, 3)|add": -0.010000000000000002, "(84.61538461538461, 4)|add": -0.010000000000000002, "(84.61538461538461, 5)|z3": -0.0009999999999999996, "(92.3076923076923, 7)|z3": -0.010000000000000002, "(92.3076923076923, 8)|z3": -0.010000000000000002, "(76.92307692307693, 4)|z3": -0.020000000000000004, "(76.92307692307693, 6)|add": 0.05904900000000002, "(92.3076923076923, 7)|merge": -0.010000000000000002, "(92.3076923076923, 8)|merge": -0.010000000000000002, "(92.3076923076923, 9)|merge": -0.0009999999999999996, "(53.84615384615385, 2)|add": 0.09000000000000001, "(61.53846153846154, 3)|merge": -0.0009999999999999996, "(61.53846153846154, 5)|z3": 0.08000000000000002, "(76.92307692307693, 7)|remove": 0.034390000000000004, "(75.0, 2)|add": -0.010000000000000002, "(75.0, 3)|merge": -0.010000000000000002, "(87.5, 5)|merge": -0.0029, "(100.0, 6)|add": 0.09000000000000001, "(70.0, 3)|z3": 0.10368000000000002, "(70.0, 3)|remove": -0.06632, "(70.0, 4)|merge": -0.0038999999999999994, "(100.0, 8)|add": 0.09000000000000001, "(53.84615384615385, 1)|remove": -0.058951621, "(69.23076923076923, 4)|merge": -0.009, "(69.23076923076923, 7)|z3": -0.030000000000000006, "(69.23076923076923, 6)|remove": 0.010000000000000002, "(69.23076923076923, 9)|z3": -0.030000000000000006, "(84.61538461538461, 10)|add": 0.09000000000000001}
File without changes
@@ -0,0 +1,7 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+
4
+ class AbstractState(ABC):
5
+ @abstractmethod
6
+ def get_state(self) -> Tuple:
7
+ pass
@@ -0,0 +1,153 @@
1
+ import random
2
+ import time
3
+ from typing import List
4
+
5
+ from testgen.models.test_case import TestCase
6
+ from testgen.reinforcement.environment import ReinforcementEnvironment
7
+
8
+
9
+ class ReinforcementAgent:
10
+ def __init__(self, file_name: str, environment: ReinforcementEnvironment, test_cases: List[TestCase], q_table=None):
11
+ self.learning_rate = 0.1
12
+ self.file_name = file_name
13
+ self.env = environment
14
+ self.q_table = q_table if q_table else {}
15
+ self.actions = ["add", "merge", "remove", "z3"]
16
+
17
+ def collect_test_cases(self) -> List[TestCase]:
18
+ max_time = 30
19
+ if not self.q_table:
20
+ print("Q_TABLE IS EMPTY, RUN TRAIN FIRST")
21
+ return []
22
+ else:
23
+ current_state = self.env.get_state()
24
+
25
+ goal_state: float = 100.0
26
+
27
+ start_time = time.time()
28
+
29
+ while current_state[0] != goal_state and time.time() - start_time < max_time:
30
+ action = self.choose_action(current_state)
31
+ next_state, reward = self.env.step(action)
32
+
33
+ if not isinstance(next_state, tuple) or len(next_state) != 2:
34
+ raise ValueError(f"Expected new_state to be a tuple (covered_statements, coverage_percentage, len(test_cases)), but got: {next_state}")
35
+
36
+ self.update_q_table(next_state, action, next_state, reward)
37
+
38
+ current_state = next_state
39
+
40
+ return self.env.test_cases
41
+
42
+
43
+ def do_q_learning(self, episodes=10):
44
+ max_time = 30
45
+ best_coverage = 0.0
46
+ best_test_cases = []
47
+
48
+ for episode in range(episodes):
49
+ print(f"\nNEW EPISODE {episode}")
50
+ self.env.reset()
51
+
52
+ current_state = self.env.get_state()
53
+ print(f"Current state after reset: {current_state}")
54
+
55
+ goal_state: float = 100.0
56
+ steps_in_episode = 1
57
+ max_steps_per_episode = 100
58
+
59
+ start_time = time.time()
60
+
61
+ while current_state[0] != goal_state and steps_in_episode < max_steps_per_episode and time.time() - start_time < max_time:
62
+ print(f"Step {steps_in_episode} in episode {episode}")
63
+
64
+ action = self.choose_action(current_state)
65
+ next_state, reward = self.env.step(action)
66
+
67
+ if not isinstance(next_state, tuple) or len(next_state) != 2:
68
+ raise ValueError(f"Expected new_state to be a tuple (covered_statements, coverage_percentage, len(test_cases)), but got: {next_state}")
69
+
70
+ print(f"AFTER NEW STATE, REWARD: {reward}")
71
+
72
+ # Update q_table
73
+ self.update_q_table(current_state, action, next_state, reward)
74
+ current_state = next_state
75
+
76
+ steps_in_episode += 1
77
+ if current_state[0] > best_coverage:
78
+ best_coverage = current_state[0]
79
+ best_test_cases = self.env.test_cases.copy()
80
+ print(f"New best coverage: {best_coverage}% with {len(best_test_cases)} test cases")
81
+ elif current_state[0] == best_coverage and len(best_test_cases) > len(self.env.test_cases):
82
+ best_test_cases = self.env.test_cases.copy()
83
+ print(f"New best coverage: {best_coverage}% with {len(best_test_cases)} test cases")
84
+
85
+ return best_test_cases
86
+
87
+
88
+ def choose_action(self, state):
89
+ EXPLORATION = 0
90
+ EXPLOITATION = 1
91
+
92
+ weights = [0.33, 0.67]
93
+
94
+ choice = random.choices([EXPLORATION, EXPLOITATION], weights=weights, k=1)[0]
95
+ action_list = self.get_action_list(state[1])
96
+
97
+ if not isinstance(state, tuple) or len(state) != 2:
98
+ raise ValueError(f"Expected state to be a tuple (covered_statements, coverage_percentage, len(test_cases)), but got: {state}")
99
+
100
+ if choice == EXPLORATION:
101
+ chosen_action = random.choice(action_list)
102
+ print(f"CHOSEN EXPLORATION ACTION: {chosen_action}")
103
+ return chosen_action
104
+ else:
105
+ chosen_action = max(action_list, key=lambda action: self.q_table.get((state, action), 0), default=random.choice(action_list))
106
+ print(f"CHOSEN EXPLOITATION ACTION: {chosen_action}")
107
+ return chosen_action
108
+
109
+ def optimize_test_suit(self, current_state, executable_statements):
110
+ # Try to optimize test cases by repeatedly performing remove actions if reached full coverage
111
+ test_case_count = current_state[1]
112
+ optimization_attempts = min(10, test_case_count - 1)
113
+
114
+ for _ in range(optimization_attempts):
115
+ if test_case_count <= 1:
116
+ break
117
+
118
+ action = "remove"
119
+ next_state, reward = self.env.step(action)
120
+
121
+ new_covered = next_state[0]
122
+ new_uncovered = [stmt for stmt in executable_statements if stmt not in new_covered]
123
+
124
+ if len(new_uncovered) == 0:
125
+ current_state = next_state
126
+ test_case_count = current_state[2]
127
+ print(f"Optimized to {test_case_count} test cases.")
128
+ else:
129
+ # Add a test case back if removing broke coverage
130
+ self.env.step("add")
131
+ break
132
+
133
+ return current_state
134
+
135
+ @staticmethod
136
+ def get_action_list(test_case_length: int) -> List[str]:
137
+ action_list = ["add", "z3"]
138
+ if test_case_length >= 2:
139
+ action_list.extend(["merge", "remove"])
140
+ return action_list
141
+
142
+ def update_q_table(self, state: tuple, action: str, new_state:tuple, reward:float):
143
+ current_q = self.q_table.get((state, action), 0)
144
+ print(f"CURRENT Q: {current_q}")
145
+ valid_actions = self.get_action_list(new_state[1])
146
+
147
+ max_next_q = max(self.q_table.get((new_state, a), 0) for a in valid_actions)
148
+ print(f"MAX NEXT Q: {max_next_q}")
149
+
150
+ print(f"UPDATING Q TABLE FOR STATE: {state}, ACTION: {action} WITH REWARD: {reward}")
151
+ new_q = (1 - self.learning_rate) * current_q + self.learning_rate * (reward + max_next_q)
152
+
153
+ self.q_table[(state, action)] = new_q
@@ -0,0 +1,215 @@
1
+ import ast
2
+ import io
3
+ from typing import List, Tuple
4
+
5
+ import coverage
6
+
7
+ import testgen.util.coverage_utils
8
+ import testgen.util.file_utils
9
+ from testgen.reinforcement.abstract_state import AbstractState
10
+ import testgen.util.randomizer as randomizer
11
+ from testgen.models.test_case import TestCase
12
+
13
+
14
+ class ReinforcementEnvironment:
15
+ def __init__(self, file_name, fut: ast.FunctionDef, module, initial_test_cases: List[TestCase], state: AbstractState):
16
+ self.file_name = file_name
17
+ self.fut = fut
18
+ self.module = module
19
+ self.initial_test_cases = initial_test_cases
20
+ self.test_cases = initial_test_cases.copy()
21
+ self.state = state
22
+ self.cov = coverage.Coverage()
23
+
24
+ # State represented by covered_statements, test_count
25
+ def get_state(self) -> Tuple:
26
+ return self.state.get_state()
27
+
28
+ def step(self, action) -> Tuple[Tuple, float]:
29
+ prev_coverage = self.state.get_state()[0] # Get actual coverage before action
30
+ prev_test_cases = self.state.get_state()[1]
31
+ print(f"STEP: Previous coverage: {prev_coverage} before action: {action}")
32
+
33
+ # Execute action
34
+ if action == "add":
35
+ self.test_cases.append(randomizer.new_random_test_case(self.file_name, self.fut))
36
+ elif action == "merge" and len(self.test_cases) > 1:
37
+ self.test_cases.append(randomizer.combine_cases(self.test_cases))
38
+ elif action == "remove" and len(self.test_cases) > 1:
39
+ self.test_cases = randomizer.remove_case(self.test_cases)
40
+ elif action == "z3":
41
+ self.test_cases = randomizer.get_z3_test_cases(self.file_name, self.fut, self.test_cases)
42
+ else:
43
+ raise ValueError("Invalid action")
44
+
45
+ # Update state with new coverage
46
+ new_coverage = self.state.get_state()[0]
47
+ num_test_cases = self.state.get_state()[1]
48
+
49
+ # Calculate reward
50
+ coverage_delta = new_coverage - prev_coverage
51
+ num_test_cases_delta = num_test_cases - prev_test_cases
52
+ reward = self.get_reward(coverage_delta, num_test_cases_delta)
53
+
54
+ print(f"Action: {action}, Previous coverage: {prev_coverage}, New coverage: {new_coverage}, Reward: {reward}")
55
+
56
+ return self.get_state(), reward
57
+
58
+ def reset(self) -> None:
59
+ self.test_cases = self.initial_test_cases.copy()
60
+
61
+ def render(self):
62
+ pass
63
+
64
+ @staticmethod
65
+ def get_reward(coverage_delta, num_test_cases_delta) -> float:
66
+ reward: float
67
+ """
68
+ Reward of 1.0 for increasing coverage
69
+ No reward for no change
70
+ Penalty of -1.0 for decreasing coverage
71
+ """
72
+ if coverage_delta > 0:
73
+ reward = 1.0
74
+ elif coverage_delta == 0:
75
+ reward = 0.0
76
+ else:
77
+ reward = -1.0
78
+
79
+ print(f"Coverage delta reward: {reward}")
80
+
81
+ """
82
+ If new test cases are added, subtract a small penalty
83
+ If test cases are removed, add a small bonus
84
+ If test cases are the same, no change
85
+ """
86
+ test_cases_factor = (num_test_cases_delta * -0.1)
87
+ reward = reward + test_cases_factor
88
+ print(f"Reward or penalty added to coverage delta reward: {test_cases_factor}")
89
+
90
+ print(f"Final reward {reward}")
91
+ return reward
92
+
93
+
94
+ def get_all_executable_statements(self):
95
+ """Get all executable statements including else branches"""
96
+ import ast
97
+
98
+ test_cases = [tc for tc in self.test_cases if tc.func_name == self.fut.name]
99
+
100
+ if not test_cases:
101
+ print("Warning: No test cases available to determine executable statements")
102
+ from testgen.util.randomizer import new_random_test_case
103
+ temp_case = new_random_test_case(self.file_name, self.fut)
104
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.fut.name, temp_case.inputs)
105
+ else:
106
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.file_name, self.fut.name, test_cases[0].inputs)
107
+
108
+ # Get standard executable lines from coverage.py
109
+ executable_lines = list(analysis[1])
110
+
111
+ # Parse the source file to find else branches
112
+ with open(self.file_name, 'r') as f:
113
+ source = f.read()
114
+
115
+ # Parse the code
116
+ tree = ast.parse(source)
117
+
118
+ # Find our specific function
119
+ for node in ast.walk(tree):
120
+ if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
121
+ # Find all if statements in this function
122
+ for if_node in ast.walk(node):
123
+ if isinstance(if_node, ast.If) and if_node.orelse:
124
+ # There's an else branch
125
+ if isinstance(if_node.orelse[0], ast.If):
126
+ # This is an elif - already counted
127
+ continue
128
+
129
+ # Get the line number of the first statement in the else block
130
+ # and subtract 1 to get the 'else:' line
131
+ else_line = if_node.orelse[0].lineno - 1
132
+
133
+ # Check if this is actually an else line (not a nested if)
134
+ with open(self.file_name, 'r') as f:
135
+ lines = f.readlines()
136
+ if else_line <= len(lines):
137
+ line_content = lines[else_line - 1].strip()
138
+ if line_content == "else:":
139
+ if else_line not in executable_lines:
140
+ executable_lines.append(else_line)
141
+
142
+ return sorted(executable_lines)
143
+
144
+ def run_tests(self) -> float:
145
+ """Run all tests and calculate coverage with branch awareness"""
146
+ import os
147
+
148
+ # Create a coverage object with branch tracking
149
+ self.cov = coverage.Coverage(branch=True)
150
+ self.cov.start()
151
+
152
+ # Execute all test cases
153
+ for test_case in self.test_cases:
154
+ try:
155
+ module = testgen.util.file_utils.load_module(self.file_name)
156
+ func = getattr(module, self.fut.name)
157
+ _ = func(*test_case.inputs)
158
+ except Exception as e:
159
+ import traceback
160
+ print(f"[ERROR]: {traceback.format_exc()}")
161
+
162
+ self.cov.stop()
163
+
164
+ # Get detailed coverage data including branches
165
+ file_path = os.path.abspath(self.file_name)
166
+ data = self.cov.get_data()
167
+
168
+ # Extract function-specific coverage
169
+ function_range = self._get_function_line_range()
170
+ if function_range:
171
+ start_line, end_line = function_range
172
+
173
+ # Calculate function-specific coverage
174
+ analysis = self.cov.analysis2(file_path)
175
+
176
+ if len(analysis) >= 4:
177
+ executable_in_func = [line for line in analysis[1] if start_line <= line <= end_line]
178
+ missed_in_func = [line for line in analysis[3] if start_line <= line <= end_line]
179
+
180
+ if executable_in_func:
181
+ func_coverage = (len(executable_in_func) - len(missed_in_func)) / len(executable_in_func) * 100
182
+ return func_coverage
183
+
184
+ # Fall back to standard coverage calculation
185
+ fake_file = io.StringIO()
186
+ total_coverage = self.cov.report(file=fake_file)
187
+ self.cov.save()
188
+ return total_coverage
189
+
190
+ def _get_function_line_range(self):
191
+ """Get the line range of the current function"""
192
+ import ast
193
+
194
+ try:
195
+ with open(self.file_name, 'r') as f:
196
+ source = f.read()
197
+
198
+ tree = ast.parse(source)
199
+
200
+ for node in ast.walk(tree):
201
+ if isinstance(node, ast.FunctionDef) and node.name == self.fut.name:
202
+ # Find the first line of the function
203
+ start_line = node.lineno
204
+
205
+ # Find the last line by getting the maximum line number of any node in this function
206
+ max_line = start_line
207
+ for child in ast.walk(node):
208
+ if hasattr(child, 'lineno'):
209
+ max_line = max(max_line, child.lineno)
210
+
211
+ return (start_line, max_line)
212
+ except Exception as e:
213
+ print(f"Error getting function range: {e}")
214
+
215
+ return None
@@ -0,0 +1,33 @@
1
+ from typing import Tuple
2
+
3
+ import testgen.util.coverage_utils
4
+ from testgen.reinforcement.abstract_state import AbstractState
5
+ from testgen.util import utils
6
+
7
+ class StatementCoverageState(AbstractState):
8
+ def __init__(self, environment):
9
+ self.environment = environment
10
+
11
+ def get_state(self) -> Tuple[float, int]:
12
+ """Returns calculated coverage and length of test cases in a tuple"""
13
+ all_covered_statements = set()
14
+ for test_case in self.environment.test_cases:
15
+ analysis = testgen.util.coverage_utils.get_coverage_analysis(self.environment.file_name, self.environment.fut.name, test_case.inputs)
16
+ covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
17
+ all_covered_statements.update(covered)
18
+
19
+ executable_statements = self.environment.get_all_executable_statements()
20
+
21
+ if not executable_statements or executable_statements == 0:
22
+ calc_coverage = 0.0
23
+ else:
24
+ calc_coverage: float = (len(all_covered_statements) / len(executable_statements)) * 100
25
+
26
+ print(f"GET STATE ALL COVERED STATEMENTS: {all_covered_statements}")
27
+ print(f"GET STATE ALL EXECUTABLE STATEMENTS: {self.environment.get_all_executable_statements()}")
28
+ print(f"GET STATE FLOAT COVERAGE: {calc_coverage}")
29
+
30
+ if calc_coverage >= 100:
31
+ print(f"!!!!!!!!FULLY COVERED FUNCTION: {self.environment.fut.name}!!!!!!!!")
32
+ return calc_coverage, len(self.environment.test_cases)
33
+
File without changes