testgenie-py 0.3.6__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/ast_analyzer.py +2 -11
- testgen/analyzer/fuzz_analyzer.py +1 -6
- testgen/analyzer/random_feedback_analyzer.py +20 -293
- testgen/analyzer/reinforcement_analyzer.py +59 -57
- testgen/analyzer/test_case_analyzer_context.py +0 -6
- testgen/controller/cli_controller.py +35 -29
- testgen/controller/docker_controller.py +3 -2
- testgen/db/dao.py +68 -0
- testgen/db/dao_impl.py +226 -0
- testgen/{sqlite → db}/db.py +15 -6
- testgen/generator/pytest_generator.py +2 -10
- testgen/generator/unit_test_generator.py +2 -11
- testgen/main.py +1 -3
- testgen/models/coverage_data.py +56 -0
- testgen/models/db_test_case.py +65 -0
- testgen/models/function.py +56 -0
- testgen/models/function_metadata.py +11 -1
- testgen/models/generator_context.py +32 -2
- testgen/models/source_file.py +29 -0
- testgen/models/test_result.py +38 -0
- testgen/models/test_suite.py +20 -0
- testgen/reinforcement/agent.py +1 -27
- testgen/reinforcement/environment.py +11 -93
- testgen/reinforcement/statement_coverage_state.py +5 -4
- testgen/service/analysis_service.py +31 -22
- testgen/service/cfg_service.py +3 -1
- testgen/service/coverage_service.py +115 -0
- testgen/service/db_service.py +140 -0
- testgen/service/generator_service.py +77 -20
- testgen/service/logging_service.py +2 -2
- testgen/service/service.py +62 -231
- testgen/service/test_executor_service.py +145 -0
- testgen/util/coverage_utils.py +38 -116
- testgen/util/coverage_visualizer.py +10 -9
- testgen/util/file_utils.py +10 -111
- testgen/util/randomizer.py +0 -26
- testgen/util/utils.py +197 -38
- {testgenie_py-0.3.6.dist-info → testgenie_py-0.3.8.dist-info}/METADATA +1 -1
- testgenie_py-0.3.8.dist-info/RECORD +72 -0
- testgen/inspector/inspector.py +0 -59
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +0 -12
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db_service.py +0 -239
- testgen/testgen.db +0 -0
- testgenie_py-0.3.6.dist-info/RECORD +0 -67
- /testgen/{inspector → db}/__init__.py +0 -0
- {testgenie_py-0.3.6.dist-info → testgenie_py-0.3.8.dist-info}/WHEEL +0 -0
- {testgenie_py-0.3.6.dist-info → testgenie_py-0.3.8.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,4 @@
|
|
1
1
|
import argparse
|
2
|
-
import inspect
|
3
2
|
import os
|
4
3
|
import sys
|
5
4
|
|
@@ -10,8 +9,7 @@ from testgen.service.logging_service import LoggingService, get_logger
|
|
10
9
|
from testgen.util.file_utils import adjust_file_path_for_docker, get_project_root_in_docker
|
11
10
|
from testgen.controller.docker_controller import DockerController
|
12
11
|
from testgen.service.service import Service
|
13
|
-
from testgen.
|
14
|
-
from testgen.sqlite.db_service import DBService
|
12
|
+
from testgen.service.db_service import DBService
|
15
13
|
|
16
14
|
AST_STRAT = 1
|
17
15
|
FUZZ_STRAT = 2
|
@@ -24,9 +22,9 @@ DOCTEST_FORMAT = 3
|
|
24
22
|
|
25
23
|
class CLIController:
|
26
24
|
#TODO: Possibly create a view 'interface' and use dependency injection to extend other views
|
27
|
-
def __init__(self, service: Service
|
25
|
+
def __init__(self, service: Service):
|
28
26
|
self.service = service
|
29
|
-
self.
|
27
|
+
self.logger = None
|
30
28
|
|
31
29
|
def run(self):
|
32
30
|
|
@@ -35,12 +33,16 @@ class CLIController:
|
|
35
33
|
args = parser.parse_args()
|
36
34
|
|
37
35
|
LoggingService.get_instance().initialize(
|
38
|
-
debug_mode=args.debug if
|
36
|
+
debug_mode=args.debug if args.debug else False,
|
39
37
|
log_file=args.log_file if hasattr(args, 'log_file') else None,
|
40
38
|
console_output=True
|
41
39
|
)
|
42
40
|
|
43
|
-
logger = get_logger()
|
41
|
+
self.logger = get_logger()
|
42
|
+
|
43
|
+
if args.functions:
|
44
|
+
self.service.get_all_functions(args.file_path)
|
45
|
+
return
|
44
46
|
|
45
47
|
if args.query:
|
46
48
|
print(f"Querying database for file: {args.file_path}")
|
@@ -48,7 +50,7 @@ class CLIController:
|
|
48
50
|
return
|
49
51
|
|
50
52
|
if args.coverage:
|
51
|
-
self.service.
|
53
|
+
self.service.run_coverage(args.file_path)
|
52
54
|
return
|
53
55
|
|
54
56
|
running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
|
@@ -59,7 +61,7 @@ class CLIController:
|
|
59
61
|
client = self.docker_available()
|
60
62
|
# Skip Docker-dependent operations if client is None
|
61
63
|
if client is None and args.safe:
|
62
|
-
self.
|
64
|
+
self.logger.debug("Running with --safe flag requires Docker. Continuing without safe mode.")
|
63
65
|
args.safe = False
|
64
66
|
self.execute_generation(args)
|
65
67
|
else:
|
@@ -69,15 +71,13 @@ class CLIController:
|
|
69
71
|
if not successful:
|
70
72
|
if hasattr(args, 'db') and args.db:
|
71
73
|
self.service.db_service = DBService(args.db)
|
72
|
-
self.
|
74
|
+
self.logger.debug(f"Using database: {args.db}")
|
73
75
|
self.execute_generation(args)
|
74
|
-
# Else successful, do nothing - we're done
|
75
76
|
else:
|
76
|
-
# Initialize database service with specified path
|
77
77
|
if hasattr(args, 'db') and args.db:
|
78
78
|
self.service.db_service = DBService(args.db)
|
79
|
-
self.
|
80
|
-
self.
|
79
|
+
self.logger.debug(f"Using database: {args.db}")
|
80
|
+
self.logger.debug("Running in local mode...")
|
81
81
|
self.execute_generation(args)
|
82
82
|
|
83
83
|
def execute_generation(self, args: argparse.Namespace, running_in_docker: bool = False):
|
@@ -85,22 +85,23 @@ class CLIController:
|
|
85
85
|
self.set_service_args(args)
|
86
86
|
|
87
87
|
if running_in_docker:
|
88
|
-
self.
|
88
|
+
self.logger.debug("Running in Docker mode...")
|
89
89
|
self.service.generate_test_cases()
|
90
90
|
|
91
91
|
else:
|
92
92
|
test_file = self.service.generate_tests(args.output)
|
93
|
-
self.
|
94
|
-
|
93
|
+
self.logger.debug(f"Unit tests saved to: {test_file}")
|
94
|
+
print("Executing tests...")
|
95
|
+
self.service.run_tests(test_file)
|
96
|
+
print("Running coverage...")
|
95
97
|
self.service.run_coverage(test_file)
|
96
|
-
self.
|
98
|
+
self.logger.debug("Tests and coverage data saved to database.")
|
97
99
|
|
98
100
|
if args.visualize:
|
99
101
|
self.service.visualize_test_coverage()
|
100
102
|
|
101
103
|
except Exception as e:
|
102
|
-
self.
|
103
|
-
# Make sure to close the DB connection on error
|
104
|
+
self.logger.error(f"An error occurred: {e}")
|
104
105
|
if hasattr(self.service, 'db_service'):
|
105
106
|
self.service.db_service.close()
|
106
107
|
|
@@ -168,6 +169,11 @@ class CLIController:
|
|
168
169
|
action="store_true",
|
169
170
|
help="Run coverage analysis on the generated tests"
|
170
171
|
)
|
172
|
+
parser.add_argument(
|
173
|
+
"-f", "--functions",
|
174
|
+
action="store_true",
|
175
|
+
help="List all functions in file"
|
176
|
+
)
|
171
177
|
return parser
|
172
178
|
|
173
179
|
def set_test_format(self, args: argparse.Namespace):
|
@@ -180,32 +186,32 @@ class CLIController:
|
|
180
186
|
|
181
187
|
def set_test_strategy(self, args: argparse.Namespace):
|
182
188
|
if args.test_mode == "random":
|
183
|
-
|
189
|
+
print("Using Random Feedback-Directed Test Generation Strategy.")
|
184
190
|
self.service.set_test_analysis_strategy(RANDOM_STRAT)
|
185
191
|
elif args.test_mode == "fuzz":
|
186
|
-
|
192
|
+
print("Using Fuzz Test Generation Strategy...")
|
187
193
|
self.service.set_test_analysis_strategy(FUZZ_STRAT)
|
188
194
|
elif args.test_mode == "reinforce":
|
189
|
-
|
195
|
+
print("Using Reinforcement Learning Test Generation Strategy...")
|
190
196
|
if args.reinforce_mode == "train":
|
191
|
-
|
197
|
+
print("Training mode enabled - will update Q-table")
|
192
198
|
else:
|
193
|
-
|
199
|
+
print("Training mode disabled - will use existing Q-table")
|
194
200
|
self.service.set_test_analysis_strategy(REINFORCE_STRAT)
|
195
201
|
self.service.set_reinforcement_mode(args.reinforce_mode)
|
196
202
|
else:
|
197
|
-
|
203
|
+
print("Generating function code using AST analysis...")
|
198
204
|
generated_file_path = self.service.generate_function_code()
|
199
|
-
|
205
|
+
print(f"Generated code saved to: {generated_file_path}")
|
200
206
|
if not args.generate_only:
|
201
|
-
|
207
|
+
print("Using Simple AST Traversal Test Generation Strategy...")
|
202
208
|
self.service.set_test_analysis_strategy(AST_STRAT)
|
203
209
|
|
204
210
|
def docker_available(self) -> DockerClient | None:
|
205
211
|
try:
|
206
212
|
client = docker.from_env()
|
207
213
|
client.ping()
|
208
|
-
|
214
|
+
print("Docker daemon is running and connected.")
|
209
215
|
return client
|
210
216
|
except docker.errors.DockerException as err:
|
211
217
|
print(f"Docker is not available: {err}")
|
@@ -59,10 +59,10 @@ class DockerController:
|
|
59
59
|
self.debug(f"project_root: {project_root}")
|
60
60
|
container = self.run_container(docker_client, image_name, docker_args, project_root)
|
61
61
|
|
62
|
-
self.clean_up(dest_path)
|
63
|
-
|
64
62
|
# Stream the logs to the console
|
65
63
|
logs_output = self.get_logs(container)
|
64
|
+
|
65
|
+
self.clean_up(dest_path)
|
66
66
|
self.debug(logs_output)
|
67
67
|
|
68
68
|
except Exception as e:
|
@@ -98,6 +98,7 @@ class DockerController:
|
|
98
98
|
|
99
99
|
if not args.generate_only:
|
100
100
|
print("Running coverage...")
|
101
|
+
self.service.run_tests(test_file)
|
101
102
|
self.service.run_coverage(test_file)
|
102
103
|
|
103
104
|
# Add explicit return True here
|
testgen/db/dao.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import List, Tuple, Any
|
3
|
+
|
4
|
+
from testgen.models.function import Function
|
5
|
+
|
6
|
+
|
7
|
+
class Dao(ABC):
|
8
|
+
@abstractmethod
|
9
|
+
def insert_test_suite(self, name: str) -> int:
|
10
|
+
pass
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
|
14
|
+
pass
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
|
18
|
+
pass
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def insert_test_case(self, test_case: Any, test_suite_id: int, function_id: int, test_method_type: int) -> int:
|
22
|
+
pass
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
|
30
|
+
branch_coverage: float, source_file_id: int, function_id: int = None) -> int:
|
31
|
+
pass
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def get_test_suites(self) -> List[Any]:
|
35
|
+
pass
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def get_test_cases_by_function(self, function_name: str) -> List[Any]:
|
39
|
+
pass
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def get_source_file_id_by_path(self, filepath: str) -> int:
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def get_coverage_by_file(self, file_path: str) -> List[Any]:
|
47
|
+
pass
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def get_test_file_data(self, file_path: str) -> List[Any]:
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def get_function_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int)-> int:
|
55
|
+
pass
|
56
|
+
|
57
|
+
@abstractmethod
|
58
|
+
def get_functions_by_file(self, filepath: str) -> List[Function]:
|
59
|
+
pass
|
60
|
+
|
61
|
+
@abstractmethod
|
62
|
+
def get_test_suite_id_by_name(self, name: str) -> int:
|
63
|
+
pass
|
64
|
+
|
65
|
+
@abstractmethod
|
66
|
+
def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
|
67
|
+
pass
|
68
|
+
|
testgen/db/dao_impl.py
ADDED
@@ -0,0 +1,226 @@
|
|
1
|
+
import os
|
2
|
+
import sqlite3
|
3
|
+
import json
|
4
|
+
from typing import List, Tuple, Any
|
5
|
+
from datetime import datetime
|
6
|
+
|
7
|
+
from testgen.db.dao import Dao
|
8
|
+
from testgen.models.function import Function
|
9
|
+
from testgen.models.test_case import TestCase
|
10
|
+
from testgen.db.db import create_database
|
11
|
+
|
12
|
+
class DaoImpl(Dao):
|
13
|
+
def __init__(self, db_name="testgen.db"):
|
14
|
+
self.db_name = db_name
|
15
|
+
self.conn = None
|
16
|
+
self.cursor = None
|
17
|
+
self._connect()
|
18
|
+
|
19
|
+
def _connect(self):
|
20
|
+
if not os.path.exists(self.db_name):
|
21
|
+
create_database(self.db_name)
|
22
|
+
self.conn = sqlite3.connect(self.db_name)
|
23
|
+
self.conn.row_factory = sqlite3.Row
|
24
|
+
self.cursor = self.conn.cursor()
|
25
|
+
self.cursor.execute("PRAGMA foreign_keys = ON;")
|
26
|
+
|
27
|
+
def close(self):
|
28
|
+
if self.conn:
|
29
|
+
self.conn.close()
|
30
|
+
self.conn = None
|
31
|
+
self.cursor = None
|
32
|
+
|
33
|
+
def insert_test_suite(self, name: str) -> int:
|
34
|
+
self.cursor.execute(
|
35
|
+
"INSERT INTO TestSuite (name, creation_date) VALUES (?, ?)",
|
36
|
+
(name, datetime.now())
|
37
|
+
)
|
38
|
+
self.conn.commit()
|
39
|
+
return self.cursor.lastrowid
|
40
|
+
|
41
|
+
def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
|
42
|
+
self.cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (path,))
|
43
|
+
existing = self.cursor.fetchone()
|
44
|
+
if existing:
|
45
|
+
return existing[0]
|
46
|
+
self.cursor.execute(
|
47
|
+
"INSERT INTO SourceFile (path, lines_of_code, last_modified) VALUES (?, ?, ?)",
|
48
|
+
(path, lines_of_code, last_modified)
|
49
|
+
)
|
50
|
+
self.conn.commit()
|
51
|
+
return self.cursor.lastrowid
|
52
|
+
|
53
|
+
def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
|
54
|
+
print(f"INSERTING FUNCTION: {name}, {params}, {start_line}, {end_line}, {source_file_id}")
|
55
|
+
|
56
|
+
num_lines = end_line - start_line + 1
|
57
|
+
self.cursor.execute(
|
58
|
+
"SELECT id FROM Function WHERE name = ? AND source_file_id = ? AND params = ?",
|
59
|
+
(name, source_file_id, params)
|
60
|
+
)
|
61
|
+
existing = self.cursor.fetchone()
|
62
|
+
if existing:
|
63
|
+
return existing[0]
|
64
|
+
self.cursor.execute(
|
65
|
+
"INSERT INTO Function (name, params, start_line, end_line, num_lines, source_file_id) VALUES (?, ?, ?, ?, ?, ?)",
|
66
|
+
(name, params, start_line, end_line, num_lines, source_file_id)
|
67
|
+
)
|
68
|
+
self.conn.commit()
|
69
|
+
return self.cursor.lastrowid
|
70
|
+
|
71
|
+
def insert_test_case(self, test_case: TestCase, test_suite_id: int, function_id: int, test_method_type: int) -> int:
|
72
|
+
inputs_json = json.dumps(test_case.inputs, sort_keys=True)
|
73
|
+
expected_json = json.dumps(test_case.expected, sort_keys=True)
|
74
|
+
|
75
|
+
print(f"INSERTING TEST CASE: {test_case.func_name}, {test_case.inputs}, {test_case.expected}, {test_suite_id}, {function_id}, {test_method_type}")
|
76
|
+
|
77
|
+
# Check for existing test case with same function_id, inputs, and expected_output
|
78
|
+
self.cursor.execute(
|
79
|
+
"SELECT id FROM TestCase WHERE function_id = ? AND input = ? AND expected_output = ?",
|
80
|
+
(function_id, inputs_json, expected_json)
|
81
|
+
)
|
82
|
+
existing = self.cursor.fetchone()
|
83
|
+
if existing:
|
84
|
+
return existing[0]
|
85
|
+
|
86
|
+
self.cursor.execute(
|
87
|
+
"INSERT INTO TestCase (expected_output, input, test_function, last_run_time, test_method_type, test_suite_id, function_id) "
|
88
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
89
|
+
(
|
90
|
+
expected_json,
|
91
|
+
inputs_json,
|
92
|
+
test_case.func_name,
|
93
|
+
datetime.now(),
|
94
|
+
test_method_type,
|
95
|
+
test_suite_id,
|
96
|
+
function_id
|
97
|
+
)
|
98
|
+
)
|
99
|
+
self.conn.commit()
|
100
|
+
return self.cursor.lastrowid
|
101
|
+
|
102
|
+
def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
|
103
|
+
self.cursor.execute(
|
104
|
+
"INSERT INTO TestResult (test_case_id, status, error, execution_time) VALUES (?, ?, ?, ?)",
|
105
|
+
(test_case_id, status, error, datetime.now())
|
106
|
+
)
|
107
|
+
self.conn.commit()
|
108
|
+
return self.cursor.lastrowid
|
109
|
+
|
110
|
+
def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
|
111
|
+
branch_coverage: float, source_file_id: int, function_id: int | None) -> int:
|
112
|
+
coverage_type = "file" if function_id is None else "function"
|
113
|
+
self.cursor.execute(
|
114
|
+
"INSERT INTO CoverageData (coverage_type, executed_lines, missed_lines, branch_coverage, source_file_id, function_id) "
|
115
|
+
"VALUES (?, ?, ?, ?, ?, ?)",
|
116
|
+
(coverage_type, executed_lines, missed_lines, branch_coverage, source_file_id, function_id)
|
117
|
+
)
|
118
|
+
self.conn.commit()
|
119
|
+
return self.cursor.lastrowid
|
120
|
+
|
121
|
+
def get_test_suites(self) -> List[Any]:
|
122
|
+
self.cursor.execute("SELECT * FROM TestSuite ORDER BY creation_date DESC")
|
123
|
+
return self.cursor.fetchall()
|
124
|
+
|
125
|
+
def get_test_cases_by_function(self, function_name: str) -> List[Any]:
|
126
|
+
self.cursor.execute(
|
127
|
+
"SELECT tc.* FROM TestCase tc JOIN Function f ON tc.function_id = f.id WHERE f.name = ?",
|
128
|
+
(function_name,)
|
129
|
+
)
|
130
|
+
return self.cursor.fetchall()
|
131
|
+
|
132
|
+
def get_coverage_by_file(self, file_path: str) -> List[Any]:
|
133
|
+
self.cursor.execute(
|
134
|
+
"SELECT cd.* FROM CoverageData cd JOIN SourceFile sf ON cd.source_file_id = sf.id WHERE sf.path = ?",
|
135
|
+
(file_path,)
|
136
|
+
)
|
137
|
+
return self.cursor.fetchall()
|
138
|
+
|
139
|
+
def get_test_file_data(self, file_path: str) -> List[Any]:
|
140
|
+
query = """
|
141
|
+
SELECT
|
142
|
+
tc.id AS test_case_id,
|
143
|
+
tc.name AS test_case_name,
|
144
|
+
tc.test_function AS test_case_test_function,
|
145
|
+
tc.test_method_type AS test_case_method_type,
|
146
|
+
COALESCE(cd.missed_lines, 'None') AS coverage_data_missed_lines
|
147
|
+
FROM SourceFile sf
|
148
|
+
LEFT JOIN Function f ON sf.id = f.source_file_id
|
149
|
+
LEFT JOIN TestCase tc ON f.id = tc.function_id
|
150
|
+
LEFT JOIN TestResult tr ON tc.id = tr.test_case_id
|
151
|
+
LEFT JOIN CoverageData cd ON sf.id = cd.source_file_id
|
152
|
+
WHERE sf.path = ?;
|
153
|
+
"""
|
154
|
+
self.cursor.execute(query, (file_path,))
|
155
|
+
return self.cursor.fetchall()
|
156
|
+
|
157
|
+
def get_function_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int) -> int:
|
158
|
+
print(f"GETTING FUNCTION ID FOR {name}, {source_file_id}, {start_line}")
|
159
|
+
self.cursor.execute(
|
160
|
+
"""
|
161
|
+
SELECT id FROM Function
|
162
|
+
WHERE name = ? AND source_file_id = ? AND start_line = ?
|
163
|
+
""",
|
164
|
+
(name, source_file_id, start_line)
|
165
|
+
)
|
166
|
+
result = self.cursor.fetchone()
|
167
|
+
return result["id"] if result else -1
|
168
|
+
|
169
|
+
def get_functions_by_file(self, filepath: str) -> List[Function]:
|
170
|
+
self.cursor.execute(
|
171
|
+
"""
|
172
|
+
SELECT id, name, params, start_line, end_line, num_lines, source_file_id
|
173
|
+
FROM Function
|
174
|
+
WHERE source_file_id = (
|
175
|
+
SELECT id FROM SourceFile WHERE path = ?
|
176
|
+
)
|
177
|
+
""",
|
178
|
+
(filepath,)
|
179
|
+
)
|
180
|
+
rows = self.cursor.fetchall()
|
181
|
+
return [
|
182
|
+
Function(
|
183
|
+
name=row["name"],
|
184
|
+
params=row["params"],
|
185
|
+
start_line=row["start_line"],
|
186
|
+
end_line=row["end_line"],
|
187
|
+
num_lines=row["num_lines"],
|
188
|
+
source_file_id=row["source_file_id"]
|
189
|
+
)
|
190
|
+
for row in rows
|
191
|
+
]
|
192
|
+
|
193
|
+
def get_test_suite_id_by_name(self, name: str) -> int:
|
194
|
+
self.cursor.execute(
|
195
|
+
"""
|
196
|
+
SELECT id FROM TestSuite
|
197
|
+
WHERE name = ?
|
198
|
+
""",
|
199
|
+
(name,)
|
200
|
+
)
|
201
|
+
result = self.cursor.fetchone()
|
202
|
+
return result["id"] if result else -1
|
203
|
+
|
204
|
+
def get_source_file_id_by_path(self, filepath: str) -> int:
|
205
|
+
self.cursor.execute(
|
206
|
+
"""
|
207
|
+
SELECT id FROM SourceFile
|
208
|
+
WHERE path = ?
|
209
|
+
""",
|
210
|
+
(filepath,)
|
211
|
+
)
|
212
|
+
result = self.cursor.fetchone()
|
213
|
+
return result["id"] if result else -1
|
214
|
+
|
215
|
+
def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
|
216
|
+
self.cursor.execute(
|
217
|
+
"""
|
218
|
+
SELECT id FROM TestCase
|
219
|
+
WHERE function_id = ?
|
220
|
+
AND input = ?
|
221
|
+
AND expected_output = ?
|
222
|
+
""",
|
223
|
+
(function_id, inputs, expected)
|
224
|
+
)
|
225
|
+
result = self.cursor.fetchone()
|
226
|
+
return result["id"] if result else -1
|
testgen/{sqlite → db}/db.py
RENAMED
@@ -29,10 +29,12 @@ def create_database(db_name="testgen.db"):
|
|
29
29
|
CREATE TABLE IF NOT EXISTS Function (
|
30
30
|
id INTEGER PRIMARY KEY,
|
31
31
|
name TEXT,
|
32
|
+
params TEXT,
|
32
33
|
start_line INTEGER,
|
33
34
|
end_line INTEGER,
|
34
35
|
num_lines INTEGER,
|
35
36
|
source_file_id INTEGER,
|
37
|
+
UNIQUE(name, params, source_file_id),
|
36
38
|
FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
|
37
39
|
);
|
38
40
|
""")
|
@@ -40,16 +42,16 @@ def create_database(db_name="testgen.db"):
|
|
40
42
|
cursor.execute("""
|
41
43
|
CREATE TABLE IF NOT EXISTS TestCase (
|
42
44
|
id INTEGER PRIMARY KEY,
|
43
|
-
|
44
|
-
|
45
|
-
input TEXT, -- storing JSON as TEXT
|
45
|
+
expected_output TEXT,
|
46
|
+
input TEXT,
|
46
47
|
test_function TEXT,
|
47
48
|
last_run_time TIMESTAMP,
|
48
49
|
test_method_type INTEGER,
|
49
50
|
test_suite_id INTEGER,
|
50
51
|
function_id INTEGER,
|
51
52
|
FOREIGN KEY (test_suite_id) REFERENCES TestSuite(id),
|
52
|
-
FOREIGN KEY (function_id) REFERENCES Function(id)
|
53
|
+
FOREIGN KEY (function_id) REFERENCES Function(id),
|
54
|
+
UNIQUE(function_id, input, expected_output)
|
53
55
|
);
|
54
56
|
""")
|
55
57
|
|
@@ -67,12 +69,19 @@ def create_database(db_name="testgen.db"):
|
|
67
69
|
cursor.execute("""
|
68
70
|
CREATE TABLE IF NOT EXISTS CoverageData (
|
69
71
|
id INTEGER PRIMARY KEY,
|
70
|
-
|
72
|
+
coverage_type TEXT CHECK(coverage_type IN ('file', 'function')),
|
71
73
|
executed_lines INTEGER,
|
72
74
|
missed_lines INTEGER,
|
73
75
|
branch_coverage REAL,
|
74
76
|
source_file_id INTEGER,
|
75
|
-
|
77
|
+
function_id INTEGER,
|
78
|
+
FOREIGN KEY (source_file_id) REFERENCES SourceFile(id),
|
79
|
+
FOREIGN KEY (function_id) REFERENCES Function(id),
|
80
|
+
CHECK (
|
81
|
+
-- Only one of source_file_id or function_id is required based on coverage_type
|
82
|
+
(coverage_type = 'file' AND source_file_id IS NOT NULL AND function_id IS NULL) OR
|
83
|
+
(coverage_type = 'function' AND function_id IS NOT NULL)
|
84
|
+
)
|
76
85
|
);
|
77
86
|
""")
|
78
87
|
|
@@ -1,5 +1,4 @@
|
|
1
1
|
from testgen.generator.test_generator import TestGenerator
|
2
|
-
from testgen.util.file_utils import get_import_info
|
3
2
|
from testgen.models.generator_context import GeneratorContext
|
4
3
|
|
5
4
|
|
@@ -10,17 +9,10 @@ class PyTestGenerator(TestGenerator):
|
|
10
9
|
|
11
10
|
def generate_test_header(self):
|
12
11
|
self.test_code.append("import pytest\n")
|
13
|
-
import_info = get_import_info(self._generator_context.filepath)
|
14
12
|
if self._generator_context.class_name == "" or self._generator_context.class_name is None:
|
15
|
-
|
16
|
-
self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
|
17
|
-
else:
|
18
|
-
self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
|
13
|
+
self.test_code.append(f"import {self._generator_context.import_path} as {self._generator_context.module.__name__}\n")
|
19
14
|
else:
|
20
|
-
|
21
|
-
self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
|
22
|
-
else:
|
23
|
-
self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
|
15
|
+
self.test_code.append(f"from {self._generator_context.import_path} import {self._generator_context.class_name}\n")
|
24
16
|
|
25
17
|
def generate_test_function(self, unique_func_name, func_name, cases):
|
26
18
|
self.test_code.append(f"def test_{unique_func_name}():")
|
@@ -1,4 +1,3 @@
|
|
1
|
-
from testgen.util.file_utils import get_import_info
|
2
1
|
from testgen.generator.test_generator import TestGenerator
|
3
2
|
from testgen.models.generator_context import GeneratorContext
|
4
3
|
|
@@ -13,19 +12,11 @@ class UnitTestGenerator(TestGenerator):
|
|
13
12
|
|
14
13
|
def generate_test_header(self):
|
15
14
|
|
16
|
-
import_info = get_import_info(self._generator_context.filepath)
|
17
|
-
|
18
15
|
self.test_code.append("import unittest\n")
|
19
16
|
if self._generator_context.class_name == "" or self._generator_context.class_name is None:
|
20
|
-
|
21
|
-
self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
|
22
|
-
else:
|
23
|
-
self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
|
17
|
+
self.test_code.append(f"import {self._generator_context.import_path} as {self._generator_context.module.__name__}\n")
|
24
18
|
else:
|
25
|
-
|
26
|
-
self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
|
27
|
-
else:
|
28
|
-
self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
|
19
|
+
self.test_code.append(f"from {self._generator_context.import_path} import {self._generator_context.class_name}\n")
|
29
20
|
self.test_code.append(f"class Test{self._generator_context.class_name}(unittest.TestCase): \n")
|
30
21
|
|
31
22
|
def generate_test_function(self, unique_func_name, func_name, cases):
|
testgen/main.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1
1
|
from testgen.service.service import Service
|
2
2
|
from testgen.controller.cli_controller import CLIController
|
3
3
|
from testgen.generator.unit_test_generator import UnitTestGenerator
|
4
|
-
from testgen.presentation.cli_view import CLIView
|
5
4
|
|
6
5
|
def main():
|
7
6
|
service = Service()
|
8
|
-
|
9
|
-
controller = CLIController(service, view)
|
7
|
+
controller = CLIController(service)
|
10
8
|
controller.run()
|
11
9
|
|
12
10
|
if __name__ == '__main__':
|
@@ -0,0 +1,56 @@
|
|
1
|
+
class CoverageData:
|
2
|
+
def __init__(self, coverage_type: str, executed_lines: int, missed_lines: int, branch_coverage: float, source_file_id: int, function_id: int):
|
3
|
+
self._coverage_type: str = coverage_type
|
4
|
+
self._executed_lines: int = executed_lines
|
5
|
+
self._missed_lines: int = missed_lines
|
6
|
+
self._branch_coverage: float = branch_coverage
|
7
|
+
self._source_file_id: int = source_file_id
|
8
|
+
self._function_id: int = function_id
|
9
|
+
|
10
|
+
@property
|
11
|
+
def coverage_type(self) -> str:
|
12
|
+
return self._coverage_type
|
13
|
+
|
14
|
+
@coverage_type.setter
|
15
|
+
def coverage_type(self, value: str) -> None:
|
16
|
+
self._coverage_type = value
|
17
|
+
|
18
|
+
@property
|
19
|
+
def executed_lines(self) -> int:
|
20
|
+
return self._executed_lines
|
21
|
+
|
22
|
+
@executed_lines.setter
|
23
|
+
def executed_lines(self, value: int) -> None:
|
24
|
+
self._executed_lines = value
|
25
|
+
|
26
|
+
@property
|
27
|
+
def missed_lines(self) -> int:
|
28
|
+
return self._missed_lines
|
29
|
+
|
30
|
+
@missed_lines.setter
|
31
|
+
def missed_lines(self, value: int) -> None:
|
32
|
+
self._missed_lines = value
|
33
|
+
|
34
|
+
@property
|
35
|
+
def branch_coverage(self) -> float:
|
36
|
+
return self._branch_coverage
|
37
|
+
|
38
|
+
@branch_coverage.setter
|
39
|
+
def branch_coverage(self, value: float) -> None:
|
40
|
+
self._branch_coverage = value
|
41
|
+
|
42
|
+
@property
|
43
|
+
def source_file_id(self) -> int:
|
44
|
+
return self._source_file_id
|
45
|
+
|
46
|
+
@source_file_id.setter
|
47
|
+
def source_file_id(self, value: int) -> None:
|
48
|
+
self._source_file_id = value
|
49
|
+
|
50
|
+
@property
|
51
|
+
def function_id(self) -> int:
|
52
|
+
return self._function_id
|
53
|
+
|
54
|
+
@function_id.setter
|
55
|
+
def function_id(self, value: int) -> None:
|
56
|
+
self._function_id = value
|