codeshift 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/__init__.py +8 -0
- codeshift/analyzer/__init__.py +5 -0
- codeshift/analyzer/risk_assessor.py +388 -0
- codeshift/api/__init__.py +1 -0
- codeshift/api/auth.py +182 -0
- codeshift/api/config.py +73 -0
- codeshift/api/database.py +215 -0
- codeshift/api/main.py +103 -0
- codeshift/api/models/__init__.py +55 -0
- codeshift/api/models/auth.py +108 -0
- codeshift/api/models/billing.py +92 -0
- codeshift/api/models/migrate.py +42 -0
- codeshift/api/models/usage.py +116 -0
- codeshift/api/routers/__init__.py +5 -0
- codeshift/api/routers/auth.py +440 -0
- codeshift/api/routers/billing.py +395 -0
- codeshift/api/routers/migrate.py +304 -0
- codeshift/api/routers/usage.py +291 -0
- codeshift/api/routers/webhooks.py +289 -0
- codeshift/cli/__init__.py +5 -0
- codeshift/cli/commands/__init__.py +7 -0
- codeshift/cli/commands/apply.py +352 -0
- codeshift/cli/commands/auth.py +842 -0
- codeshift/cli/commands/diff.py +221 -0
- codeshift/cli/commands/scan.py +368 -0
- codeshift/cli/commands/upgrade.py +436 -0
- codeshift/cli/commands/upgrade_all.py +518 -0
- codeshift/cli/main.py +221 -0
- codeshift/cli/quota.py +210 -0
- codeshift/knowledge/__init__.py +50 -0
- codeshift/knowledge/cache.py +167 -0
- codeshift/knowledge/generator.py +231 -0
- codeshift/knowledge/models.py +151 -0
- codeshift/knowledge/parser.py +270 -0
- codeshift/knowledge/sources.py +388 -0
- codeshift/knowledge_base/__init__.py +17 -0
- codeshift/knowledge_base/loader.py +102 -0
- codeshift/knowledge_base/models.py +110 -0
- codeshift/migrator/__init__.py +23 -0
- codeshift/migrator/ast_transforms.py +256 -0
- codeshift/migrator/engine.py +395 -0
- codeshift/migrator/llm_migrator.py +320 -0
- codeshift/migrator/transforms/__init__.py +19 -0
- codeshift/migrator/transforms/fastapi_transformer.py +174 -0
- codeshift/migrator/transforms/pandas_transformer.py +236 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
- codeshift/migrator/transforms/requests_transformer.py +218 -0
- codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
- codeshift/scanner/__init__.py +6 -0
- codeshift/scanner/code_scanner.py +352 -0
- codeshift/scanner/dependency_parser.py +473 -0
- codeshift/utils/__init__.py +5 -0
- codeshift/utils/api_client.py +266 -0
- codeshift/utils/cache.py +318 -0
- codeshift/utils/config.py +71 -0
- codeshift/utils/llm_client.py +221 -0
- codeshift/validator/__init__.py +6 -0
- codeshift/validator/syntax_checker.py +183 -0
- codeshift/validator/test_runner.py +224 -0
- codeshift-0.2.0.dist-info/METADATA +326 -0
- codeshift-0.2.0.dist-info/RECORD +65 -0
- codeshift-0.2.0.dist-info/WHEEL +5 -0
- codeshift-0.2.0.dist-info/entry_points.txt +2 -0
- codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
- codeshift-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""Anthropic Claude client wrapper for LLM-based migrations."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from anthropic import Anthropic
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class LLMResponse:
|
|
11
|
+
"""Response from the LLM."""
|
|
12
|
+
|
|
13
|
+
content: str
|
|
14
|
+
model: str
|
|
15
|
+
usage: dict
|
|
16
|
+
success: bool
|
|
17
|
+
error: str | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LLMClient:
|
|
21
|
+
"""Client for interacting with Anthropic's Claude API."""
|
|
22
|
+
|
|
23
|
+
DEFAULT_MODEL = "claude-sonnet-4-20250514"
|
|
24
|
+
MAX_TOKENS = 4096
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
api_key: str | None = None,
|
|
29
|
+
model: str | None = None,
|
|
30
|
+
):
|
|
31
|
+
"""Initialize the LLM client.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
|
|
35
|
+
model: Model to use. Defaults to claude-sonnet-4-20250514.
|
|
36
|
+
"""
|
|
37
|
+
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
38
|
+
self.model = model or self.DEFAULT_MODEL
|
|
39
|
+
self._client: Anthropic | None = None
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def client(self) -> Anthropic:
|
|
43
|
+
"""Get or create the Anthropic client."""
|
|
44
|
+
if self._client is None:
|
|
45
|
+
if not self.api_key:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"Anthropic API key not found. Set ANTHROPIC_API_KEY environment variable "
|
|
48
|
+
"or pass api_key to LLMClient."
|
|
49
|
+
)
|
|
50
|
+
self._client = Anthropic(api_key=self.api_key)
|
|
51
|
+
return self._client
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def is_available(self) -> bool:
|
|
55
|
+
"""Check if the LLM client is available (API key is set)."""
|
|
56
|
+
return bool(self.api_key)
|
|
57
|
+
|
|
58
|
+
def generate(
|
|
59
|
+
self,
|
|
60
|
+
prompt: str,
|
|
61
|
+
system_prompt: str | None = None,
|
|
62
|
+
max_tokens: int | None = None,
|
|
63
|
+
temperature: float = 0.0,
|
|
64
|
+
) -> LLMResponse:
|
|
65
|
+
"""Generate a response from the LLM.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
prompt: The user prompt
|
|
69
|
+
system_prompt: Optional system prompt
|
|
70
|
+
max_tokens: Maximum tokens in response
|
|
71
|
+
temperature: Sampling temperature (0.0 for deterministic)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
LLMResponse with the generated content
|
|
75
|
+
"""
|
|
76
|
+
if not self.is_available:
|
|
77
|
+
return LLMResponse(
|
|
78
|
+
content="",
|
|
79
|
+
model=self.model,
|
|
80
|
+
usage={},
|
|
81
|
+
success=False,
|
|
82
|
+
error="API key not configured",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
response = self.client.messages.create(
|
|
87
|
+
model=self.model,
|
|
88
|
+
max_tokens=max_tokens or self.MAX_TOKENS,
|
|
89
|
+
temperature=temperature,
|
|
90
|
+
system=system_prompt or "",
|
|
91
|
+
messages=[{"role": "user", "content": prompt}],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
content = ""
|
|
95
|
+
for block in response.content:
|
|
96
|
+
if hasattr(block, "text"):
|
|
97
|
+
content += block.text
|
|
98
|
+
|
|
99
|
+
return LLMResponse(
|
|
100
|
+
content=content,
|
|
101
|
+
model=response.model,
|
|
102
|
+
usage={
|
|
103
|
+
"input_tokens": response.usage.input_tokens,
|
|
104
|
+
"output_tokens": response.usage.output_tokens,
|
|
105
|
+
},
|
|
106
|
+
success=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
except Exception as e:
|
|
110
|
+
return LLMResponse(
|
|
111
|
+
content="",
|
|
112
|
+
model=self.model,
|
|
113
|
+
usage={},
|
|
114
|
+
success=False,
|
|
115
|
+
error=str(e),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def migrate_code(
|
|
119
|
+
self,
|
|
120
|
+
code: str,
|
|
121
|
+
library: str,
|
|
122
|
+
from_version: str,
|
|
123
|
+
to_version: str,
|
|
124
|
+
context: str | None = None,
|
|
125
|
+
) -> LLMResponse:
|
|
126
|
+
"""Use the LLM to migrate code.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
code: The source code to migrate
|
|
130
|
+
library: The library being upgraded
|
|
131
|
+
from_version: Current version
|
|
132
|
+
to_version: Target version
|
|
133
|
+
context: Optional context about the migration
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
LLMResponse with the migrated code
|
|
137
|
+
"""
|
|
138
|
+
system_prompt = f"""You are an expert Python developer specializing in code migrations.
|
|
139
|
+
Your task is to migrate Python code from {library} v{from_version} to v{to_version}.
|
|
140
|
+
|
|
141
|
+
Guidelines:
|
|
142
|
+
1. Only modify code that needs to change for the migration
|
|
143
|
+
2. Preserve all comments, formatting, and code style where possible
|
|
144
|
+
3. Add brief inline comments explaining non-obvious changes
|
|
145
|
+
4. If you're unsure about a change, add a TODO comment
|
|
146
|
+
5. Return ONLY the migrated code, no explanations before or after
|
|
147
|
+
|
|
148
|
+
Important {library} v{from_version} to v{to_version} changes:
|
|
149
|
+
- Config class -> model_config = ConfigDict(...)
|
|
150
|
+
- @validator -> @field_validator with @classmethod
|
|
151
|
+
- @root_validator -> @model_validator with @classmethod
|
|
152
|
+
- .dict() -> .model_dump()
|
|
153
|
+
- .json() -> .model_dump_json()
|
|
154
|
+
- .schema() -> .model_json_schema()
|
|
155
|
+
- .parse_obj() -> .model_validate()
|
|
156
|
+
- .parse_raw() -> .model_validate_json()
|
|
157
|
+
- .copy() -> .model_copy()
|
|
158
|
+
- orm_mode -> from_attributes
|
|
159
|
+
- Field(regex=...) -> Field(pattern=...)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
prompt = f"""Migrate the following Python code from {library} v{from_version} to v{to_version}.
|
|
163
|
+
|
|
164
|
+
{f"Context: {context}" if context else ""}
|
|
165
|
+
|
|
166
|
+
Code to migrate:
|
|
167
|
+
```python
|
|
168
|
+
{code}
|
|
169
|
+
```
|
|
170
|
+
|
|
171
|
+
Return only the migrated Python code:"""
|
|
172
|
+
|
|
173
|
+
return self.generate(prompt, system_prompt=system_prompt)
|
|
174
|
+
|
|
175
|
+
def explain_change(
|
|
176
|
+
self,
|
|
177
|
+
original: str,
|
|
178
|
+
transformed: str,
|
|
179
|
+
library: str,
|
|
180
|
+
) -> LLMResponse:
|
|
181
|
+
"""Use the LLM to explain a migration change.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
original: Original code
|
|
185
|
+
transformed: Transformed code
|
|
186
|
+
library: The library being upgraded
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
LLMResponse with the explanation
|
|
190
|
+
"""
|
|
191
|
+
system_prompt = """You are an expert Python developer.
|
|
192
|
+
Explain code changes clearly and concisely for other developers.
|
|
193
|
+
Focus on the 'why' not just the 'what'."""
|
|
194
|
+
|
|
195
|
+
prompt = f"""Explain the following {library} migration change:
|
|
196
|
+
|
|
197
|
+
Original:
|
|
198
|
+
```python
|
|
199
|
+
{original}
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
Migrated:
|
|
203
|
+
```python
|
|
204
|
+
{transformed}
|
|
205
|
+
```
|
|
206
|
+
|
|
207
|
+
Provide a brief explanation (2-3 sentences) of what changed and why:"""
|
|
208
|
+
|
|
209
|
+
return self.generate(prompt, system_prompt=system_prompt, max_tokens=500)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# Singleton instance for convenience
|
|
213
|
+
_default_client: LLMClient | None = None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def get_llm_client() -> LLMClient:
|
|
217
|
+
"""Get the default LLM client instance."""
|
|
218
|
+
global _default_client
|
|
219
|
+
if _default_client is None:
|
|
220
|
+
_default_client = LLMClient()
|
|
221
|
+
return _default_client
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Validator module for checking transformed code."""
|
|
2
|
+
|
|
3
|
+
from codeshift.validator.syntax_checker import SyntaxChecker, SyntaxCheckResult
|
|
4
|
+
from codeshift.validator.test_runner import TestResult, TestRunner
|
|
5
|
+
|
|
6
|
+
__all__ = ["SyntaxChecker", "SyntaxCheckResult", "TestRunner", "TestResult"]
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""Syntax checker for validating transformed code."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import sys
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class SyntaxIssue:
|
|
11
|
+
"""Represents a syntax error in code."""
|
|
12
|
+
|
|
13
|
+
message: str
|
|
14
|
+
line_number: int
|
|
15
|
+
column: int
|
|
16
|
+
line_text: str | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SyntaxCheckResult:
|
|
21
|
+
"""Result of a syntax check."""
|
|
22
|
+
|
|
23
|
+
is_valid: bool
|
|
24
|
+
file_path: Path | None = None
|
|
25
|
+
errors: list[SyntaxIssue] = field(default_factory=list)
|
|
26
|
+
warnings: list[str] = field(default_factory=list)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def error_count(self) -> int:
|
|
30
|
+
"""Get the number of errors."""
|
|
31
|
+
return len(self.errors)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SyntaxChecker:
|
|
35
|
+
"""Validates Python code syntax."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, python_version: tuple[int, int] | None = None):
|
|
38
|
+
"""Initialize the syntax checker.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
python_version: Target Python version as (major, minor).
|
|
42
|
+
Defaults to current Python version.
|
|
43
|
+
"""
|
|
44
|
+
if python_version is None:
|
|
45
|
+
python_version = (sys.version_info.major, sys.version_info.minor)
|
|
46
|
+
self.python_version = python_version
|
|
47
|
+
|
|
48
|
+
def check_code(self, source_code: str, filename: str = "<string>") -> SyntaxCheckResult:
|
|
49
|
+
"""Check if source code has valid Python syntax.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
source_code: The Python source code to check
|
|
53
|
+
filename: Optional filename for error messages
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
SyntaxCheckResult with validation status
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
# First, try to compile the code
|
|
60
|
+
compile(source_code, filename, "exec")
|
|
61
|
+
|
|
62
|
+
# Then parse with AST for more detailed checking
|
|
63
|
+
ast.parse(source_code, filename=filename)
|
|
64
|
+
|
|
65
|
+
return SyntaxCheckResult(is_valid=True)
|
|
66
|
+
|
|
67
|
+
except SyntaxError as e:
|
|
68
|
+
error = SyntaxIssue(
|
|
69
|
+
message=str(e.msg) if hasattr(e, "msg") else str(e),
|
|
70
|
+
line_number=e.lineno or 0,
|
|
71
|
+
column=e.offset or 0,
|
|
72
|
+
line_text=e.text,
|
|
73
|
+
)
|
|
74
|
+
return SyntaxCheckResult(
|
|
75
|
+
is_valid=False,
|
|
76
|
+
errors=[error],
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def check_file(self, file_path: Path) -> SyntaxCheckResult:
|
|
80
|
+
"""Check if a Python file has valid syntax.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
file_path: Path to the Python file
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
SyntaxCheckResult with validation status
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
source_code = file_path.read_text()
|
|
90
|
+
except Exception as e:
|
|
91
|
+
return SyntaxCheckResult(
|
|
92
|
+
is_valid=False,
|
|
93
|
+
file_path=file_path,
|
|
94
|
+
errors=[
|
|
95
|
+
SyntaxIssue(
|
|
96
|
+
message=f"Could not read file: {e}",
|
|
97
|
+
line_number=0,
|
|
98
|
+
column=0,
|
|
99
|
+
)
|
|
100
|
+
],
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
result = self.check_code(source_code, str(file_path))
|
|
104
|
+
result.file_path = file_path
|
|
105
|
+
return result
|
|
106
|
+
|
|
107
|
+
def check_directory(
|
|
108
|
+
self, directory: Path, exclude_patterns: list[str] | None = None
|
|
109
|
+
) -> list[SyntaxCheckResult]:
|
|
110
|
+
"""Check all Python files in a directory.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
directory: Path to the directory
|
|
114
|
+
exclude_patterns: Glob patterns to exclude
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of SyntaxCheckResult for each file with errors
|
|
118
|
+
"""
|
|
119
|
+
import fnmatch
|
|
120
|
+
|
|
121
|
+
exclude_patterns = exclude_patterns or []
|
|
122
|
+
results = []
|
|
123
|
+
|
|
124
|
+
for file_path in directory.rglob("*.py"):
|
|
125
|
+
relative_path = str(file_path.relative_to(directory))
|
|
126
|
+
|
|
127
|
+
# Check exclude patterns
|
|
128
|
+
excluded = False
|
|
129
|
+
for pattern in exclude_patterns:
|
|
130
|
+
if fnmatch.fnmatch(relative_path, pattern):
|
|
131
|
+
excluded = True
|
|
132
|
+
break
|
|
133
|
+
|
|
134
|
+
if excluded:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
result = self.check_file(file_path)
|
|
138
|
+
if not result.is_valid:
|
|
139
|
+
results.append(result)
|
|
140
|
+
|
|
141
|
+
return results
|
|
142
|
+
|
|
143
|
+
def validate_transform(self, original: str, transformed: str) -> tuple[bool, list[str]]:
|
|
144
|
+
"""Validate that a transformation didn't break syntax.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
original: Original source code
|
|
148
|
+
transformed: Transformed source code
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Tuple of (is_valid, list of issues)
|
|
152
|
+
"""
|
|
153
|
+
issues = []
|
|
154
|
+
|
|
155
|
+
# Check original syntax (should be valid)
|
|
156
|
+
original_result = self.check_code(original, "<original>")
|
|
157
|
+
if not original_result.is_valid:
|
|
158
|
+
issues.append("Original code has syntax errors")
|
|
159
|
+
|
|
160
|
+
# Check transformed syntax
|
|
161
|
+
transformed_result = self.check_code(transformed, "<transformed>")
|
|
162
|
+
if not transformed_result.is_valid:
|
|
163
|
+
for error in transformed_result.errors:
|
|
164
|
+
issues.append(f"Line {error.line_number}: {error.message}")
|
|
165
|
+
return False, issues
|
|
166
|
+
|
|
167
|
+
return True, issues
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def quick_syntax_check(source_code: str) -> bool:
|
|
171
|
+
"""Quick check if code has valid Python syntax.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
source_code: The Python source code to check
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
True if syntax is valid, False otherwise
|
|
178
|
+
"""
|
|
179
|
+
try:
|
|
180
|
+
compile(source_code, "<string>", "exec")
|
|
181
|
+
return True
|
|
182
|
+
except SyntaxError:
|
|
183
|
+
return False
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Test runner for validating migrations."""
|
|
2
|
+
|
|
3
|
+
import subprocess
|
|
4
|
+
import sys
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TestResult:
|
|
11
|
+
"""Result of running tests."""
|
|
12
|
+
|
|
13
|
+
success: bool
|
|
14
|
+
exit_code: int
|
|
15
|
+
stdout: str = ""
|
|
16
|
+
stderr: str = ""
|
|
17
|
+
tests_run: int = 0
|
|
18
|
+
tests_passed: int = 0
|
|
19
|
+
tests_failed: int = 0
|
|
20
|
+
tests_skipped: int = 0
|
|
21
|
+
duration: float = 0.0
|
|
22
|
+
error_message: str | None = None
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def summary(self) -> str:
|
|
26
|
+
"""Get a summary string of the test results."""
|
|
27
|
+
if self.success:
|
|
28
|
+
return f"✓ {self.tests_passed}/{self.tests_run} tests passed"
|
|
29
|
+
return f"✗ {self.tests_failed}/{self.tests_run} tests failed"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestRunner:
|
|
33
|
+
"""Runs project tests to validate migrations."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
project_path: Path,
|
|
38
|
+
test_command: list[str] | None = None,
|
|
39
|
+
timeout: int = 300,
|
|
40
|
+
):
|
|
41
|
+
"""Initialize the test runner.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
project_path: Path to the project root
|
|
45
|
+
test_command: Custom test command. Defaults to pytest detection.
|
|
46
|
+
timeout: Maximum time in seconds to run tests
|
|
47
|
+
"""
|
|
48
|
+
self.project_path = project_path
|
|
49
|
+
self.test_command = test_command or self._detect_test_command()
|
|
50
|
+
self.timeout = timeout
|
|
51
|
+
|
|
52
|
+
def _detect_test_command(self) -> list[str]:
|
|
53
|
+
"""Detect the appropriate test command for the project."""
|
|
54
|
+
# Check for pytest
|
|
55
|
+
if (
|
|
56
|
+
(self.project_path / "pytest.ini").exists()
|
|
57
|
+
or (self.project_path / "pyproject.toml").exists()
|
|
58
|
+
or (self.project_path / "tests").exists()
|
|
59
|
+
):
|
|
60
|
+
return [sys.executable, "-m", "pytest", "-v", "--tb=short"]
|
|
61
|
+
|
|
62
|
+
# Check for unittest
|
|
63
|
+
if (self.project_path / "tests").exists():
|
|
64
|
+
return [sys.executable, "-m", "unittest", "discover", "-v"]
|
|
65
|
+
|
|
66
|
+
# Default to pytest
|
|
67
|
+
return [sys.executable, "-m", "pytest", "-v", "--tb=short"]
|
|
68
|
+
|
|
69
|
+
def run(
|
|
70
|
+
self,
|
|
71
|
+
specific_tests: list[str] | None = None,
|
|
72
|
+
extra_args: list[str] | None = None,
|
|
73
|
+
) -> TestResult:
|
|
74
|
+
"""Run the project tests.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
specific_tests: List of specific test files or patterns to run
|
|
78
|
+
extra_args: Additional arguments to pass to the test runner
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
TestResult with the outcome
|
|
82
|
+
"""
|
|
83
|
+
command = self.test_command.copy()
|
|
84
|
+
|
|
85
|
+
if extra_args:
|
|
86
|
+
command.extend(extra_args)
|
|
87
|
+
|
|
88
|
+
if specific_tests:
|
|
89
|
+
command.extend(specific_tests)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
result = subprocess.run(
|
|
93
|
+
command,
|
|
94
|
+
cwd=self.project_path,
|
|
95
|
+
capture_output=True,
|
|
96
|
+
text=True,
|
|
97
|
+
timeout=self.timeout,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Parse pytest output
|
|
101
|
+
tests_run, tests_passed, tests_failed, tests_skipped = self._parse_pytest_output(
|
|
102
|
+
result.stdout + result.stderr
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return TestResult(
|
|
106
|
+
success=result.returncode == 0,
|
|
107
|
+
exit_code=result.returncode,
|
|
108
|
+
stdout=result.stdout,
|
|
109
|
+
stderr=result.stderr,
|
|
110
|
+
tests_run=tests_run,
|
|
111
|
+
tests_passed=tests_passed,
|
|
112
|
+
tests_failed=tests_failed,
|
|
113
|
+
tests_skipped=tests_skipped,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
except subprocess.TimeoutExpired:
|
|
117
|
+
return TestResult(
|
|
118
|
+
success=False,
|
|
119
|
+
exit_code=-1,
|
|
120
|
+
error_message=f"Tests timed out after {self.timeout} seconds",
|
|
121
|
+
)
|
|
122
|
+
except FileNotFoundError as e:
|
|
123
|
+
return TestResult(
|
|
124
|
+
success=False,
|
|
125
|
+
exit_code=-1,
|
|
126
|
+
error_message=f"Test command not found: {e}",
|
|
127
|
+
)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
return TestResult(
|
|
130
|
+
success=False,
|
|
131
|
+
exit_code=-1,
|
|
132
|
+
error_message=f"Error running tests: {e}",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _parse_pytest_output(self, output: str) -> tuple[int, int, int, int]:
|
|
136
|
+
"""Parse pytest output to extract test counts.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
output: Combined stdout and stderr from pytest
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Tuple of (total, passed, failed, skipped)
|
|
143
|
+
"""
|
|
144
|
+
import re
|
|
145
|
+
|
|
146
|
+
# Look for pytest summary line like "5 passed, 2 failed, 1 skipped"
|
|
147
|
+
# or "1 passed in 0.05s"
|
|
148
|
+
passed = 0
|
|
149
|
+
failed = 0
|
|
150
|
+
skipped = 0
|
|
151
|
+
errors = 0
|
|
152
|
+
|
|
153
|
+
# Match patterns like "5 passed", "2 failed", etc.
|
|
154
|
+
passed_match = re.search(r"(\d+) passed", output)
|
|
155
|
+
if passed_match:
|
|
156
|
+
passed = int(passed_match.group(1))
|
|
157
|
+
|
|
158
|
+
failed_match = re.search(r"(\d+) failed", output)
|
|
159
|
+
if failed_match:
|
|
160
|
+
failed = int(failed_match.group(1))
|
|
161
|
+
|
|
162
|
+
skipped_match = re.search(r"(\d+) skipped", output)
|
|
163
|
+
if skipped_match:
|
|
164
|
+
skipped = int(skipped_match.group(1))
|
|
165
|
+
|
|
166
|
+
error_match = re.search(r"(\d+) error", output)
|
|
167
|
+
if error_match:
|
|
168
|
+
errors = int(error_match.group(1))
|
|
169
|
+
|
|
170
|
+
total = passed + failed + skipped + errors
|
|
171
|
+
return total, passed, failed + errors, skipped
|
|
172
|
+
|
|
173
|
+
def run_quick_check(self) -> TestResult:
|
|
174
|
+
"""Run a quick smoke test (collection only, no execution).
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
TestResult indicating if tests can be collected
|
|
178
|
+
"""
|
|
179
|
+
command = self.test_command.copy()
|
|
180
|
+
command.extend(["--collect-only", "-q"])
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
result = subprocess.run(
|
|
184
|
+
command,
|
|
185
|
+
cwd=self.project_path,
|
|
186
|
+
capture_output=True,
|
|
187
|
+
text=True,
|
|
188
|
+
timeout=60,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Count collected tests
|
|
192
|
+
tests_collected = 0
|
|
193
|
+
for line in result.stdout.splitlines():
|
|
194
|
+
if "test" in line.lower() and "::" in line:
|
|
195
|
+
tests_collected += 1
|
|
196
|
+
|
|
197
|
+
return TestResult(
|
|
198
|
+
success=result.returncode == 0,
|
|
199
|
+
exit_code=result.returncode,
|
|
200
|
+
stdout=result.stdout,
|
|
201
|
+
stderr=result.stderr,
|
|
202
|
+
tests_run=tests_collected,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
return TestResult(
|
|
207
|
+
success=False,
|
|
208
|
+
exit_code=-1,
|
|
209
|
+
error_message=f"Error collecting tests: {e}",
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def run_tests(project_path: Path, timeout: int = 300) -> TestResult:
|
|
214
|
+
"""Convenience function to run tests for a project.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
project_path: Path to the project
|
|
218
|
+
timeout: Maximum time in seconds
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
TestResult with the outcome
|
|
222
|
+
"""
|
|
223
|
+
runner = TestRunner(project_path, timeout=timeout)
|
|
224
|
+
return runner.run()
|