testgenie-py 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/random_feedback_analyzer.py +344 -114
- testgen/controller/cli_controller.py +5 -10
- testgen/generated_samplecodebin.py +545 -0
- testgen/generator/code_generator.py +36 -18
- testgen/reinforcement/environment.py +1 -1
- testgen/service/generator_service.py +16 -1
- testgen/service/service.py +124 -32
- testgen/sqlite/db_service.py +22 -2
- testgen/util/coverage_utils.py +35 -0
- testgen/util/randomizer.py +29 -12
- testgenie_py-0.2.3.dist-info/METADATA +139 -0
- {testgenie_py-0.2.1.dist-info → testgenie_py-0.2.3.dist-info}/RECORD +14 -28
- testgen/.coverage +0 -0
- testgen/code_to_test/__init__.py +0 -0
- testgen/code_to_test/boolean.py +0 -146
- testgen/code_to_test/calculator.py +0 -29
- testgen/code_to_test/code_to_fuzz.py +0 -234
- testgen/code_to_test/code_to_fuzz_lite.py +0 -397
- testgen/code_to_test/decisions.py +0 -57
- testgen/code_to_test/math_utils.py +0 -47
- testgen/code_to_test/no_types.py +0 -35
- testgen/code_to_test/sample_code_bin.py +0 -141
- testgen/q_table/global_q_table.json +0 -1
- testgen/testgen.db +0 -0
- testgen/tests/__init__.py +0 -0
- testgen/tests/test_boolean.py +0 -69
- testgen/tests/test_decisions.py +0 -195
- testgenie_py-0.2.1.dist-info/METADATA +0 -26
- {testgenie_py-0.2.1.dist-info → testgenie_py-0.2.3.dist-info}/WHEEL +0 -0
- {testgenie_py-0.2.1.dist-info → testgenie_py-0.2.3.dist-info}/entry_points.txt +0 -0
@@ -4,6 +4,7 @@ import random
|
|
4
4
|
import time
|
5
5
|
import traceback
|
6
6
|
from typing import List, Dict, Set
|
7
|
+
import z3
|
7
8
|
|
8
9
|
import testgen.util.randomizer
|
9
10
|
import testgen.util.utils as utils
|
@@ -16,6 +17,8 @@ from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
|
16
17
|
from abc import ABC
|
17
18
|
|
18
19
|
from testgen.models.function_metadata import FunctionMetadata
|
20
|
+
from testgen.util.z3_utils.constraint_extractor import extract_branch_conditions
|
21
|
+
from testgen.util.z3_utils.ast_to_z3 import ast_to_z3_constraint
|
19
22
|
|
20
23
|
|
21
24
|
# Citation in which this method and algorithm were taken from:
|
@@ -29,84 +32,61 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
29
32
|
super().__init__(analysis_context)
|
30
33
|
self.test_cases = []
|
31
34
|
self.covered_lines: Dict[str, Set[int]] = {}
|
35
|
+
self.covered_functions: Set[str] = set()
|
32
36
|
|
33
|
-
|
34
|
-
|
35
|
-
# Contracts express invariant properties that hold both at entry and exit from a call
|
36
|
-
# Contract takes as input the current state of the system (runtime values created in the sequence so far, and any exception thrown by the last call), and returns satisfied or violated
|
37
|
-
# Output is the runtime values and boolean flag violated
|
38
|
-
# Filters determine which values of a sequence are extensible and should be used as inputs
|
39
|
-
def generate_sequences(self, function_metadata: List[FunctionMetadata], classes=None, contracts: List[Contract] = None, filters=None, time_limit=20):
|
40
|
-
contracts = [NonNullContract(), NoExceptionContract()]
|
41
|
-
error_seqs = [] # execution violates a contract
|
42
|
-
non_error_seqs = [] # execution does not violate a contract
|
43
|
-
|
44
|
-
functions = self._analysis_context.function_data
|
45
|
-
start_time = time.time()
|
46
|
-
while(time.time() - start_time) >= time_limit:
|
47
|
-
# Get random function
|
48
|
-
func = random.choice(functions)
|
49
|
-
param_types: dict = func.params
|
50
|
-
vals: dict = self.random_seqs_and_vals(param_types)
|
51
|
-
new_seq = (func.function_name, vals)
|
52
|
-
if new_seq in error_seqs or new_seq in non_error_seqs:
|
53
|
-
continue
|
54
|
-
outs_violated: tuple = self.execute_sequence(new_seq, contracts)
|
55
|
-
violated: bool = outs_violated[1]
|
56
|
-
# Create tuple of sequence ((func name, args), output)
|
57
|
-
new_seq_out = (new_seq, outs_violated[0])
|
58
|
-
if violated:
|
59
|
-
error_seqs.append(new_seq_out)
|
60
|
-
else:
|
61
|
-
# Question: Should I use the failed contract to be the assertion in unit test??
|
62
|
-
non_error_seqs.append(new_seq_out)
|
63
|
-
return error_seqs, non_error_seqs
|
64
|
-
|
65
|
-
def generate_sequences_new(self, contracts: List[Contract] = None, filters=None, time_limit=20):
|
66
|
-
contracts = [NonNullContract(), NoExceptionContract()]
|
67
|
-
error_seqs = [] # execution violates a contract
|
68
|
-
non_error_seqs = [] # execution does not violate a contract
|
69
|
-
|
70
|
-
functions = self._analysis_context.function_data.copy()
|
37
|
+
def collect_test_cases(self, function_metadata: FunctionMetadata, time_limit: int = 5) -> List[TestCase]:
|
38
|
+
self.test_cases = []
|
71
39
|
start_time = time.time()
|
72
|
-
|
40
|
+
|
73
41
|
while (time.time() - start_time) < time_limit:
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
42
|
+
|
43
|
+
try:
|
44
|
+
param_values = self.generate_random_inputs(function_metadata.params)
|
45
|
+
module = self.analysis_context.module
|
46
|
+
func_name = function_metadata.function_name
|
47
|
+
|
48
|
+
if self._analysis_context.class_name:
|
49
|
+
cls = getattr(module, self._analysis_context.class_name)
|
50
|
+
obj = cls()
|
51
|
+
function = getattr(obj, func_name)
|
52
|
+
else:
|
53
|
+
function = getattr(module, func_name)
|
54
|
+
|
55
|
+
import inspect
|
56
|
+
sig = inspect.signature(function)
|
57
|
+
param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
|
58
|
+
|
59
|
+
ordered_args = []
|
60
|
+
for name in param_names:
|
61
|
+
if name in param_values:
|
62
|
+
ordered_args.append(param_values[name])
|
63
|
+
else:
|
64
|
+
ordered_args.append(None)
|
65
|
+
|
66
|
+
result = function(*ordered_args)
|
67
|
+
test_case = TestCase(func_name, tuple(ordered_args), result)
|
68
|
+
|
69
|
+
if not self.is_duplicate_test_case(test_case):
|
70
|
+
self.test_cases.append(test_case)
|
71
|
+
|
72
|
+
covered = self.covered(function_metadata)
|
73
|
+
if covered:
|
74
|
+
break
|
75
|
+
else:
|
76
|
+
# Optionally log duplicate detection
|
77
|
+
print(f"Skipping duplicate test case: {func_name}{test_case.inputs}")
|
78
|
+
|
79
|
+
except Exception as e:
|
80
|
+
print(f"Error testing {function_metadata.function_name}: {e}")
|
109
81
|
|
82
|
+
return self.test_cases
|
83
|
+
|
84
|
+
def is_duplicate_test_case(self, new_test_case: TestCase) -> bool:
|
85
|
+
for existing_test_case in self.test_cases:
|
86
|
+
if (existing_test_case.func_name == new_test_case.func_name and
|
87
|
+
existing_test_case.inputs == new_test_case.inputs):
|
88
|
+
return True
|
89
|
+
return False
|
110
90
|
|
111
91
|
def covered(self, func: FunctionMetadata) -> bool:
|
112
92
|
if func.function_name not in self.covered_lines:
|
@@ -117,8 +97,10 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
117
97
|
func.function_name, test_case.inputs)
|
118
98
|
covered = coverage_utils.get_list_of_covered_statements(analysis)
|
119
99
|
self.covered_lines[func.function_name].update(covered)
|
100
|
+
print(f"Covered lines for {func.function_name}: {self.covered_lines[func.function_name]}")
|
120
101
|
|
121
|
-
executable_statements = self.get_all_executable_statements(func)
|
102
|
+
executable_statements = set(self.get_all_executable_statements(func))
|
103
|
+
print(f"Executable statements for {func.function_name}: {executable_statements}")
|
122
104
|
|
123
105
|
return self.covered_lines[func.function_name] == executable_statements
|
124
106
|
|
@@ -203,44 +185,118 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
203
185
|
inputs = {}
|
204
186
|
for param, param_type in param_types.items():
|
205
187
|
if param_type == "int":
|
206
|
-
random_integer = random.randint(
|
188
|
+
random_integer = random.randint(-500, 500) # Wider range for better edge cases
|
207
189
|
inputs[param] = random_integer
|
208
|
-
|
190
|
+
elif param_type == "bool":
|
209
191
|
random_choice = random.choice([True, False])
|
210
192
|
inputs[param] = random_choice
|
211
|
-
|
212
|
-
random_float = random.
|
193
|
+
elif param_type == "float":
|
194
|
+
random_float = random.uniform(-500.0, 500.0) # Wider range for better edge cases
|
213
195
|
inputs[param] = random_float
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
196
|
+
elif param_type == "str":
|
197
|
+
# Generate diverse strings instead of always "abc"
|
198
|
+
string_type = random.choice([
|
199
|
+
"empty", "short", "medium", "long", "special", "numeric", "whitespace"
|
200
|
+
])
|
201
|
+
|
202
|
+
if string_type == "empty":
|
203
|
+
inputs[param] = ""
|
204
|
+
elif string_type == "short":
|
205
|
+
inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=random.randint(1, 3)))
|
206
|
+
elif string_type == "medium":
|
207
|
+
inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ', k=random.randint(4, 10)))
|
208
|
+
elif string_type == "long":
|
209
|
+
inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', k=random.randint(11, 30)))
|
210
|
+
elif string_type == "special":
|
211
|
+
inputs[param] = ''.join(random.choices('!@#$%^&*()_+-=[]{}|;:,./<>?', k=random.randint(1, 8)))
|
212
|
+
elif string_type == "numeric":
|
213
|
+
inputs[param] = ''.join(random.choices('0123456789', k=random.randint(1, 10)))
|
214
|
+
else: # whitespace
|
215
|
+
inputs[param] = ' ' * random.randint(1, 5)
|
216
|
+
else:
|
217
|
+
# For unknown types, try a default value
|
218
|
+
inputs[param] = None
|
219
|
+
|
221
220
|
return inputs
|
221
|
+
|
222
|
+
# Algorithm described in above article
|
223
|
+
# Classes is the classes for which we want to generate sequences
|
224
|
+
# Contracts express invariant properties that hold both at entry and exit from a call
|
225
|
+
# Contract takes as input the current state of the system (runtime values created in the sequence so far, and any exception thrown by the last call), and returns satisfied or violated
|
226
|
+
# Output is the runtime values and boolean flag violated
|
227
|
+
# Filters determine which values of a sequence are extensible and should be used as inputs
|
228
|
+
def generate_sequences(self, function_metadata: List[FunctionMetadata], classes=None, contracts: List[Contract] = None, filters=None, time_limit=20):
|
229
|
+
contracts = [NonNullContract(), NoExceptionContract()]
|
230
|
+
error_seqs = [] # execution violates a contract
|
231
|
+
non_error_seqs = [] # execution does not violate a contract
|
222
232
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
233
|
+
functions = self._analysis_context.function_data
|
234
|
+
start_time = time.time()
|
235
|
+
while(time.time() - start_time) >= time_limit:
|
236
|
+
# Get random function
|
237
|
+
func = random.choice(functions)
|
238
|
+
param_types: dict = func.params
|
239
|
+
vals: dict = self.random_seqs_and_vals(param_types)
|
240
|
+
new_seq = (func.function_name, vals)
|
241
|
+
if new_seq in error_seqs or new_seq in non_error_seqs:
|
242
|
+
continue
|
243
|
+
outs_violated: tuple = self.execute_sequence(new_seq, contracts)
|
244
|
+
violated: bool = outs_violated[1]
|
245
|
+
# Create tuple of sequence ((func name, args), output)
|
246
|
+
new_seq_out = (new_seq, outs_violated[0])
|
247
|
+
if violated:
|
248
|
+
error_seqs.append(new_seq_out)
|
249
|
+
else:
|
250
|
+
# Question: Should I use the failed contract to be the assertion in unit test??
|
251
|
+
non_error_seqs.append(new_seq_out)
|
252
|
+
return error_seqs, non_error_seqs
|
227
253
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
test_cases.append(TestCase(error_seq[0][0], tuple(error_seq[0][1].values()), error_seq[1]))
|
254
|
+
def generate_sequences_new(self, contracts: List[Contract] = None, filters=None, time_limit=20):
|
255
|
+
contracts = [NonNullContract(), NoExceptionContract()]
|
256
|
+
error_seqs = [] # execution violates a contract
|
257
|
+
non_error_seqs = [] # execution does not violate a contract
|
233
258
|
|
234
|
-
|
235
|
-
|
236
|
-
for non_error_seq in non_error_seqs:
|
237
|
-
print("NON ERROR SEQ OUTPUT:", non_error_seq[1])
|
238
|
-
test_cases.append(TestCase(non_error_seq[0][0], tuple(non_error_seq[0][1].values()), non_error_seq[1]))
|
259
|
+
functions = self._analysis_context.function_data.copy()
|
260
|
+
start_time = time.time()
|
239
261
|
|
240
|
-
|
262
|
+
while (time.time() - start_time) < time_limit:
|
263
|
+
# Get random function
|
264
|
+
func = random.choice(functions)
|
265
|
+
param_types: dict = func.params
|
266
|
+
vals: dict = self.random_seqs_and_vals(param_types)
|
267
|
+
new_seq = (func.function_name, vals)
|
241
268
|
|
269
|
+
if new_seq in [seq[0] for seq in error_seqs] or new_seq in [seq[0] for seq in non_error_seqs]:
|
270
|
+
continue
|
271
|
+
|
272
|
+
outs_violated: tuple = self.execute_sequence(new_seq, contracts)
|
273
|
+
violated: bool = outs_violated[1]
|
274
|
+
|
275
|
+
# Create tuple of sequence ((func name, args), output)
|
276
|
+
new_seq_out = (new_seq, outs_violated[0])
|
277
|
+
|
278
|
+
if violated:
|
279
|
+
error_seqs.append(new_seq_out)
|
280
|
+
|
281
|
+
else:
|
282
|
+
non_error_seqs.append(new_seq_out)
|
283
|
+
|
284
|
+
test_case = TestCase(new_seq_out[0][0], tuple(new_seq_out[0][1].values()), new_seq_out[1])
|
285
|
+
self.test_cases.append(test_case)
|
286
|
+
fully_covered = self.covered(func)
|
287
|
+
if fully_covered:
|
288
|
+
print(f"Function {func.function_name} is fully covered")
|
289
|
+
functions.remove(func)
|
290
|
+
|
291
|
+
if not functions:
|
292
|
+
self.test_cases.sort(key=lambda tc: tc.func_name)
|
293
|
+
print("All functions covered")
|
294
|
+
break
|
295
|
+
|
296
|
+
self.test_cases.sort(key=lambda tc: tc.func_name)
|
297
|
+
return error_seqs, non_error_seqs
|
298
|
+
|
242
299
|
def get_all_executable_statements(self, func: FunctionMetadata):
|
243
|
-
"""Get all executable statements including else branches"""
|
244
300
|
import ast
|
245
301
|
|
246
302
|
test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
|
@@ -254,32 +310,21 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
254
310
|
else:
|
255
311
|
analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, self._analysis_context.class_name, func.function_name, test_cases[0].inputs)
|
256
312
|
|
257
|
-
# Get standard executable lines from coverage.py
|
258
313
|
executable_lines = list(analysis[1])
|
259
314
|
|
260
|
-
# Parse the source file to find else branches
|
261
315
|
with open(self._analysis_context.filepath, 'r') as f:
|
262
316
|
source = f.read()
|
263
317
|
|
264
|
-
# Parse the code
|
265
318
|
tree = ast.parse(source)
|
266
319
|
|
267
|
-
# Find our specific function
|
268
320
|
for node in ast.walk(tree):
|
269
321
|
if isinstance(node, ast.FunctionDef) and node.name == func.func_def.name:
|
270
|
-
# Find all if statements in this function
|
271
322
|
for if_node in ast.walk(node):
|
272
323
|
if isinstance(if_node, ast.If) and if_node.orelse:
|
273
|
-
# There's an else branch
|
274
324
|
if isinstance(if_node.orelse[0], ast.If):
|
275
|
-
# This is an elif - already counted
|
276
325
|
continue
|
277
|
-
|
278
|
-
# Get the line number of the first statement in the else block
|
279
|
-
# and subtract 1 to get the 'else:' line
|
280
326
|
else_line = if_node.orelse[0].lineno - 1
|
281
327
|
|
282
|
-
# Check if this is actually an else line (not a nested if)
|
283
328
|
with open(self._analysis_context.filepath, 'r') as f:
|
284
329
|
lines = f.readlines()
|
285
330
|
if else_line <= len(lines):
|
@@ -288,4 +333,189 @@ class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
288
333
|
if else_line not in executable_lines:
|
289
334
|
executable_lines.append(else_line)
|
290
335
|
|
291
|
-
return sorted(executable_lines)
|
336
|
+
return sorted(executable_lines)
|
337
|
+
|
338
|
+
"""
|
339
|
+
def collect_test_cases_with_z3(self, function_metadata: FunctionMetadata) -> List[TestCase]:
|
340
|
+
test_cases = []
|
341
|
+
|
342
|
+
z3_test_cases = self.generate_z3_test_cases(function_metadata)
|
343
|
+
if z3_test_cases:
|
344
|
+
test_cases.extend(z3_test_cases)
|
345
|
+
|
346
|
+
if not test_cases:
|
347
|
+
test_cases = self.generate_sequences_new()[1]
|
348
|
+
|
349
|
+
self.test_cases = test_cases
|
350
|
+
return test_cases
|
351
|
+
|
352
|
+
def generate_z3_test_cases(self, function_metadata: FunctionMetadata) -> List[TestCase]:
|
353
|
+
test_cases = []
|
354
|
+
|
355
|
+
branch_conditions, param_types = extract_branch_conditions(function_metadata.func_def)
|
356
|
+
|
357
|
+
if not branch_conditions:
|
358
|
+
random_inputs = self.generate_random_inputs(function_metadata.params)
|
359
|
+
try:
|
360
|
+
module = self.analysis_context.module
|
361
|
+
func_name = function_metadata.function_name
|
362
|
+
|
363
|
+
if self._analysis_context.class_name:
|
364
|
+
cls = getattr(module, self._analysis_context.class_name)
|
365
|
+
obj = cls()
|
366
|
+
func = getattr(obj, func_name)
|
367
|
+
ordered_args = self._order_arguments(func, random_inputs)
|
368
|
+
output = func(*ordered_args)
|
369
|
+
else:
|
370
|
+
func = getattr(module, func_name)
|
371
|
+
ordered_args = self._order_arguments(func, random_inputs)
|
372
|
+
output = func(*ordered_args)
|
373
|
+
|
374
|
+
test_cases.append(TestCase(func_name, tuple(ordered_args), output))
|
375
|
+
except Exception as e:
|
376
|
+
print(f"Error executing function with random inputs: {e}")
|
377
|
+
|
378
|
+
return test_cases
|
379
|
+
|
380
|
+
for branch_condition in branch_conditions:
|
381
|
+
try:
|
382
|
+
z3_expr, z3_vars = ast_to_z3_constraint(branch_condition, function_metadata.params)
|
383
|
+
|
384
|
+
solver = z3.Solver()
|
385
|
+
solver.add(z3_expr)
|
386
|
+
|
387
|
+
neg_solver = z3.Solver()
|
388
|
+
neg_solver.add(z3.Not(z3_expr))
|
389
|
+
|
390
|
+
for current_solver in [solver, neg_solver]:
|
391
|
+
if current_solver.check() == z3.sat:
|
392
|
+
model = current_solver.model()
|
393
|
+
|
394
|
+
param_values = self._extract_z3_solution(model, z3_vars, function_metadata.params)
|
395
|
+
|
396
|
+
ordered_params = self._order_parameters(function_metadata.func_def, param_values)
|
397
|
+
|
398
|
+
try:
|
399
|
+
module = self.analysis_context.module
|
400
|
+
func_name = function_metadata.function_name
|
401
|
+
|
402
|
+
if self._analysis_context.class_name:
|
403
|
+
cls = getattr(module, self._analysis_context.class_name)
|
404
|
+
obj = cls()
|
405
|
+
func = getattr(obj, func_name)
|
406
|
+
else:
|
407
|
+
func = getattr(module, func_name)
|
408
|
+
|
409
|
+
result = func(*ordered_params)
|
410
|
+
test_cases.append(TestCase(func_name, tuple(ordered_params), result))
|
411
|
+
except Exception as e:
|
412
|
+
print(f"Error executing function with Z3 solution: {e}")
|
413
|
+
self._add_random_test_case(function_metadata, test_cases)
|
414
|
+
else:
|
415
|
+
self._add_random_test_case(function_metadata, test_cases)
|
416
|
+
|
417
|
+
except Exception as e:
|
418
|
+
print(f"Error processing branch condition with Z3: {e}")
|
419
|
+
self._add_random_test_case(function_metadata, test_cases)
|
420
|
+
|
421
|
+
return test_cases
|
422
|
+
|
423
|
+
def _extract_z3_solution(self, model, z3_vars, param_types):
|
424
|
+
param_values = {}
|
425
|
+
|
426
|
+
for var_name, z3_var in z3_vars.items():
|
427
|
+
if var_name in param_types:
|
428
|
+
try:
|
429
|
+
model_value = model.evaluate(z3_var)
|
430
|
+
|
431
|
+
if param_types[var_name] == "int":
|
432
|
+
param_values[var_name] = model_value.as_long()
|
433
|
+
elif param_types[var_name] == "float":
|
434
|
+
param_values[var_name] = float(model_value.as_decimal(10))
|
435
|
+
elif param_types[var_name] == "bool":
|
436
|
+
param_values[var_name] = z3.is_true(model_value)
|
437
|
+
elif param_types[var_name] == "str":
|
438
|
+
str_val = str(model_value)
|
439
|
+
if str_val.startswith('"') and str_val.endswith('"'):
|
440
|
+
str_val = str_val[1:-1]
|
441
|
+
param_values[var_name] = str_val
|
442
|
+
else:
|
443
|
+
# Default to int for unrecognized types
|
444
|
+
param_values[var_name] = model_value.as_long()
|
445
|
+
except Exception as e:
|
446
|
+
print(f"Couldn't get {var_name} from model: {e}")
|
447
|
+
# Use default values for parameters not in the model
|
448
|
+
if param_types[var_name] == "int":
|
449
|
+
param_values[var_name] = 0
|
450
|
+
elif param_types[var_name] == "float":
|
451
|
+
param_values[var_name] = 0.0
|
452
|
+
elif param_types[var_name] == "bool":
|
453
|
+
param_values[var_name] = False
|
454
|
+
elif param_types[var_name] == "str":
|
455
|
+
param_values[var_name] = ""
|
456
|
+
else:
|
457
|
+
param_values[var_name] = None
|
458
|
+
|
459
|
+
return param_values
|
460
|
+
|
461
|
+
def _order_parameters(self, func_node, param_values):
|
462
|
+
ordered_params = []
|
463
|
+
|
464
|
+
for arg in func_node.args.args:
|
465
|
+
arg_name = arg.arg
|
466
|
+
if arg_name == 'self': # Skip self parameter
|
467
|
+
continue
|
468
|
+
if arg_name in param_values:
|
469
|
+
ordered_params.append(param_values[arg_name])
|
470
|
+
else:
|
471
|
+
# Default value handling if parameter not in solution
|
472
|
+
if arg.annotation and hasattr(arg.annotation, 'id'):
|
473
|
+
if arg.annotation.id == 'int':
|
474
|
+
ordered_params.append(0)
|
475
|
+
elif arg.annotation.id == 'float':
|
476
|
+
ordered_params.append(0.0)
|
477
|
+
elif arg.annotation.id == 'bool':
|
478
|
+
ordered_params.append(False)
|
479
|
+
elif arg.annotation.id == 'str':
|
480
|
+
ordered_params.append('')
|
481
|
+
else:
|
482
|
+
ordered_params.append(None)
|
483
|
+
else:
|
484
|
+
ordered_params.append(None)
|
485
|
+
|
486
|
+
return ordered_params
|
487
|
+
|
488
|
+
def _order_arguments(self, func, args_dict):
|
489
|
+
import inspect
|
490
|
+
sig = inspect.signature(func)
|
491
|
+
param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
|
492
|
+
|
493
|
+
ordered_args = []
|
494
|
+
for name in param_names:
|
495
|
+
if name in args_dict:
|
496
|
+
ordered_args.append(args_dict[name])
|
497
|
+
else:
|
498
|
+
ordered_args.append(None) # Default to None if missing
|
499
|
+
|
500
|
+
return ordered_args
|
501
|
+
|
502
|
+
def _add_random_test_case(self, function_metadata, test_cases):
|
503
|
+
random_inputs = self.generate_random_inputs(function_metadata.params)
|
504
|
+
try:
|
505
|
+
module = self.analysis_context.module
|
506
|
+
func_name = function_metadata.function_name
|
507
|
+
|
508
|
+
if self._analysis_context.class_name:
|
509
|
+
cls = getattr(module, self._analysis_context.class_name)
|
510
|
+
obj = cls()
|
511
|
+
func = getattr(obj, func_name)
|
512
|
+
else:
|
513
|
+
func = getattr(module, func_name)
|
514
|
+
|
515
|
+
ordered_args = self._order_arguments(func, random_inputs)
|
516
|
+
|
517
|
+
output = func(*ordered_args)
|
518
|
+
test_cases.append(TestCase(func_name, tuple(ordered_args), output))
|
519
|
+
except Exception as e:
|
520
|
+
print(f"Error executing function with random inputs: {e}")
|
521
|
+
"""
|
@@ -42,10 +42,9 @@ class CLIController:
|
|
42
42
|
|
43
43
|
logger = get_logger()
|
44
44
|
|
45
|
-
if args.
|
46
|
-
|
47
|
-
|
48
|
-
self.service.select_all_from_db()
|
45
|
+
if args.query:
|
46
|
+
print(f"Querying database for file: {args.file_path}")
|
47
|
+
self.service.query_test_file_data(args.file_path)
|
49
48
|
return
|
50
49
|
|
51
50
|
running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
|
@@ -111,6 +110,7 @@ class CLIController:
|
|
111
110
|
parser = argparse.ArgumentParser(description="A CLI tool for generating unit tests.")
|
112
111
|
parser.add_argument("file_path", type=str, help="Path to the Python file.")
|
113
112
|
parser.add_argument("--output", "-o", type=str, help="Path to output directory.")
|
113
|
+
parser.add_argument("-q", "--query", action="store_true", help="Query the database for test cases, coverage data, and test results for a specific file")
|
114
114
|
parser.add_argument(
|
115
115
|
"--generate-only", "-g",
|
116
116
|
action="store_true",
|
@@ -146,12 +146,7 @@ class CLIController:
|
|
146
146
|
help="Path to SQLite database file (default: testgen.db)"
|
147
147
|
)
|
148
148
|
parser.add_argument(
|
149
|
-
"
|
150
|
-
action="store_true",
|
151
|
-
help="Select all from sqlite db"
|
152
|
-
)
|
153
|
-
parser.add_argument(
|
154
|
-
"--visualize",
|
149
|
+
"-viz", "--visualize",
|
155
150
|
action="store_true",
|
156
151
|
help = "Visualize the tests with graphviz"
|
157
152
|
)
|