akitallm 0.1.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,181 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional
3
+ import requests
4
+ from pydantic import BaseModel
5
+
6
+ class ModelInfo(BaseModel):
7
+ id: str
8
+ name: Optional[str] = None
9
+
10
+ class BaseProvider(ABC):
11
+ @property
12
+ @abstractmethod
13
+ def name(self) -> str:
14
+ pass
15
+
16
+ @abstractmethod
17
+ def validate_key(self, api_key: str) -> bool:
18
+ pass
19
+
20
+ @abstractmethod
21
+ def list_models(self, api_key: str) -> List[ModelInfo]:
22
+ pass
23
+
24
+ class OpenAIProvider(BaseProvider):
25
+ @property
26
+ def name(self) -> str:
27
+ return "openai"
28
+
29
+ def validate_key(self, api_key: str) -> bool:
30
+ if not api_key.startswith("sk-"):
31
+ return False
32
+ # Simple validation request
33
+ try:
34
+ response = requests.get(
35
+ "https://api.openai.com/v1/models",
36
+ headers={"Authorization": f"Bearer {api_key}"},
37
+ timeout=5
38
+ )
39
+ return response.status_code == 200
40
+ except Exception:
41
+ return False
42
+
43
+ def list_models(self, api_key: str) -> List[ModelInfo]:
44
+ response = requests.get(
45
+ "https://api.openai.com/v1/models",
46
+ headers={"Authorization": f"Bearer {api_key}"},
47
+ timeout=10
48
+ )
49
+ response.raise_for_status()
50
+ data = response.json()
51
+ exclude_keywords = ["vision", "instruct", "audio", "realtime", "tts", "dall-e", "embedding", "moderation", "davinci", "babbage", "curie", "ada"]
52
+
53
+ models = []
54
+ for m in data["data"]:
55
+ model_id = m["id"]
56
+ if not any(kw in model_id.lower() for kw in exclude_keywords):
57
+ if model_id.startswith("gpt-") or model_id.startswith("o1") or model_id.startswith("o3"):
58
+ models.append(ModelInfo(id=model_id))
59
+ return sorted(models, key=lambda x: x.id)
60
+
61
+ class AnthropicProvider(BaseProvider):
62
+ @property
63
+ def name(self) -> str:
64
+ return "anthropic"
65
+
66
+ def validate_key(self, api_key: str) -> bool:
67
+ if not api_key.startswith("sk-ant-"):
68
+ return False
69
+ # Anthropic validation usually requires a full request, but we'll check prefix for now
70
+ # or do a no-op call if possible.
71
+ return True
72
+
73
+ def list_models(self, api_key: str) -> List[ModelInfo]:
74
+ # Anthropic doesn't have a public models list API like OpenAI
75
+ return [
76
+ ModelInfo(id="claude-3-5-sonnet-latest", name="Claude 3.5 Sonnet (Latest)"),
77
+ ModelInfo(id="claude-3-5-haiku-latest", name="Claude 3.5 Haiku (Latest)"),
78
+ ModelInfo(id="claude-3-opus-20240229", name="Claude 3 Opus"),
79
+ ModelInfo(id="claude-3-sonnet-20240229", name="Claude 3 Sonnet"),
80
+ ModelInfo(id="claude-3-haiku-20240307", name="Claude 3 Haiku"),
81
+ ]
82
+
83
+ class OllamaProvider(BaseProvider):
84
+ @property
85
+ def name(self) -> str:
86
+ return "ollama"
87
+
88
+ def validate_key(self, api_key: str) -> bool:
89
+ # Ollama doesn't use keys by default, we just check if it's reachable
90
+ try:
91
+ response = requests.get("http://localhost:11434/api/tags", timeout=2)
92
+ return response.status_code == 200
93
+ except Exception:
94
+ return False
95
+
96
+ def list_models(self, api_key: str) -> List[ModelInfo]:
97
+ response = requests.get("http://localhost:11434/api/tags", timeout=5)
98
+ response.raise_for_status()
99
+ data = response.json()
100
+ return [ModelInfo(id=m["name"]) for m in data["models"]]
101
+
102
+ class GeminiProvider(BaseProvider):
103
+ @property
104
+ def name(self) -> str:
105
+ return "gemini"
106
+
107
+ def validate_key(self, api_key: str) -> bool:
108
+ if not api_key.startswith("AIza"):
109
+ return False
110
+ return True
111
+
112
+ def list_models(self, api_key: str) -> List[ModelInfo]:
113
+ # Gemini API URL for listing models
114
+ url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}"
115
+ response = requests.get(url, timeout=10)
116
+ response.raise_for_status()
117
+ data = response.json()
118
+
119
+ exclude_keywords = ["nano", "banana", "vision", "embedding", "aqa", "learnlm"]
120
+
121
+ models = []
122
+ for m in data["models"]:
123
+ model_id = m["name"].split("/")[-1]
124
+ display_name = m["displayName"]
125
+
126
+ # Check if it supports generation and doesn't have excluded keywords
127
+ if "generateContent" in m["supportedGenerationMethods"]:
128
+ if not any(kw in model_id.lower() or kw in display_name.lower() for kw in exclude_keywords):
129
+ models.append(ModelInfo(id=model_id, name=display_name))
130
+
131
+ return models
132
+
133
+ class GroqProvider(BaseProvider):
134
+ @property
135
+ def name(self) -> str:
136
+ return "groq"
137
+
138
+ def validate_key(self, api_key: str) -> bool:
139
+ if not api_key.startswith("gsk_"):
140
+ return False
141
+ return True
142
+
143
+ def list_models(self, api_key: str) -> List[ModelInfo]:
144
+ # Groq uses OpenAI-compatible models endpoint
145
+ response = requests.get(
146
+ "https://api.groq.com/openai/v1/models",
147
+ headers={"Authorization": f"Bearer {api_key}"},
148
+ timeout=10
149
+ )
150
+ response.raise_for_status()
151
+ data = response.json()
152
+
153
+ # Filter for text models
154
+ exclude_keywords = ["vision", "audio"]
155
+ models = []
156
+ for m in data["data"]:
157
+ model_id = m["id"]
158
+ if not any(kw in model_id.lower() for kw in exclude_keywords):
159
+ models.append(ModelInfo(id=model_id))
160
+ return sorted(models, key=lambda x: x.id)
161
+
162
+ def detect_provider(api_key: str) -> Optional[BaseProvider]:
163
+ """
164
+ Attempts to detect the provider based on the API key or environment.
165
+ """
166
+ if api_key.lower() == "ollama":
167
+ return OllamaProvider()
168
+
169
+ if api_key.startswith("sk-ant-"):
170
+ return AnthropicProvider()
171
+
172
+ if api_key.startswith("gsk_"):
173
+ return GroqProvider()
174
+
175
+ if api_key.startswith("sk-"):
176
+ return OpenAIProvider()
177
+
178
+ if api_key.startswith("AIza"):
179
+ return GeminiProvider()
180
+
181
+ return None
akita/core/trace.py ADDED
@@ -0,0 +1,18 @@
1
+ from typing import List, Dict, Any
2
+ from datetime import datetime
3
+ from pydantic import BaseModel, Field
4
+
5
+ class TraceStep(BaseModel):
6
+ timestamp: datetime = Field(default_factory=datetime.now)
7
+ action: str
8
+ details: str
9
+ metadata: Dict[str, Any] = Field(default_factory=dict)
10
+
11
+ class ReasoningTrace(BaseModel):
12
+ steps: List[TraceStep] = Field(default_factory=list)
13
+
14
+ def add_step(self, action: str, details: str, metadata: Dict[str, Any] = None):
15
+ self.steps.append(TraceStep(action=action, details=details, metadata=metadata or {}))
16
+
17
+ def __str__(self):
18
+ return "\n".join([f"[{s.timestamp.strftime('%H:%M:%S')}] {s.action}: {s.details}" for s in self.steps])
akita/models/base.py CHANGED
@@ -35,14 +35,19 @@ def get_model(model_name: Optional[str] = None) -> AIModel:
35
35
  """
36
36
  Get an AIModel instance based on config or provided name.
37
37
  """
38
+ provider = get_config_value("model", "provider", "openai")
39
+ api_key = get_config_value("model", "api_key")
40
+
38
41
  if model_name is None:
39
42
  model_name = get_config_value("model", "name", "gpt-4o-mini")
40
43
 
41
- provider = get_config_value("model", "provider", "openai")
42
-
43
- # LiteLLM usually wants "provider/model_name" for some providers
44
- # but for OpenAI it handles "gpt-3.5-turbo" directly.
45
- # If it's a custom provider, we might need to prepend it.
46
- full_model_name = f"{provider}/{model_name}" if provider != "openai" else model_name
44
+ # LiteLLM wants "provider/model_name" for non-OpenAI providers
45
+ if provider == "openai":
46
+ full_model_name = model_name
47
+ elif provider == "gemini":
48
+ full_model_name = f"gemini/{model_name}"
49
+ else:
50
+ full_model_name = f"{provider}/{model_name}"
47
51
 
48
- return AIModel(model_name=full_model_name)
52
+ # For Ollama, we might need a base_url, but for now we assume default
53
+ return AIModel(model_name=full_model_name, api_key=api_key)
@@ -0,0 +1 @@
1
+ # Official AkitaLLM Plugins
akita/plugins/files.py ADDED
@@ -0,0 +1,34 @@
1
+ from akita.core.plugins import AkitaPlugin
2
+ from akita.tools.base import FileSystemTools
3
+ from typing import List, Dict, Any
4
+
5
+ class FilesPlugin(AkitaPlugin):
6
+ @property
7
+ def name(self) -> str:
8
+ return "files"
9
+
10
+ @property
11
+ def description(self) -> str:
12
+ return "Standard filesystem operations (read, write, list)."
13
+
14
+ def get_tools(self) -> List[Dict[str, Any]]:
15
+ return [
16
+ {
17
+ "name": "read_file",
18
+ "description": "Read content from a file.",
19
+ "parameters": {"path": "string"},
20
+ "func": FileSystemTools.read_file
21
+ },
22
+ {
23
+ "name": "write_file",
24
+ "description": "Write content to a file.",
25
+ "parameters": {"path": "string", "content": "string"},
26
+ "func": FileSystemTools.write_file
27
+ },
28
+ {
29
+ "name": "list_dir",
30
+ "description": "List files in a directory.",
31
+ "parameters": {"path": "string"},
32
+ "func": FileSystemTools.list_dir
33
+ }
34
+ ]
akita/reasoning/engine.py CHANGED
@@ -1,16 +1,23 @@
1
1
  from typing import List, Dict, Any, Optional
2
2
  from akita.models.base import AIModel, get_model
3
- from akita.tools.base import ShellTools, FileSystemTools
3
+ from akita.tools.base import ShellTools
4
+ from akita.core.plugins import PluginManager
4
5
  from akita.tools.context import ContextBuilder
5
6
  from akita.schemas.review import ReviewResult
7
+ from akita.core.trace import ReasoningTrace
8
+ from akita.reasoning.session import ConversationSession
6
9
  import json
7
10
  from rich.console import Console
8
11
 
9
12
  console = Console()
10
-
13
+
11
14
  class ReasoningEngine:
12
15
  def __init__(self, model: AIModel):
13
16
  self.model = model
17
+ self.plugin_manager = PluginManager()
18
+ self.plugin_manager.discover_all()
19
+ self.trace = ReasoningTrace()
20
+ self.session: Optional[ConversationSession] = None
14
21
 
15
22
  def run_review(self, path: str) -> ReviewResult:
16
23
  """
@@ -91,27 +98,46 @@ class ReasoningEngine:
91
98
  ])
92
99
  return response.content
93
100
 
94
- def run_solve(self, query: str, path: str = ".") -> str:
101
+ def run_solve(self, query: str, path: str = ".", session: Optional[ConversationSession] = None) -> str:
95
102
  """
96
103
  Generates a Unified Diff solution for the given query.
104
+ Supports iterative refinement if a session is provided.
97
105
  """
98
- console.print(f"🔍 [bold]Building context for solution...[/]")
99
- builder = ContextBuilder(path)
100
- snapshot = builder.build()
101
-
102
- files_str = "\n---\n".join([f"FILE: {f.path}\nCONTENT:\n{f.content}" for f in snapshot.files[:10]]) # Limit for solve
106
+ self.trace.add_step("Solve", f"Starting solve for query: {query}")
103
107
 
104
- system_prompt = (
105
- "You are an Expert Programmer. Solve the requested task by providing code changes in Unified Diff format. "
106
- "Respond ONLY with the Diff block. Use +++ and --- with file paths relative to project root."
107
- )
108
- user_prompt = f"Task: {query}\n\nContext:\n{files_str}\n\nGenerate the Unified Diff."
108
+ if not session:
109
+ self.trace.add_step("Context", f"Building context for {path}")
110
+ builder = ContextBuilder(path)
111
+ snapshot = builder.build(query=query)
112
+
113
+ files_str = "\n---\n".join([f"FILE: {f.path}\nCONTENT:\n{f.content}" for f in snapshot.files[:10]])
114
+
115
+ rag_str = ""
116
+ if snapshot.rag_snippets:
117
+ rag_str = "\n\nRELEVANT SNIPPETS (RAG):\n" + "\n".join([
118
+ f"- {s['path']} ({s['name']}):\n{s['content']}" for s in snapshot.rag_snippets
119
+ ])
120
+
121
+ tools_info = "\n".join([f"- {t['name']}: {t['description']}" for t in self.plugin_manager.get_all_tools()])
122
+
123
+ system_prompt = (
124
+ "You are an Expert Programmer. Solve the requested task by providing code changes in Unified Diff format. "
125
+ "Respond ONLY with the Diff block. Use +++ and --- with file paths relative to project root.\n\n"
126
+ f"Available Tools:\n{tools_info}"
127
+ )
128
+
129
+ session = ConversationSession()
130
+ session.add_message("system", system_prompt)
131
+ session.add_message("user", f"Task: {query}\n\nContext:\n{files_str}{rag_str}")
132
+ self.session = session
133
+ else:
134
+ session.add_message("user", query)
135
+
136
+ console.print("🤖 [bold green]Thinking...[/]")
137
+ response = self.model.chat(session.get_messages_dict())
138
+ session.add_message("assistant", response.content)
109
139
 
110
- console.print("🤖 [bold green]Generating solution...[/]")
111
- response = self.model.chat([
112
- {"role": "system", "content": system_prompt},
113
- {"role": "user", "content": user_prompt}
114
- ])
140
+ self.trace.add_step("LLM Response", "Received solution from model")
115
141
  return response.content
116
142
 
117
143
  def run_pipeline(self, task: str):
@@ -0,0 +1,15 @@
1
+ from typing import List, Dict, Any
2
+ from pydantic import BaseModel, Field
3
+
4
+ class ChatMessage(BaseModel):
5
+ role: str
6
+ content: str
7
+
8
+ class ConversationSession(BaseModel):
9
+ messages: List[ChatMessage] = Field(default_factory=list)
10
+
11
+ def add_message(self, role: str, content: str):
12
+ self.messages.append(ChatMessage(role=role, content=content))
13
+
14
+ def get_messages_dict(self) -> List[Dict[str, str]]:
15
+ return [m.model_dump() for m in self.messages]
akita/tools/base.py CHANGED
@@ -15,7 +15,12 @@ class FileSystemTools:
15
15
  return f.read()
16
16
 
17
17
  @staticmethod
18
- def list_files(path: str) -> List[str]:
18
+ def write_file(path: str, content: str):
19
+ with open(path, 'w', encoding='utf-8') as f:
20
+ f.write(content)
21
+
22
+ @staticmethod
23
+ def list_dir(path: str) -> List[str]:
19
24
  return os.listdir(path)
20
25
 
21
26
  class ShellTools:
akita/tools/context.py CHANGED
@@ -1,16 +1,18 @@
1
1
  import os
2
2
  from pathlib import Path
3
- from typing import List, Dict, Optional
3
+ from typing import Any, List, Dict, Optional
4
4
  from pydantic import BaseModel
5
5
 
6
6
  class FileContext(BaseModel):
7
7
  path: str
8
8
  content: str
9
9
  extension: str
10
+ summary: Optional[str] = None # New field for semantic summary
10
11
 
11
12
  class ContextSnapshot(BaseModel):
12
13
  files: List[FileContext]
13
14
  project_structure: List[str]
15
+ rag_snippets: Optional[List[Dict[str, Any]]] = None
14
16
 
15
17
  class ContextBuilder:
16
18
  def __init__(
@@ -19,26 +21,49 @@ class ContextBuilder:
19
21
  extensions: Optional[List[str]] = None,
20
22
  exclude_dirs: Optional[List[str]] = None,
21
23
  max_file_size_kb: int = 50,
22
- max_files: int = 50
24
+ max_files: int = 50,
25
+ use_semantical_context: bool = True
23
26
  ):
24
27
  self.base_path = Path(base_path)
25
28
  self.extensions = extensions or [".py", ".js", ".ts", ".cpp", ".h", ".toml", ".md", ".json"]
26
29
  self.exclude_dirs = exclude_dirs or [".git", ".venv", "node_modules", "__pycache__", "dist", "build"]
27
30
  self.max_file_size_kb = max_file_size_kb
28
31
  self.max_files = max_files
32
+ self.use_semantical_context = use_semantical_context
33
+
34
+ if self.use_semantical_context:
35
+ try:
36
+ from akita.core.ast_utils import ASTParser
37
+ from akita.core.indexing import CodeIndexer
38
+ self.ast_parser = ASTParser()
39
+ self.indexer = CodeIndexer(str(self.base_path))
40
+ except ImportError:
41
+ self.ast_parser = None
42
+ self.indexer = None
29
43
 
30
- def build(self) -> ContextSnapshot:
31
- """Scan the path and build a context snapshot."""
44
+ def build(self, query: Optional[str] = None) -> ContextSnapshot:
45
+ """
46
+ Scan the path and build a context snapshot.
47
+ If a query is provided and indexer is available, it includes RAG snippets.
48
+ """
32
49
  files_context = []
33
50
  project_structure = []
51
+ rag_snippets = None
34
52
 
53
+ if query and self.indexer:
54
+ try:
55
+ # Ensure index exists (lazy indexing for now)
56
+ # In production, we'd have a separate command or check timestamps
57
+ rag_snippets = self.indexer.search(query, n_results=10)
58
+ except Exception:
59
+ pass
60
+
35
61
  if self.base_path.is_file():
36
62
  if self._should_include_file(self.base_path):
37
63
  files_context.append(self._read_file(self.base_path))
38
64
  project_structure.append(str(self.base_path.name))
39
65
  else:
40
66
  for root, dirs, files in os.walk(self.base_path):
41
- # Filter out excluded directories
42
67
  dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
43
68
 
44
69
  rel_root = os.path.relpath(root, self.base_path)
@@ -54,7 +79,11 @@ class ContextBuilder:
54
79
  files_context.append(context)
55
80
  project_structure.append(os.path.join(rel_root, file))
56
81
 
57
- return ContextSnapshot(files=files_context, project_structure=project_structure)
82
+ return ContextSnapshot(
83
+ files=files_context,
84
+ project_structure=project_structure,
85
+ rag_snippets=rag_snippets
86
+ )
58
87
 
59
88
  def _should_include_file(self, path: Path) -> bool:
60
89
  if path.name == ".env" or path.suffix == ".env":
@@ -66,8 +95,12 @@ class ContextBuilder:
66
95
  if not path.exists():
67
96
  return False
68
97
 
69
- # Check size
70
- if path.stat().st_size > self.max_file_size_kb * 1024:
98
+ # Check size (we can be more lenient if using semantic summaries)
99
+ size_limit = self.max_file_size_kb * 1024
100
+ if self.use_semantical_context:
101
+ size_limit *= 2 # Allow larger files if we can summarize them
102
+
103
+ if path.stat().st_size > size_limit:
71
104
  return False
72
105
 
73
106
  return True
@@ -76,10 +109,22 @@ class ContextBuilder:
76
109
  try:
77
110
  with open(path, 'r', encoding='utf-8') as f:
78
111
  content = f.read()
112
+
113
+ summary = None
114
+ if self.use_semantical_context and self.ast_parser and path.suffix == ".py":
115
+ try:
116
+ defs = self.ast_parser.get_definitions(str(path))
117
+ if defs:
118
+ summary_lines = [f"{d['type'].upper()} {d['name']} (L{d['start_line']}-L{d['end_line']})" for d in defs]
119
+ summary = "\n".join(summary_lines)
120
+ except Exception:
121
+ pass
122
+
79
123
  return FileContext(
80
124
  path=str(path.relative_to(self.base_path) if self.base_path.is_dir() else path.name),
81
125
  content=content,
82
- extension=path.suffix
126
+ extension=path.suffix,
127
+ summary=summary
83
128
  )
84
129
  except Exception:
85
130
  return None
akita/tools/diff.py CHANGED
@@ -1,35 +1,110 @@
1
1
  import os
2
+ import shutil
3
+ import pathlib
2
4
  from pathlib import Path
3
- import re
5
+ import whatthepatch
6
+ from typing import List, Tuple, Optional
4
7
 
5
8
  class DiffApplier:
6
9
  @staticmethod
7
- def apply_unified_diff(diff_text: str, base_path: str = "."):
10
+ def apply_unified_diff(diff_text: str, base_path: str = ".") -> bool:
8
11
  """
9
- Simplistic Unified Diff applier.
10
- In a real scenario, this would use a robust library like 'patch-py' or 'whatthepatch'.
11
- For AkitaLLM, we keep it simple for now.
12
+ Applies a unified diff to files in the base_path.
13
+ Includes backup and rollback logic for atomicity.
12
14
  """
13
- # Split by file
14
- file_diffs = re.split(r'--- (.*?)\n\+\+\+ (.*?)\n', diff_text)
15
-
16
- # Pattern extraction is tricky with regex, let's try a safer approach
17
- lines = diff_text.splitlines()
18
- current_file = None
19
- new_content = []
20
-
21
- # This is a VERY placeholder implementation for safety.
22
- # Applying diffs manually is high risk without a dedicated library.
23
- # For the MVP, we will log what would happen.
24
-
25
- print(f"DEBUG: DiffApplier would process {len(lines)} lines of diff.")
26
-
27
- # Real logic would:
28
- # 1. Path identification (--- / +++)
29
- # 2. Hunk identification (@@)
30
- # 3. Line modification
31
-
32
- return True
15
+ patches = list(whatthepatch.parse_patch(diff_text))
16
+ if not patches:
17
+ print("ERROR: No valid patches found in the diff text.")
18
+ return False
19
+
20
+ backups: List[Tuple[Path, Path]] = []
21
+ base = Path(base_path)
22
+ backup_dir = base / ".akita" / "backups"
23
+ backup_dir.mkdir(parents=True, exist_ok=True)
24
+
25
+ try:
26
+ for patch in patches:
27
+ if not patch.header:
28
+ continue
29
+
30
+ # whatthepatch identifies the target file in the header
31
+ # We usually want the 'new' filename (the +++ part)
32
+ rel_path = patch.header.new_path
33
+ is_new = (patch.header.old_path == "/dev/null")
34
+ is_delete = (patch.header.new_path == "/dev/null")
35
+
36
+ if is_new:
37
+ rel_path = patch.header.new_path
38
+ elif is_delete:
39
+ rel_path = patch.header.old_path
40
+ else:
41
+ rel_path = patch.header.new_path or patch.header.old_path
42
+
43
+ if not rel_path or rel_path == "/dev/null":
44
+ continue
45
+
46
+ # Clean up path (sometimes they have a/ or b/ prefixes)
47
+ if rel_path.startswith("a/") or rel_path.startswith("b/"):
48
+ rel_path = rel_path[2:]
49
+
50
+ target_file = (base / rel_path).resolve()
51
+
52
+ if not is_new and not target_file.exists():
53
+ print(f"ERROR: Target file {target_file} does not exist for patching.")
54
+ return False
55
+
56
+ # 1. Create backup
57
+ if target_file.exists():
58
+ backup_file = backup_dir / f"{target_file.name}.bak"
59
+ shutil.copy2(target_file, backup_file)
60
+ backups.append((target_file, backup_file))
61
+ else:
62
+ backups.append((target_file, None)) # Mark for deletion on rollback if it's a new file
63
+
64
+ # 2. Apply patch
65
+ content = ""
66
+ if target_file.exists():
67
+ with open(target_file, "r", encoding="utf-8") as f:
68
+ content = f.read()
69
+
70
+ lines = content.splitlines()
71
+ # whatthepatch apply_diff returns a generator of lines
72
+ patched_lines = whatthepatch.apply_diff(patch, lines)
73
+
74
+ if patched_lines is None:
75
+ print(f"ERROR: Failed to apply patch to {rel_path}.")
76
+ raise Exception(f"Patch failure on {rel_path}")
77
+
78
+ # 3. Write new content
79
+ target_file.parent.mkdir(parents=True, exist_ok=True)
80
+ with open(target_file, "w", encoding="utf-8") as f:
81
+ f.write("\n".join(patched_lines) + "\n")
82
+
83
+ print(f"SUCCESS: Applied {len(patches)} patches successfully.")
84
+
85
+ # 4. Pre-flight Validation
86
+ # Run tests to ensure the patch didn't break anything
87
+ if (base / "tests").exists():
88
+ print("🧪 Running pre-flight validation (pytest)...")
89
+ import subprocess
90
+ # Run pytest in the base_path
91
+ result = subprocess.run(["pytest"], cwd=str(base), capture_output=True, text=True)
92
+ if result.returncode != 0:
93
+ print(f"❌ Validation FAILED:\n{result.stdout}")
94
+ raise Exception("Pre-flight validation failed. Tests are broken.")
95
+ else:
96
+ print("✅ Pre-flight validation passed!")
97
+
98
+ return True
99
+
100
+ except Exception as e:
101
+ print(f"CRITICAL ERROR: {e}. Starting rollback...")
102
+ for target, backup in backups:
103
+ if backup and backup.exists():
104
+ shutil.move(str(backup), str(target))
105
+ elif not backup and target.exists():
106
+ target.unlink() # Delete newly created file
107
+ return False
33
108
 
34
109
  @staticmethod
35
110
  def apply_whole_file(file_path: str, content: str):