testgenie-py 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. testgen/analyzer/ast_analyzer.py +2 -11
  2. testgen/analyzer/fuzz_analyzer.py +1 -6
  3. testgen/analyzer/random_feedback_analyzer.py +20 -293
  4. testgen/analyzer/reinforcement_analyzer.py +59 -57
  5. testgen/analyzer/test_case_analyzer_context.py +0 -6
  6. testgen/controller/cli_controller.py +35 -29
  7. testgen/controller/docker_controller.py +1 -0
  8. testgen/db/dao.py +68 -0
  9. testgen/db/dao_impl.py +226 -0
  10. testgen/{sqlite → db}/db.py +15 -6
  11. testgen/generator/pytest_generator.py +2 -10
  12. testgen/generator/unit_test_generator.py +2 -11
  13. testgen/main.py +1 -3
  14. testgen/models/coverage_data.py +56 -0
  15. testgen/models/db_test_case.py +65 -0
  16. testgen/models/function.py +56 -0
  17. testgen/models/function_metadata.py +11 -1
  18. testgen/models/generator_context.py +30 -3
  19. testgen/models/source_file.py +29 -0
  20. testgen/models/test_result.py +38 -0
  21. testgen/models/test_suite.py +20 -0
  22. testgen/reinforcement/agent.py +1 -27
  23. testgen/reinforcement/environment.py +11 -93
  24. testgen/reinforcement/statement_coverage_state.py +5 -4
  25. testgen/service/analysis_service.py +31 -22
  26. testgen/service/cfg_service.py +3 -1
  27. testgen/service/coverage_service.py +115 -0
  28. testgen/service/db_service.py +140 -0
  29. testgen/service/generator_service.py +77 -20
  30. testgen/service/logging_service.py +2 -2
  31. testgen/service/service.py +62 -231
  32. testgen/service/test_executor_service.py +145 -0
  33. testgen/util/coverage_utils.py +38 -116
  34. testgen/util/coverage_visualizer.py +10 -9
  35. testgen/util/file_utils.py +10 -111
  36. testgen/util/randomizer.py +0 -26
  37. testgen/util/utils.py +197 -38
  38. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/METADATA +1 -1
  39. testgenie_py-0.3.9.dist-info/RECORD +72 -0
  40. testgen/inspector/inspector.py +0 -59
  41. testgen/presentation/__init__.py +0 -0
  42. testgen/presentation/cli_view.py +0 -12
  43. testgen/sqlite/__init__.py +0 -0
  44. testgen/sqlite/db_service.py +0 -239
  45. testgen/testgen.db +0 -0
  46. testgenie_py-0.3.7.dist-info/RECORD +0 -67
  47. /testgen/{inspector → db}/__init__.py +0 -0
  48. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/WHEEL +0 -0
  49. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import inspect
3
2
  import os
4
3
  import sys
5
4
 
@@ -10,8 +9,7 @@ from testgen.service.logging_service import LoggingService, get_logger
10
9
  from testgen.util.file_utils import adjust_file_path_for_docker, get_project_root_in_docker
11
10
  from testgen.controller.docker_controller import DockerController
12
11
  from testgen.service.service import Service
13
- from testgen.presentation.cli_view import CLIView
14
- from testgen.sqlite.db_service import DBService
12
+ from testgen.service.db_service import DBService
15
13
 
16
14
  AST_STRAT = 1
17
15
  FUZZ_STRAT = 2
@@ -24,9 +22,9 @@ DOCTEST_FORMAT = 3
24
22
 
25
23
  class CLIController:
26
24
  #TODO: Possibly create a view 'interface' and use dependency injection to extend other views
27
- def __init__(self, service: Service, view: CLIView):
25
+ def __init__(self, service: Service):
28
26
  self.service = service
29
- self.view = view
27
+ self.logger = None
30
28
 
31
29
  def run(self):
32
30
 
@@ -35,12 +33,16 @@ class CLIController:
35
33
  args = parser.parse_args()
36
34
 
37
35
  LoggingService.get_instance().initialize(
38
- debug_mode=args.debug if hasattr(args, 'debug') else False,
36
+ debug_mode=args.debug if args.debug else False,
39
37
  log_file=args.log_file if hasattr(args, 'log_file') else None,
40
38
  console_output=True
41
39
  )
42
40
 
43
- logger = get_logger()
41
+ self.logger = get_logger()
42
+
43
+ if args.functions:
44
+ self.service.get_all_functions(args.file_path)
45
+ return
44
46
 
45
47
  if args.query:
46
48
  print(f"Querying database for file: {args.file_path}")
@@ -48,7 +50,7 @@ class CLIController:
48
50
  return
49
51
 
50
52
  if args.coverage:
51
- self.service.get_coverage(args.file_path)
53
+ self.service.run_coverage(args.file_path)
52
54
  return
53
55
 
54
56
  running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
@@ -59,7 +61,7 @@ class CLIController:
59
61
  client = self.docker_available()
60
62
  # Skip Docker-dependent operations if client is None
61
63
  if client is None and args.safe:
62
- self.view.display_message("Running with --safe flag requires Docker. Continuing without safe mode.")
64
+ self.logger.debug("Running with --safe flag requires Docker. Continuing without safe mode.")
63
65
  args.safe = False
64
66
  self.execute_generation(args)
65
67
  else:
@@ -69,15 +71,13 @@ class CLIController:
69
71
  if not successful:
70
72
  if hasattr(args, 'db') and args.db:
71
73
  self.service.db_service = DBService(args.db)
72
- self.view.display_message(f"Using database: {args.db}")
74
+ self.logger.debug(f"Using database: {args.db}")
73
75
  self.execute_generation(args)
74
- # Else successful, do nothing - we're done
75
76
  else:
76
- # Initialize database service with specified path
77
77
  if hasattr(args, 'db') and args.db:
78
78
  self.service.db_service = DBService(args.db)
79
- self.view.display_message(f"Using database: {args.db}")
80
- self.view.display_message("Running in local mode...")
79
+ self.logger.debug(f"Using database: {args.db}")
80
+ self.logger.debug("Running in local mode...")
81
81
  self.execute_generation(args)
82
82
 
83
83
  def execute_generation(self, args: argparse.Namespace, running_in_docker: bool = False):
@@ -85,22 +85,23 @@ class CLIController:
85
85
  self.set_service_args(args)
86
86
 
87
87
  if running_in_docker:
88
- self.view.display_message("Running in Docker mode...")
88
+ self.logger.debug("Running in Docker mode...")
89
89
  self.service.generate_test_cases()
90
90
 
91
91
  else:
92
92
  test_file = self.service.generate_tests(args.output)
93
- self.view.display_message(f"Unit tests saved to: {test_file}")
94
- self.view.display_message("Running coverage...")
93
+ self.logger.debug(f"Unit tests saved to: {test_file}")
94
+ print("Executing tests...")
95
+ self.service.run_tests(test_file)
96
+ print("Running coverage...")
95
97
  self.service.run_coverage(test_file)
96
- self.view.display_message("Tests and coverage data saved to database.")
98
+ self.logger.debug("Tests and coverage data saved to database.")
97
99
 
98
100
  if args.visualize:
99
101
  self.service.visualize_test_coverage()
100
102
 
101
103
  except Exception as e:
102
- self.view.display_error(f"An error occurred: {e}")
103
- # Make sure to close the DB connection on error
104
+ self.logger.error(f"An error occurred: {e}")
104
105
  if hasattr(self.service, 'db_service'):
105
106
  self.service.db_service.close()
106
107
 
@@ -168,6 +169,11 @@ class CLIController:
168
169
  action="store_true",
169
170
  help="Run coverage analysis on the generated tests"
170
171
  )
172
+ parser.add_argument(
173
+ "-f", "--functions",
174
+ action="store_true",
175
+ help="List all functions in file"
176
+ )
171
177
  return parser
172
178
 
173
179
  def set_test_format(self, args: argparse.Namespace):
@@ -180,32 +186,32 @@ class CLIController:
180
186
 
181
187
  def set_test_strategy(self, args: argparse.Namespace):
182
188
  if args.test_mode == "random":
183
- self.view.display_message("Using Random Feedback-Directed Test Generation Strategy.")
189
+ print("Using Random Feedback-Directed Test Generation Strategy.")
184
190
  self.service.set_test_analysis_strategy(RANDOM_STRAT)
185
191
  elif args.test_mode == "fuzz":
186
- self.view.display_message("Using Fuzz Test Generation Strategy...")
192
+ print("Using Fuzz Test Generation Strategy...")
187
193
  self.service.set_test_analysis_strategy(FUZZ_STRAT)
188
194
  elif args.test_mode == "reinforce":
189
- self.view.display_message("Using Reinforcement Learning Test Generation Strategy...")
195
+ print("Using Reinforcement Learning Test Generation Strategy...")
190
196
  if args.reinforce_mode == "train":
191
- self.view.display_message("Training mode enabled - will update Q-table")
197
+ print("Training mode enabled - will update Q-table")
192
198
  else:
193
- self.view.display_message("Training mode disabled - will use existing Q-table")
199
+ print("Training mode disabled - will use existing Q-table")
194
200
  self.service.set_test_analysis_strategy(REINFORCE_STRAT)
195
201
  self.service.set_reinforcement_mode(args.reinforce_mode)
196
202
  else:
197
- self.view.display_message("Generating function code using AST analysis...")
203
+ print("Generating function code using AST analysis...")
198
204
  generated_file_path = self.service.generate_function_code()
199
- self.view.display_message(f"Generated code saved to: {generated_file_path}")
205
+ print(f"Generated code saved to: {generated_file_path}")
200
206
  if not args.generate_only:
201
- self.view.display_message("Using Simple AST Traversal Test Generation Strategy...")
207
+ print("Using Simple AST Traversal Test Generation Strategy...")
202
208
  self.service.set_test_analysis_strategy(AST_STRAT)
203
209
 
204
210
  def docker_available(self) -> DockerClient | None:
205
211
  try:
206
212
  client = docker.from_env()
207
213
  client.ping()
208
- self.view.display_message("Docker daemon is running and connected.")
214
+ print("Docker daemon is running and connected.")
209
215
  return client
210
216
  except docker.errors.DockerException as err:
211
217
  print(f"Docker is not available: {err}")
@@ -98,6 +98,7 @@ class DockerController:
98
98
 
99
99
  if not args.generate_only:
100
100
  print("Running coverage...")
101
+ self.service.run_tests(test_file)
101
102
  self.service.run_coverage(test_file)
102
103
 
103
104
  # Add explicit return True here
testgen/db/dao.py ADDED
@@ -0,0 +1,68 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Tuple, Any
3
+
4
+ from testgen.models.function import Function
5
+
6
+
7
+ class Dao(ABC):
8
+ @abstractmethod
9
+ def insert_test_suite(self, name: str) -> int:
10
+ pass
11
+
12
+ @abstractmethod
13
+ def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
14
+ pass
15
+
16
+ @abstractmethod
17
+ def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
18
+ pass
19
+
20
+ @abstractmethod
21
+ def insert_test_case(self, test_case: Any, test_suite_id: int, function_id: int, test_method_type: int) -> int:
22
+ pass
23
+
24
+ @abstractmethod
25
+ def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
26
+ pass
27
+
28
+ @abstractmethod
29
+ def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
30
+ branch_coverage: float, source_file_id: int, function_id: int = None) -> int:
31
+ pass
32
+
33
+ @abstractmethod
34
+ def get_test_suites(self) -> List[Any]:
35
+ pass
36
+
37
+ @abstractmethod
38
+ def get_test_cases_by_function(self, function_name: str) -> List[Any]:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def get_source_file_id_by_path(self, filepath: str) -> int:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_coverage_by_file(self, file_path: str) -> List[Any]:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def get_test_file_data(self, file_path: str) -> List[Any]:
51
+ pass
52
+
53
+ @abstractmethod
54
+ def get_function_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int)-> int:
55
+ pass
56
+
57
+ @abstractmethod
58
+ def get_functions_by_file(self, filepath: str) -> List[Function]:
59
+ pass
60
+
61
+ @abstractmethod
62
+ def get_test_suite_id_by_name(self, name: str) -> int:
63
+ pass
64
+
65
+ @abstractmethod
66
+ def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
67
+ pass
68
+
testgen/db/dao_impl.py ADDED
@@ -0,0 +1,226 @@
1
+ import os
2
+ import sqlite3
3
+ import json
4
+ from typing import List, Tuple, Any
5
+ from datetime import datetime
6
+
7
+ from testgen.db.dao import Dao
8
+ from testgen.models.function import Function
9
+ from testgen.models.test_case import TestCase
10
+ from testgen.db.db import create_database
11
+
12
+ class DaoImpl(Dao):
13
+ def __init__(self, db_name="testgen.db"):
14
+ self.db_name = db_name
15
+ self.conn = None
16
+ self.cursor = None
17
+ self._connect()
18
+
19
+ def _connect(self):
20
+ if not os.path.exists(self.db_name):
21
+ create_database(self.db_name)
22
+ self.conn = sqlite3.connect(self.db_name)
23
+ self.conn.row_factory = sqlite3.Row
24
+ self.cursor = self.conn.cursor()
25
+ self.cursor.execute("PRAGMA foreign_keys = ON;")
26
+
27
+ def close(self):
28
+ if self.conn:
29
+ self.conn.close()
30
+ self.conn = None
31
+ self.cursor = None
32
+
33
+ def insert_test_suite(self, name: str) -> int:
34
+ self.cursor.execute(
35
+ "INSERT INTO TestSuite (name, creation_date) VALUES (?, ?)",
36
+ (name, datetime.now())
37
+ )
38
+ self.conn.commit()
39
+ return self.cursor.lastrowid
40
+
41
+ def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
42
+ self.cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (path,))
43
+ existing = self.cursor.fetchone()
44
+ if existing:
45
+ return existing[0]
46
+ self.cursor.execute(
47
+ "INSERT INTO SourceFile (path, lines_of_code, last_modified) VALUES (?, ?, ?)",
48
+ (path, lines_of_code, last_modified)
49
+ )
50
+ self.conn.commit()
51
+ return self.cursor.lastrowid
52
+
53
+ def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
54
+ print(f"INSERTING FUNCTION: {name}, {params}, {start_line}, {end_line}, {source_file_id}")
55
+
56
+ num_lines = end_line - start_line + 1
57
+ self.cursor.execute(
58
+ "SELECT id FROM Function WHERE name = ? AND source_file_id = ? AND params = ?",
59
+ (name, source_file_id, params)
60
+ )
61
+ existing = self.cursor.fetchone()
62
+ if existing:
63
+ return existing[0]
64
+ self.cursor.execute(
65
+ "INSERT INTO Function (name, params, start_line, end_line, num_lines, source_file_id) VALUES (?, ?, ?, ?, ?, ?)",
66
+ (name, params, start_line, end_line, num_lines, source_file_id)
67
+ )
68
+ self.conn.commit()
69
+ return self.cursor.lastrowid
70
+
71
+ def insert_test_case(self, test_case: TestCase, test_suite_id: int, function_id: int, test_method_type: int) -> int:
72
+ inputs_json = json.dumps(test_case.inputs, sort_keys=True)
73
+ expected_json = json.dumps(test_case.expected, sort_keys=True)
74
+
75
+ print(f"INSERTING TEST CASE: {test_case.func_name}, {test_case.inputs}, {test_case.expected}, {test_suite_id}, {function_id}, {test_method_type}")
76
+
77
+ # Check for existing test case with same function_id, inputs, and expected_output
78
+ self.cursor.execute(
79
+ "SELECT id FROM TestCase WHERE function_id = ? AND input = ? AND expected_output = ?",
80
+ (function_id, inputs_json, expected_json)
81
+ )
82
+ existing = self.cursor.fetchone()
83
+ if existing:
84
+ return existing[0]
85
+
86
+ self.cursor.execute(
87
+ "INSERT INTO TestCase (expected_output, input, test_function, last_run_time, test_method_type, test_suite_id, function_id) "
88
+ "VALUES (?, ?, ?, ?, ?, ?, ?)",
89
+ (
90
+ expected_json,
91
+ inputs_json,
92
+ test_case.func_name,
93
+ datetime.now(),
94
+ test_method_type,
95
+ test_suite_id,
96
+ function_id
97
+ )
98
+ )
99
+ self.conn.commit()
100
+ return self.cursor.lastrowid
101
+
102
+ def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
103
+ self.cursor.execute(
104
+ "INSERT INTO TestResult (test_case_id, status, error, execution_time) VALUES (?, ?, ?, ?)",
105
+ (test_case_id, status, error, datetime.now())
106
+ )
107
+ self.conn.commit()
108
+ return self.cursor.lastrowid
109
+
110
+ def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
111
+ branch_coverage: float, source_file_id: int, function_id: int | None) -> int:
112
+ coverage_type = "file" if function_id is None else "function"
113
+ self.cursor.execute(
114
+ "INSERT INTO CoverageData (coverage_type, executed_lines, missed_lines, branch_coverage, source_file_id, function_id) "
115
+ "VALUES (?, ?, ?, ?, ?, ?)",
116
+ (coverage_type, executed_lines, missed_lines, branch_coverage, source_file_id, function_id)
117
+ )
118
+ self.conn.commit()
119
+ return self.cursor.lastrowid
120
+
121
+ def get_test_suites(self) -> List[Any]:
122
+ self.cursor.execute("SELECT * FROM TestSuite ORDER BY creation_date DESC")
123
+ return self.cursor.fetchall()
124
+
125
+ def get_test_cases_by_function(self, function_name: str) -> List[Any]:
126
+ self.cursor.execute(
127
+ "SELECT tc.* FROM TestCase tc JOIN Function f ON tc.function_id = f.id WHERE f.name = ?",
128
+ (function_name,)
129
+ )
130
+ return self.cursor.fetchall()
131
+
132
+ def get_coverage_by_file(self, file_path: str) -> List[Any]:
133
+ self.cursor.execute(
134
+ "SELECT cd.* FROM CoverageData cd JOIN SourceFile sf ON cd.source_file_id = sf.id WHERE sf.path = ?",
135
+ (file_path,)
136
+ )
137
+ return self.cursor.fetchall()
138
+
139
+ def get_test_file_data(self, file_path: str) -> List[Any]:
140
+ query = """
141
+ SELECT
142
+ tc.id AS test_case_id,
143
+ tc.name AS test_case_name,
144
+ tc.test_function AS test_case_test_function,
145
+ tc.test_method_type AS test_case_method_type,
146
+ COALESCE(cd.missed_lines, 'None') AS coverage_data_missed_lines
147
+ FROM SourceFile sf
148
+ LEFT JOIN Function f ON sf.id = f.source_file_id
149
+ LEFT JOIN TestCase tc ON f.id = tc.function_id
150
+ LEFT JOIN TestResult tr ON tc.id = tr.test_case_id
151
+ LEFT JOIN CoverageData cd ON sf.id = cd.source_file_id
152
+ WHERE sf.path = ?;
153
+ """
154
+ self.cursor.execute(query, (file_path,))
155
+ return self.cursor.fetchall()
156
+
157
+ def get_function_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int) -> int:
158
+ print(f"GETTING FUNCTION ID FOR {name}, {source_file_id}, {start_line}")
159
+ self.cursor.execute(
160
+ """
161
+ SELECT id FROM Function
162
+ WHERE name = ? AND source_file_id = ? AND start_line = ?
163
+ """,
164
+ (name, source_file_id, start_line)
165
+ )
166
+ result = self.cursor.fetchone()
167
+ return result["id"] if result else -1
168
+
169
+ def get_functions_by_file(self, filepath: str) -> List[Function]:
170
+ self.cursor.execute(
171
+ """
172
+ SELECT id, name, params, start_line, end_line, num_lines, source_file_id
173
+ FROM Function
174
+ WHERE source_file_id = (
175
+ SELECT id FROM SourceFile WHERE path = ?
176
+ )
177
+ """,
178
+ (filepath,)
179
+ )
180
+ rows = self.cursor.fetchall()
181
+ return [
182
+ Function(
183
+ name=row["name"],
184
+ params=row["params"],
185
+ start_line=row["start_line"],
186
+ end_line=row["end_line"],
187
+ num_lines=row["num_lines"],
188
+ source_file_id=row["source_file_id"]
189
+ )
190
+ for row in rows
191
+ ]
192
+
193
+ def get_test_suite_id_by_name(self, name: str) -> int:
194
+ self.cursor.execute(
195
+ """
196
+ SELECT id FROM TestSuite
197
+ WHERE name = ?
198
+ """,
199
+ (name,)
200
+ )
201
+ result = self.cursor.fetchone()
202
+ return result["id"] if result else -1
203
+
204
+ def get_source_file_id_by_path(self, filepath: str) -> int:
205
+ self.cursor.execute(
206
+ """
207
+ SELECT id FROM SourceFile
208
+ WHERE path = ?
209
+ """,
210
+ (filepath,)
211
+ )
212
+ result = self.cursor.fetchone()
213
+ return result["id"] if result else -1
214
+
215
+ def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
216
+ self.cursor.execute(
217
+ """
218
+ SELECT id FROM TestCase
219
+ WHERE function_id = ?
220
+ AND input = ?
221
+ AND expected_output = ?
222
+ """,
223
+ (function_id, inputs, expected)
224
+ )
225
+ result = self.cursor.fetchone()
226
+ return result["id"] if result else -1
@@ -29,10 +29,12 @@ def create_database(db_name="testgen.db"):
29
29
  CREATE TABLE IF NOT EXISTS Function (
30
30
  id INTEGER PRIMARY KEY,
31
31
  name TEXT,
32
+ params TEXT,
32
33
  start_line INTEGER,
33
34
  end_line INTEGER,
34
35
  num_lines INTEGER,
35
36
  source_file_id INTEGER,
37
+ UNIQUE(name, params, source_file_id),
36
38
  FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
37
39
  );
38
40
  """)
@@ -40,16 +42,16 @@ def create_database(db_name="testgen.db"):
40
42
  cursor.execute("""
41
43
  CREATE TABLE IF NOT EXISTS TestCase (
42
44
  id INTEGER PRIMARY KEY,
43
- name TEXT,
44
- expected_output TEXT, -- storing JSON as TEXT
45
- input TEXT, -- storing JSON as TEXT
45
+ expected_output TEXT,
46
+ input TEXT,
46
47
  test_function TEXT,
47
48
  last_run_time TIMESTAMP,
48
49
  test_method_type INTEGER,
49
50
  test_suite_id INTEGER,
50
51
  function_id INTEGER,
51
52
  FOREIGN KEY (test_suite_id) REFERENCES TestSuite(id),
52
- FOREIGN KEY (function_id) REFERENCES Function(id)
53
+ FOREIGN KEY (function_id) REFERENCES Function(id),
54
+ UNIQUE(function_id, input, expected_output)
53
55
  );
54
56
  """)
55
57
 
@@ -67,12 +69,19 @@ def create_database(db_name="testgen.db"):
67
69
  cursor.execute("""
68
70
  CREATE TABLE IF NOT EXISTS CoverageData (
69
71
  id INTEGER PRIMARY KEY,
70
- file_name TEXT,
72
+ coverage_type TEXT CHECK(coverage_type IN ('file', 'function')),
71
73
  executed_lines INTEGER,
72
74
  missed_lines INTEGER,
73
75
  branch_coverage REAL,
74
76
  source_file_id INTEGER,
75
- FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
77
+ function_id INTEGER,
78
+ FOREIGN KEY (source_file_id) REFERENCES SourceFile(id),
79
+ FOREIGN KEY (function_id) REFERENCES Function(id),
80
+ CHECK (
81
+ -- Only one of source_file_id or function_id is required based on coverage_type
82
+ (coverage_type = 'file' AND source_file_id IS NOT NULL AND function_id IS NULL) OR
83
+ (coverage_type = 'function' AND function_id IS NOT NULL)
84
+ )
76
85
  );
77
86
  """)
78
87
 
@@ -1,5 +1,4 @@
1
1
  from testgen.generator.test_generator import TestGenerator
2
- from testgen.util.file_utils import get_import_info
3
2
  from testgen.models.generator_context import GeneratorContext
4
3
 
5
4
 
@@ -10,17 +9,10 @@ class PyTestGenerator(TestGenerator):
10
9
 
11
10
  def generate_test_header(self):
12
11
  self.test_code.append("import pytest\n")
13
- import_info = get_import_info(self._generator_context.filepath)
14
12
  if self._generator_context.class_name == "" or self._generator_context.class_name is None:
15
- if import_info['is_package']:
16
- self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
17
- else:
18
- self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
13
+ self.test_code.append(f"import {self._generator_context.import_path} as {self._generator_context.module.__name__}\n")
19
14
  else:
20
- if import_info['is_package']:
21
- self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
22
- else:
23
- self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
15
+ self.test_code.append(f"from {self._generator_context.import_path} import {self._generator_context.class_name}\n")
24
16
 
25
17
  def generate_test_function(self, unique_func_name, func_name, cases):
26
18
  self.test_code.append(f"def test_{unique_func_name}():")
@@ -1,4 +1,3 @@
1
- from testgen.util.file_utils import get_import_info
2
1
  from testgen.generator.test_generator import TestGenerator
3
2
  from testgen.models.generator_context import GeneratorContext
4
3
 
@@ -13,19 +12,11 @@ class UnitTestGenerator(TestGenerator):
13
12
 
14
13
  def generate_test_header(self):
15
14
 
16
- import_info = get_import_info(self._generator_context.filepath)
17
-
18
15
  self.test_code.append("import unittest\n")
19
16
  if self._generator_context.class_name == "" or self._generator_context.class_name is None:
20
- if import_info['is_package']:
21
- self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
22
- else:
23
- self.test_code.append(f"import {import_info['import_path']} as {self._generator_context.module.__name__}\n")
17
+ self.test_code.append(f"import {self._generator_context.import_path} as {self._generator_context.module.__name__}\n")
24
18
  else:
25
- if import_info['is_package']:
26
- self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
27
- else:
28
- self.test_code.append(f"from {import_info['import_path']} import {self._generator_context.class_name}\n")
19
+ self.test_code.append(f"from {self._generator_context.import_path} import {self._generator_context.class_name}\n")
29
20
  self.test_code.append(f"class Test{self._generator_context.class_name}(unittest.TestCase): \n")
30
21
 
31
22
  def generate_test_function(self, unique_func_name, func_name, cases):
testgen/main.py CHANGED
@@ -1,12 +1,10 @@
1
1
  from testgen.service.service import Service
2
2
  from testgen.controller.cli_controller import CLIController
3
3
  from testgen.generator.unit_test_generator import UnitTestGenerator
4
- from testgen.presentation.cli_view import CLIView
5
4
 
6
5
  def main():
7
6
  service = Service()
8
- view = CLIView()
9
- controller = CLIController(service, view)
7
+ controller = CLIController(service)
10
8
  controller.run()
11
9
 
12
10
  if __name__ == '__main__':
@@ -0,0 +1,56 @@
1
+ class CoverageData:
2
+ def __init__(self, coverage_type: str, executed_lines: int, missed_lines: int, branch_coverage: float, source_file_id: int, function_id: int):
3
+ self._coverage_type: str = coverage_type
4
+ self._executed_lines: int = executed_lines
5
+ self._missed_lines: int = missed_lines
6
+ self._branch_coverage: float = branch_coverage
7
+ self._source_file_id: int = source_file_id
8
+ self._function_id: int = function_id
9
+
10
+ @property
11
+ def coverage_type(self) -> str:
12
+ return self._coverage_type
13
+
14
+ @coverage_type.setter
15
+ def coverage_type(self, value: str) -> None:
16
+ self._coverage_type = value
17
+
18
+ @property
19
+ def executed_lines(self) -> int:
20
+ return self._executed_lines
21
+
22
+ @executed_lines.setter
23
+ def executed_lines(self, value: int) -> None:
24
+ self._executed_lines = value
25
+
26
+ @property
27
+ def missed_lines(self) -> int:
28
+ return self._missed_lines
29
+
30
+ @missed_lines.setter
31
+ def missed_lines(self, value: int) -> None:
32
+ self._missed_lines = value
33
+
34
+ @property
35
+ def branch_coverage(self) -> float:
36
+ return self._branch_coverage
37
+
38
+ @branch_coverage.setter
39
+ def branch_coverage(self, value: float) -> None:
40
+ self._branch_coverage = value
41
+
42
+ @property
43
+ def source_file_id(self) -> int:
44
+ return self._source_file_id
45
+
46
+ @source_file_id.setter
47
+ def source_file_id(self, value: int) -> None:
48
+ self._source_file_id = value
49
+
50
+ @property
51
+ def function_id(self) -> int:
52
+ return self._function_id
53
+
54
+ @function_id.setter
55
+ def function_id(self, value: int) -> None:
56
+ self._function_id = value