testgenie-py 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. testgen/analyzer/ast_analyzer.py +2 -11
  2. testgen/analyzer/fuzz_analyzer.py +1 -6
  3. testgen/analyzer/random_feedback_analyzer.py +20 -293
  4. testgen/analyzer/reinforcement_analyzer.py +59 -57
  5. testgen/analyzer/test_case_analyzer_context.py +0 -6
  6. testgen/controller/cli_controller.py +35 -29
  7. testgen/controller/docker_controller.py +1 -0
  8. testgen/db/dao.py +68 -0
  9. testgen/db/dao_impl.py +226 -0
  10. testgen/{sqlite → db}/db.py +15 -6
  11. testgen/generator/pytest_generator.py +2 -10
  12. testgen/generator/unit_test_generator.py +2 -11
  13. testgen/main.py +1 -3
  14. testgen/models/coverage_data.py +56 -0
  15. testgen/models/db_test_case.py +65 -0
  16. testgen/models/function.py +56 -0
  17. testgen/models/function_metadata.py +11 -1
  18. testgen/models/generator_context.py +30 -3
  19. testgen/models/source_file.py +29 -0
  20. testgen/models/test_result.py +38 -0
  21. testgen/models/test_suite.py +20 -0
  22. testgen/reinforcement/agent.py +1 -27
  23. testgen/reinforcement/environment.py +11 -93
  24. testgen/reinforcement/statement_coverage_state.py +5 -4
  25. testgen/service/analysis_service.py +31 -22
  26. testgen/service/cfg_service.py +3 -1
  27. testgen/service/coverage_service.py +115 -0
  28. testgen/service/db_service.py +140 -0
  29. testgen/service/generator_service.py +77 -20
  30. testgen/service/logging_service.py +2 -2
  31. testgen/service/service.py +62 -231
  32. testgen/service/test_executor_service.py +145 -0
  33. testgen/util/coverage_utils.py +38 -116
  34. testgen/util/coverage_visualizer.py +10 -9
  35. testgen/util/file_utils.py +10 -111
  36. testgen/util/randomizer.py +0 -26
  37. testgen/util/utils.py +197 -38
  38. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/METADATA +1 -1
  39. testgenie_py-0.3.9.dist-info/RECORD +72 -0
  40. testgen/inspector/inspector.py +0 -59
  41. testgen/presentation/__init__.py +0 -0
  42. testgen/presentation/cli_view.py +0 -12
  43. testgen/sqlite/__init__.py +0 -0
  44. testgen/sqlite/db_service.py +0 -239
  45. testgen/testgen.db +0 -0
  46. testgenie_py-0.3.7.dist-info/RECORD +0 -67
  47. /testgen/{inspector → db}/__init__.py +0 -0
  48. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/WHEEL +0 -0
  49. {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/entry_points.txt +0 -0
@@ -1,29 +1,26 @@
1
- import doctest
2
- import importlib
3
1
  import inspect
4
2
  import json
5
3
  import os
6
4
  import re
7
- import sqlite3
8
- import sys
9
5
  import time
10
6
  import subprocess
11
7
 
12
- import coverage
13
- import testgen.util.coverage_utils as coverage_utils
14
8
  from types import ModuleType
15
- from typing import List
9
+ from typing import Any, Dict, List
16
10
 
17
- import testgen
18
11
  import testgen.util.file_utils as file_utils
12
+ import testgen.util.utils
13
+ from testgen.models.coverage_data import CoverageData
14
+ from testgen.models.function import Function
19
15
 
20
16
  from testgen.models.test_case import TestCase
21
17
  from testgen.service.analysis_service import AnalysisService
22
18
  from testgen.service.generator_service import GeneratorService
23
- from testgen.sqlite.db_service import DBService
19
+ from testgen.service.coverage_service import CoverageService
20
+ from testgen.service.db_service import DBService
24
21
  from testgen.models.analysis_context import AnalysisContext
25
22
  from testgen.service.logging_service import get_logger
26
-
23
+ from testgen.service.test_executor_service import TestExecutorService
27
24
 
28
25
  # Constants for test strategies
29
26
  AST_STRAT = 1
@@ -50,6 +47,9 @@ class Service:
50
47
  # Initialize specialized services
51
48
  self.analysis_service = AnalysisService()
52
49
  self.generator_service = GeneratorService(None, None, None)
50
+ self.coverage_service = CoverageService()
51
+ self.test_executor_service = TestExecutorService()
52
+ self.coverage_service = CoverageService()
53
53
  # Only initialize DB service if not running in Docker
54
54
  if os.environ.get("RUNNING_IN_DOCKER") is None:
55
55
  self.db_service = DBService()
@@ -100,12 +100,14 @@ class Service:
100
100
  # Only save to DB if not running in Docker
101
101
  if os.environ.get("RUNNING_IN_DOCKER") is None:
102
102
  file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
103
- self.db_service.save_test_generation_data(
103
+ # Don't save test to db foreign key constraint violation error
104
+ """self.db_service.save_test_generation_data(
104
105
  file_path_to_use,
105
106
  test_cases,
106
107
  self.test_strategy,
108
+ module,
107
109
  class_name
108
- )
110
+ )"""
109
111
 
110
112
  test_file = self.generate_test_file(test_cases, output_path, module, class_name)
111
113
 
@@ -147,99 +149,51 @@ class Service:
147
149
  class_name = self.analysis_service.get_class_name(module)
148
150
  functions = self.inspect_class(class_name)
149
151
  return self.generator_service.generate_function_code(self.file_path, class_name, functions)
150
-
151
- def run_coverage(self, test_file):
152
- """Run coverage analysis on the generated tests."""
153
- Service.wait_for_file(test_file)
154
- file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
155
- self.logger.debug(f"File path to use for coverage: {file_path_to_use}")
156
- coverage_output = ""
157
-
158
- try:
159
- if self.test_format == UNITTEST_FORMAT:
160
- subprocess.run(["python", "-m", "coverage", "run", "--source=.", "-m", "unittest", test_file], check=True)
161
- result = subprocess.run(
162
- ["python", "-m", "coverage", "report", file_path_to_use],
163
- check=True,
164
- capture_output=True,
165
- text=True
166
- )
167
- coverage_output = result.stdout
168
- print(coverage_output)
169
- elif self.test_format == PYTEST_FORMAT:
170
- self.execute_and_store_pytest(test_file)
171
- elif self.test_format == DOCTEST_FORMAT:
172
- self.execute_and_store_doctest(test_file)
173
- else:
174
- raise ValueError("Unsupported test format for test results.")
175
152
 
176
- #Run coverage analysis
177
- subprocess.run(["python", "-m", "coverage", "run", "--source=.", test_file], check=True)
178
- result = subprocess.run(["python", "-m", "coverage", "report", file_path_to_use], check=True, capture_output=True, text=True)
179
- self._save_coverage_data(coverage_output, file_path_to_use)
180
- coverage_output = result.stdout
153
+ def run_tests(self, test_file: str):
154
+ # Run execute tests, would collect results but currently not saving them to db
155
+ _ = self.test_executor_service.execute_tests(test_file, self.test_format)
156
+ """
157
+ if results is None:
158
+ raise RuntimeError("No test results returned from the test executor service.")
159
+ else:
160
+ if self.db_service:
161
+ self.test_executor_service.save_test_results(self.db_service, results, self.file_path, self.test_format)
162
+ """
181
163
 
182
- self._save_coverage_data(coverage_output, file_path_to_use)
164
+ def run_coverage(self, test_file: str):
165
+ file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
183
166
 
184
- except subprocess.CalledProcessError as e:
185
- raise RuntimeError(f"Error running coverage subprocess: {e}")
167
+ self.logger.debug(f"Running coverage on: {test_file}")
168
+ coverage_data = self.coverage_service.run_coverage(test_file, file_path_to_use)
186
169
 
187
- def _save_coverage_data(self, coverage_output, file_path):
188
- """Parse coverage output and save to database."""
189
- # Skip if running in Docker or DB service is None
190
170
  if os.environ.get("RUNNING_IN_DOCKER") is not None or self.db_service is None:
191
- self.logger.debug("Skipping database operations in Docker container")
171
+ self.logger.debug("Skipping database operations - running in Docker or no DB service")
192
172
  return
173
+
174
+ # Save test results and coverage data
175
+ """self.coverage_service.save_coverage_data(
176
+ db_service=self.db_service,
177
+ coverage_data=coverage_data,
178
+ file_path=file_path_to_use
179
+ )"""
180
+
181
+ self._print_coverage_summary(file_path_to_use, coverage_data)
182
+
183
+ @staticmethod
184
+ def _print_coverage_summary(file_path: str, coverage_data: CoverageData):
185
+ print("\nCoverage Summary:")
186
+ print(f"File: {file_path}")
187
+ total_lines = coverage_data.missed_lines + coverage_data.executed_lines
188
+ print(f"Total lines: {total_lines}")
189
+ print(f"Executed lines: {coverage_data.executed_lines}")
190
+ print(f"Missed lines: {coverage_data.missed_lines}")
193
191
 
194
- try:
195
- lines = coverage_output.strip().split('\n')
196
- if not lines:
197
- raise ValueError("No coverage data found in the output.")
198
- else:
199
- for line in lines:
200
- if file_path in line:
201
- parts = line.split()
202
- if len(parts) >= 4:
203
- file_name = os.path.basename(file_path)
204
- try:
205
- total_lines = int(parts[-3])
206
- missed_lines = int(parts[-2])
207
- executed_lines = total_lines - missed_lines
208
- coverage_str = parts[-1].strip('%')
209
- branch_coverage = float(coverage_str) / 100
210
-
211
- source_file_id = self._get_source_file_id(file_path)
212
-
213
- self.db_service.insert_coverage_data(
214
- file_name,
215
- executed_lines,
216
- missed_lines,
217
- branch_coverage,
218
- source_file_id
219
- )
220
- break
221
- except (ValueError, IndexError) as e:
222
- print(f"Error parsing coverage data: {e}")
223
- except Exception as e:
224
- print(f"Error saving coverage data: {e}")
225
-
226
- def _get_source_file_id(self, file_path):
227
- """Helper to get source file ID from DB."""
228
- conn = sqlite3.connect(self.db_service.db_name)
229
- conn.row_factory = sqlite3.Row
230
- cursor = conn.cursor()
231
-
232
- cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (file_path,))
233
- row = cursor.fetchone()
234
-
235
- conn.close()
236
-
237
- if row:
238
- return row[0]
192
+ if total_lines > 0:
193
+ percentage = (coverage_data.executed_lines / total_lines) * 100
194
+ print(f"Coverage: {percentage:.2f}%")
239
195
  else:
240
- with open(file_path, 'r') as f:
241
- lines_of_code = len(f.readlines())
242
- return self.db_service.insert_source_file(file_path, lines_of_code)
196
+ print("Coverage: N/A (no lines to cover)")
243
197
 
244
198
  def serialize_test_cases(self, test_cases):
245
199
  """Serialize input arguments ot JSON-compatible format"""
@@ -315,36 +269,6 @@ class Service:
315
269
  """Create an analysis context for the given file."""
316
270
  return self.analysis_service.create_analysis_context(filepath)
317
271
 
318
- def get_coverage(self, file_path: str):
319
- """
320
- Use the coverage library to calculate and print the coverage for the specified Python file.
321
- Dynamically determine the source directory based on the file being tested.
322
- """
323
- # Dynamically determine the source directory
324
- source_dir = os.path.dirname(file_path)
325
- cov = coverage.Coverage(source=[source_dir]) # Use the directory of the file as the source
326
- cov.start()
327
-
328
- try:
329
- # Dynamically import and execute the specified file
330
- file_name = os.path.basename(file_path)
331
- module_name = file_name.rstrip(".py")
332
- spec = importlib.util.spec_from_file_location(module_name, file_path)
333
- module = importlib.util.module_from_spec(spec)
334
- spec.loader.exec_module(module)
335
-
336
- except Exception as e:
337
- print(f"Error while executing the file: {e}")
338
- return
339
-
340
- finally:
341
- cov.stop()
342
- cov.save()
343
-
344
- # Report the coverage
345
- print(f"Coverage report for {file_path}:")
346
- cov.report(file=sys.stdout)
347
-
348
272
  @staticmethod
349
273
  def wait_for_file(file_path, retries=5, delay=1):
350
274
  """Wait for the generated file to appear."""
@@ -354,25 +278,6 @@ class Service:
354
278
  if not os.path.exists(file_path):
355
279
  raise FileNotFoundError(f"File '{file_path}' not found after waiting.")
356
280
 
357
- def get_full_import_path(self) -> str:
358
- """Get the full import path for the current file."""
359
- package_root = self.find_package_root()
360
- if not package_root:
361
- raise ImportError(f"Could not determine the package root for {self.file_path}.")
362
-
363
- module_path = os.path.abspath(self.file_path)
364
- rel_path = os.path.relpath(module_path, package_root)
365
- package_path = rel_path.replace(os.sep, ".")
366
-
367
- if package_path.endswith(".py"):
368
- package_path = package_path[:-3]
369
-
370
- package_name = os.path.basename(package_root)
371
- if not package_path.startswith(package_name + "."):
372
- package_path = package_name + "." + package_path
373
-
374
- return package_path
375
-
376
281
  def find_package_root(self):
377
282
  """Find the package root directory."""
378
283
  current_dir = os.path.abspath(os.path.dirname(self.file_path))
@@ -446,91 +351,6 @@ class Service:
446
351
  if hasattr(self ,'analysis_service'):
447
352
  self.analysis_service.set_reinforcement_mode(mode)
448
353
 
449
- def _get_test_case_id(self, test_case_name: str) -> int:
450
- """
451
- Retrieve the test case ID from the database based on the test case name.
452
- Insert the test case if it does not exist.
453
- """
454
- if self.db_service is None:
455
- raise RuntimeError("Database service is not initialized.")
456
-
457
- # Query the database for the test case ID
458
- self.db_service.cursor.execute(
459
- "SELECT id FROM TestCase WHERE name = ?",
460
- (test_case_name,)
461
- )
462
- result = self.db_service.cursor.fetchone()
463
-
464
- if result:
465
- return result[0] # Return the test case ID
466
- else:
467
- # Insert the test case into the database
468
- self.db_service.cursor.execute(
469
- "INSERT INTO TestCase (name) VALUES (?)",
470
- (test_case_name,)
471
- )
472
- self.db_service.conn.commit()
473
- return self.db_service.cursor.lastrowid
474
-
475
- def execute_and_store_pytest(self, test_file):
476
- import pytest
477
- from _pytest.reports import TestReport
478
-
479
- class PytestResultPlugin:
480
- def __init__(self, db_service, get_test_case_id):
481
- self.db_service = db_service
482
- self.get_test_case_id = get_test_case_id
483
-
484
- def pytest_runtest_logreport(self, report: TestReport):
485
- if report.when == "call":
486
- test_case_id = self.get_test_case_id(report.nodeid)
487
- status = report.outcome == "passed"
488
- error_message = report.longreprtext if report.outcome == "failed" else None
489
- self.db_service.insert_test_result(test_case_id, status, error_message)
490
-
491
- pytest.main([test_file], plugins=[PytestResultPlugin(self.db_service, self._get_test_case_id)])
492
-
493
- pytest.main([test_file])
494
-
495
- def execute_and_store_unittest(self, file_path_to_use, test_file):
496
- import unittest
497
- loader = unittest.TestLoader()
498
- self.logger.debug(f"Discovering tests in: {os.path.dirname(file_path_to_use)} with pattern: {os.path.basename(test_file)}")
499
- test_module = os.path.relpath(test_file,
500
- start=os.getcwd()) # Get relative path from the current working directory
501
- test_module = test_module.replace("/", ".").replace("\\", ".").rstrip(".py") # Convert to module name
502
- if test_module.startswith("."):
503
- test_module = test_module[1:] # Remove leading dot if present
504
- self.logger.debug(f"Test module: {test_module}")
505
- suite = loader.loadTestsFromName(test_module)
506
- runner = unittest.TextTestRunner()
507
- result = runner.run(suite)
508
-
509
- for test_case, traceback in result.failures + result.errors:
510
- test_case_id = self._get_test_case_id(str(test_case))
511
- self.db_service.insert_test_result(test_case_id, status=False, error=traceback)
512
-
513
- successful_tests = set(str(test) for test in suite) - set(
514
- str(test) for test, _ in result.failures + result.errors)
515
- for test_case in successful_tests:
516
- test_case_id = self._get_test_case_id(str(test_case))
517
- self.db_service.insert_test_result(test_case_id, status=True, error=None)
518
-
519
- def execute_and_store_doctest(self, test_file):
520
- module_name = os.path.splitext(os.path.basename(test_file))[0]
521
- spec = importlib.util.spec_from_file_location(module_name, test_file)
522
- module = importlib.util.module_from_spec(spec)
523
- sys.modules[module_name] = module
524
- spec.loader.exec_module(module)
525
-
526
- # Now run doctests on the loaded module
527
- result = doctest.testmod(module)
528
-
529
- test_case_id = self._get_test_case_id(test_file)
530
- status = result.failed == 0
531
- error_message = f"{result.failed} of {result.attempted} tests failed" if result.failed > 0 else None
532
- self.db_service.insert_test_result(test_case_id, status, error_message)
533
-
534
354
  def query_test_file_data(self, test_file_name: str):
535
355
  if self.db_service is None:
536
356
  raise RuntimeError("Database service is not initialized.")
@@ -546,6 +366,17 @@ class Service:
546
366
  print(f"Results for file: {test_file_name}")
547
367
  print(tabulate(rows, headers="keys", tablefmt="grid"))
548
368
 
369
+ def get_all_functions(self, file_path: str):
370
+ functions = testgen.util.utils.get_list_of_functions(file_path)
371
+ for func in functions:
372
+ print(f"Function: {func.name}")
373
+ for attr, value in vars(func).items():
374
+ if attr == "_source_file_id":
375
+ continue
376
+ else:
377
+ print(f" {attr}: {value}")
378
+ return
379
+
549
380
  def select_all_from_db(self) -> None:
550
381
  rows = self.db_service.get_test_suites()
551
382
  for row in rows:
@@ -0,0 +1,145 @@
1
+ import doctest
2
+ import unittest
3
+ import subprocess
4
+ import sys
5
+ import os
6
+ from typing import Dict, List, Any
7
+
8
+ import pytest
9
+ from _pytest.reports import TestReport
10
+
11
+ from testgen.service.logging_service import get_logger
12
+ from testgen.util import utils
13
+ from testgen.service.db_service import DBService
14
+
15
+
16
+ class TestExecutorService:
17
+ UNITTEST_FORMAT = 1
18
+ PYTEST_FORMAT = 2
19
+ DOCTEST_FORMAT = 3
20
+
21
+ def __init__(self):
22
+ self.logger = get_logger()
23
+
24
+ def execute_tests(self, test_file: str, test_format: int) -> List[Dict[str, Any]]:
25
+
26
+ test_file = os.path.abspath(test_file)
27
+ if not os.path.exists(test_file):
28
+ self.logger.error(f"Test file not found: {test_file}")
29
+ return [{"name": f"{test_file}::file_not_found", "status": False, "error": "Test file not found"}]
30
+
31
+ try:
32
+ if test_format == self.UNITTEST_FORMAT:
33
+ return self.execute_unittest(test_file)
34
+ elif test_format == self.PYTEST_FORMAT:
35
+ return self.execute_pytest(test_file)
36
+ elif test_format == self.DOCTEST_FORMAT:
37
+ return self.execute_doctest(test_file)
38
+ else:
39
+ self.logger.error(f"Unsupported test format: {test_format}")
40
+ return [{"name": f"{test_file}::invalid_format", "status": False,
41
+ "error": f"Unsupported test format: {test_format}"}]
42
+
43
+ except Exception as e:
44
+ self.logger.error(f"Error executing tests: {str(e)}")
45
+ return [{"name": f"{test_file}::execution_error", "status": False, "error": str(e)}]
46
+
47
+ # Currently not collecting results
48
+ def execute_unittest(self, test_file: str) -> List[Dict[str, Any]]:
49
+ print(f"Running unittest on: {test_file}")
50
+ result = subprocess.run(
51
+ [sys.executable, "-m", "unittest", test_file],
52
+ capture_output=True,
53
+ text=True
54
+ )
55
+ print(result.stdout)
56
+ print(result.stderr)
57
+ # Would be used to return results
58
+ return []
59
+
60
+ def execute_pytest(self, test_file: str) -> List[Dict[str, Any]]:
61
+ print(f"Executing pytest file: {test_file}")
62
+ results = []
63
+
64
+ try:
65
+ # Custom plugin to collect results
66
+ class PytestResultCollector:
67
+ def __init__(self):
68
+ self.results = []
69
+
70
+ def pytest_runtest_logreport(self, report: TestReport):
71
+ if report.when == "call":
72
+ self.results.append({
73
+ "name": report.nodeid,
74
+ "status": report.outcome == "passed",
75
+ "error": report.longreprtext if report.outcome != "passed" else None
76
+ })
77
+
78
+ # Run pytest and collect results
79
+ collector = PytestResultCollector()
80
+ pytest.main([test_file], plugins=[collector])
81
+ results = collector.results
82
+
83
+ except Exception as e:
84
+ self.logger.error(f"Error running pytest: {e}")
85
+ results.append({
86
+ "name": f"{test_file}::pytest_execution",
87
+ "status": False,
88
+ "error": str(e)
89
+ })
90
+
91
+ return results
92
+
93
+ # Currently not collecting results
94
+ def execute_doctest(self, test_file: str) -> List[Dict[str, Any]]:
95
+ print(f"Running doctest on: {test_file}")
96
+ #failed, attempted = doctest.testfile(test_file, module_relative=False, verbose=False)
97
+ # doctest prints results by itself when verbose=True
98
+
99
+ # Would be used to collect and return results
100
+ return []
101
+
102
+ # Currently not in use and not working to save the test results since I am having difficulty
103
+ # getting the test case ID for each method and ran out of time :(
104
+ def save_test_results(self, db_service: DBService, test_results: List[Dict[str, Any]],
105
+ file_path: str, test_format: int) -> None:
106
+ if db_service is None:
107
+ self.logger.debug("Skipping database operations - no DB service provided")
108
+ return
109
+
110
+ try:
111
+ source_file_id = db_service.get_source_file_id_by_path(file_path)
112
+
113
+ if source_file_id == -1:
114
+ self.logger.error(f"Source file not found in database: {file_path}")
115
+ return
116
+
117
+ functions = db_service.get_functions_by_file(file_path)
118
+
119
+ for result in test_results:
120
+ name = result["name"]
121
+ test_case = utils.parse_test_case_from_result_name(name, test_format)
122
+ print(f"SAVE TEST RESULTS TEST CASE {test_case}")
123
+ function_id = db_service.match_test_case_to_function_for_id(source_file_id, test_case, functions)
124
+ print(f"SAVE TEST RESULTS FUNCTION ID {function_id}")
125
+
126
+ if function_id == -1:
127
+ self.logger.warning(f"Could not match test case {name} to a function")
128
+ continue
129
+
130
+ inputs_str = str(test_case.inputs)
131
+ print(f"Inputs str {inputs_str}")
132
+ expected_str = str(test_case.expected)
133
+ print(f"Expected str {expected_str}")
134
+
135
+ test_case_id = db_service.get_test_case_id_by_func_id_input_expected(
136
+ function_id, inputs_str, expected_str)
137
+
138
+ if test_case_id == -1:
139
+ self.logger.warning(f"Test case not found in database: {name}")
140
+ continue
141
+
142
+ db_service.insert_test_result(test_case_id, result["status"], result["error"])
143
+
144
+ except Exception as e:
145
+ self.logger.error(f"Error saving test results to database: {e}")