akitallm 0.1.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- akita/__init__.py +1 -1
- akita/cli/main.py +153 -24
- akita/core/ast_utils.py +77 -0
- akita/core/config.py +12 -2
- akita/core/indexing.py +94 -0
- akita/core/plugins.py +81 -0
- akita/core/providers.py +181 -0
- akita/core/trace.py +18 -0
- akita/models/base.py +12 -7
- akita/plugins/__init__.py +1 -0
- akita/plugins/files.py +34 -0
- akita/reasoning/engine.py +44 -18
- akita/reasoning/session.py +15 -0
- akita/tools/base.py +6 -1
- akita/tools/context.py +54 -9
- akita/tools/diff.py +100 -25
- akita/tools/git.py +79 -0
- {akitallm-0.1.1.dist-info → akitallm-1.1.0.dist-info}/METADATA +8 -11
- akitallm-1.1.0.dist-info/RECORD +24 -0
- akitallm-1.1.0.dist-info/entry_points.txt +5 -0
- akitallm-0.1.1.dist-info/RECORD +0 -15
- akitallm-0.1.1.dist-info/entry_points.txt +0 -2
- {akitallm-0.1.1.dist-info → akitallm-1.1.0.dist-info}/WHEEL +0 -0
- {akitallm-0.1.1.dist-info → akitallm-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {akitallm-0.1.1.dist-info → akitallm-1.1.0.dist-info}/top_level.txt +0 -0
akita/core/providers.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
import requests
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
class ModelInfo(BaseModel):
|
|
7
|
+
id: str
|
|
8
|
+
name: Optional[str] = None
|
|
9
|
+
|
|
10
|
+
class BaseProvider(ABC):
|
|
11
|
+
@property
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def name(self) -> str:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def validate_key(self, api_key: str) -> bool:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def list_models(self, api_key: str) -> List[ModelInfo]:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
class OpenAIProvider(BaseProvider):
|
|
25
|
+
@property
|
|
26
|
+
def name(self) -> str:
|
|
27
|
+
return "openai"
|
|
28
|
+
|
|
29
|
+
def validate_key(self, api_key: str) -> bool:
|
|
30
|
+
if not api_key.startswith("sk-"):
|
|
31
|
+
return False
|
|
32
|
+
# Simple validation request
|
|
33
|
+
try:
|
|
34
|
+
response = requests.get(
|
|
35
|
+
"https://api.openai.com/v1/models",
|
|
36
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
37
|
+
timeout=5
|
|
38
|
+
)
|
|
39
|
+
return response.status_code == 200
|
|
40
|
+
except Exception:
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
def list_models(self, api_key: str) -> List[ModelInfo]:
|
|
44
|
+
response = requests.get(
|
|
45
|
+
"https://api.openai.com/v1/models",
|
|
46
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
47
|
+
timeout=10
|
|
48
|
+
)
|
|
49
|
+
response.raise_for_status()
|
|
50
|
+
data = response.json()
|
|
51
|
+
exclude_keywords = ["vision", "instruct", "audio", "realtime", "tts", "dall-e", "embedding", "moderation", "davinci", "babbage", "curie", "ada"]
|
|
52
|
+
|
|
53
|
+
models = []
|
|
54
|
+
for m in data["data"]:
|
|
55
|
+
model_id = m["id"]
|
|
56
|
+
if not any(kw in model_id.lower() for kw in exclude_keywords):
|
|
57
|
+
if model_id.startswith("gpt-") or model_id.startswith("o1") or model_id.startswith("o3"):
|
|
58
|
+
models.append(ModelInfo(id=model_id))
|
|
59
|
+
return sorted(models, key=lambda x: x.id)
|
|
60
|
+
|
|
61
|
+
class AnthropicProvider(BaseProvider):
|
|
62
|
+
@property
|
|
63
|
+
def name(self) -> str:
|
|
64
|
+
return "anthropic"
|
|
65
|
+
|
|
66
|
+
def validate_key(self, api_key: str) -> bool:
|
|
67
|
+
if not api_key.startswith("sk-ant-"):
|
|
68
|
+
return False
|
|
69
|
+
# Anthropic validation usually requires a full request, but we'll check prefix for now
|
|
70
|
+
# or do a no-op call if possible.
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
def list_models(self, api_key: str) -> List[ModelInfo]:
|
|
74
|
+
# Anthropic doesn't have a public models list API like OpenAI
|
|
75
|
+
return [
|
|
76
|
+
ModelInfo(id="claude-3-5-sonnet-latest", name="Claude 3.5 Sonnet (Latest)"),
|
|
77
|
+
ModelInfo(id="claude-3-5-haiku-latest", name="Claude 3.5 Haiku (Latest)"),
|
|
78
|
+
ModelInfo(id="claude-3-opus-20240229", name="Claude 3 Opus"),
|
|
79
|
+
ModelInfo(id="claude-3-sonnet-20240229", name="Claude 3 Sonnet"),
|
|
80
|
+
ModelInfo(id="claude-3-haiku-20240307", name="Claude 3 Haiku"),
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
class OllamaProvider(BaseProvider):
|
|
84
|
+
@property
|
|
85
|
+
def name(self) -> str:
|
|
86
|
+
return "ollama"
|
|
87
|
+
|
|
88
|
+
def validate_key(self, api_key: str) -> bool:
|
|
89
|
+
# Ollama doesn't use keys by default, we just check if it's reachable
|
|
90
|
+
try:
|
|
91
|
+
response = requests.get("http://localhost:11434/api/tags", timeout=2)
|
|
92
|
+
return response.status_code == 200
|
|
93
|
+
except Exception:
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
def list_models(self, api_key: str) -> List[ModelInfo]:
|
|
97
|
+
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
|
98
|
+
response.raise_for_status()
|
|
99
|
+
data = response.json()
|
|
100
|
+
return [ModelInfo(id=m["name"]) for m in data["models"]]
|
|
101
|
+
|
|
102
|
+
class GeminiProvider(BaseProvider):
|
|
103
|
+
@property
|
|
104
|
+
def name(self) -> str:
|
|
105
|
+
return "gemini"
|
|
106
|
+
|
|
107
|
+
def validate_key(self, api_key: str) -> bool:
|
|
108
|
+
if not api_key.startswith("AIza"):
|
|
109
|
+
return False
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
def list_models(self, api_key: str) -> List[ModelInfo]:
|
|
113
|
+
# Gemini API URL for listing models
|
|
114
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}"
|
|
115
|
+
response = requests.get(url, timeout=10)
|
|
116
|
+
response.raise_for_status()
|
|
117
|
+
data = response.json()
|
|
118
|
+
|
|
119
|
+
exclude_keywords = ["nano", "banana", "vision", "embedding", "aqa", "learnlm"]
|
|
120
|
+
|
|
121
|
+
models = []
|
|
122
|
+
for m in data["models"]:
|
|
123
|
+
model_id = m["name"].split("/")[-1]
|
|
124
|
+
display_name = m["displayName"]
|
|
125
|
+
|
|
126
|
+
# Check if it supports generation and doesn't have excluded keywords
|
|
127
|
+
if "generateContent" in m["supportedGenerationMethods"]:
|
|
128
|
+
if not any(kw in model_id.lower() or kw in display_name.lower() for kw in exclude_keywords):
|
|
129
|
+
models.append(ModelInfo(id=model_id, name=display_name))
|
|
130
|
+
|
|
131
|
+
return models
|
|
132
|
+
|
|
133
|
+
class GroqProvider(BaseProvider):
|
|
134
|
+
@property
|
|
135
|
+
def name(self) -> str:
|
|
136
|
+
return "groq"
|
|
137
|
+
|
|
138
|
+
def validate_key(self, api_key: str) -> bool:
|
|
139
|
+
if not api_key.startswith("gsk_"):
|
|
140
|
+
return False
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
def list_models(self, api_key: str) -> List[ModelInfo]:
|
|
144
|
+
# Groq uses OpenAI-compatible models endpoint
|
|
145
|
+
response = requests.get(
|
|
146
|
+
"https://api.groq.com/openai/v1/models",
|
|
147
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
148
|
+
timeout=10
|
|
149
|
+
)
|
|
150
|
+
response.raise_for_status()
|
|
151
|
+
data = response.json()
|
|
152
|
+
|
|
153
|
+
# Filter for text models
|
|
154
|
+
exclude_keywords = ["vision", "audio"]
|
|
155
|
+
models = []
|
|
156
|
+
for m in data["data"]:
|
|
157
|
+
model_id = m["id"]
|
|
158
|
+
if not any(kw in model_id.lower() for kw in exclude_keywords):
|
|
159
|
+
models.append(ModelInfo(id=model_id))
|
|
160
|
+
return sorted(models, key=lambda x: x.id)
|
|
161
|
+
|
|
162
|
+
def detect_provider(api_key: str) -> Optional[BaseProvider]:
|
|
163
|
+
"""
|
|
164
|
+
Attempts to detect the provider based on the API key or environment.
|
|
165
|
+
"""
|
|
166
|
+
if api_key.lower() == "ollama":
|
|
167
|
+
return OllamaProvider()
|
|
168
|
+
|
|
169
|
+
if api_key.startswith("sk-ant-"):
|
|
170
|
+
return AnthropicProvider()
|
|
171
|
+
|
|
172
|
+
if api_key.startswith("gsk_"):
|
|
173
|
+
return GroqProvider()
|
|
174
|
+
|
|
175
|
+
if api_key.startswith("sk-"):
|
|
176
|
+
return OpenAIProvider()
|
|
177
|
+
|
|
178
|
+
if api_key.startswith("AIza"):
|
|
179
|
+
return GeminiProvider()
|
|
180
|
+
|
|
181
|
+
return None
|
akita/core/trace.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
class TraceStep(BaseModel):
|
|
6
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
7
|
+
action: str
|
|
8
|
+
details: str
|
|
9
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
10
|
+
|
|
11
|
+
class ReasoningTrace(BaseModel):
|
|
12
|
+
steps: List[TraceStep] = Field(default_factory=list)
|
|
13
|
+
|
|
14
|
+
def add_step(self, action: str, details: str, metadata: Dict[str, Any] = None):
|
|
15
|
+
self.steps.append(TraceStep(action=action, details=details, metadata=metadata or {}))
|
|
16
|
+
|
|
17
|
+
def __str__(self):
|
|
18
|
+
return "\n".join([f"[{s.timestamp.strftime('%H:%M:%S')}] {s.action}: {s.details}" for s in self.steps])
|
akita/models/base.py
CHANGED
|
@@ -35,14 +35,19 @@ def get_model(model_name: Optional[str] = None) -> AIModel:
|
|
|
35
35
|
"""
|
|
36
36
|
Get an AIModel instance based on config or provided name.
|
|
37
37
|
"""
|
|
38
|
+
provider = get_config_value("model", "provider", "openai")
|
|
39
|
+
api_key = get_config_value("model", "api_key")
|
|
40
|
+
|
|
38
41
|
if model_name is None:
|
|
39
42
|
model_name = get_config_value("model", "name", "gpt-4o-mini")
|
|
40
43
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
# LiteLLM wants "provider/model_name" for non-OpenAI providers
|
|
45
|
+
if provider == "openai":
|
|
46
|
+
full_model_name = model_name
|
|
47
|
+
elif provider == "gemini":
|
|
48
|
+
full_model_name = f"gemini/{model_name}"
|
|
49
|
+
else:
|
|
50
|
+
full_model_name = f"{provider}/{model_name}"
|
|
47
51
|
|
|
48
|
-
|
|
52
|
+
# For Ollama, we might need a base_url, but for now we assume default
|
|
53
|
+
return AIModel(model_name=full_model_name, api_key=api_key)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Official AkitaLLM Plugins
|
akita/plugins/files.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from akita.core.plugins import AkitaPlugin
|
|
2
|
+
from akita.tools.base import FileSystemTools
|
|
3
|
+
from typing import List, Dict, Any
|
|
4
|
+
|
|
5
|
+
class FilesPlugin(AkitaPlugin):
|
|
6
|
+
@property
|
|
7
|
+
def name(self) -> str:
|
|
8
|
+
return "files"
|
|
9
|
+
|
|
10
|
+
@property
|
|
11
|
+
def description(self) -> str:
|
|
12
|
+
return "Standard filesystem operations (read, write, list)."
|
|
13
|
+
|
|
14
|
+
def get_tools(self) -> List[Dict[str, Any]]:
|
|
15
|
+
return [
|
|
16
|
+
{
|
|
17
|
+
"name": "read_file",
|
|
18
|
+
"description": "Read content from a file.",
|
|
19
|
+
"parameters": {"path": "string"},
|
|
20
|
+
"func": FileSystemTools.read_file
|
|
21
|
+
},
|
|
22
|
+
{
|
|
23
|
+
"name": "write_file",
|
|
24
|
+
"description": "Write content to a file.",
|
|
25
|
+
"parameters": {"path": "string", "content": "string"},
|
|
26
|
+
"func": FileSystemTools.write_file
|
|
27
|
+
},
|
|
28
|
+
{
|
|
29
|
+
"name": "list_dir",
|
|
30
|
+
"description": "List files in a directory.",
|
|
31
|
+
"parameters": {"path": "string"},
|
|
32
|
+
"func": FileSystemTools.list_dir
|
|
33
|
+
}
|
|
34
|
+
]
|
akita/reasoning/engine.py
CHANGED
|
@@ -1,16 +1,23 @@
|
|
|
1
1
|
from typing import List, Dict, Any, Optional
|
|
2
2
|
from akita.models.base import AIModel, get_model
|
|
3
|
-
from akita.tools.base import ShellTools
|
|
3
|
+
from akita.tools.base import ShellTools
|
|
4
|
+
from akita.core.plugins import PluginManager
|
|
4
5
|
from akita.tools.context import ContextBuilder
|
|
5
6
|
from akita.schemas.review import ReviewResult
|
|
7
|
+
from akita.core.trace import ReasoningTrace
|
|
8
|
+
from akita.reasoning.session import ConversationSession
|
|
6
9
|
import json
|
|
7
10
|
from rich.console import Console
|
|
8
11
|
|
|
9
12
|
console = Console()
|
|
10
|
-
|
|
13
|
+
|
|
11
14
|
class ReasoningEngine:
|
|
12
15
|
def __init__(self, model: AIModel):
|
|
13
16
|
self.model = model
|
|
17
|
+
self.plugin_manager = PluginManager()
|
|
18
|
+
self.plugin_manager.discover_all()
|
|
19
|
+
self.trace = ReasoningTrace()
|
|
20
|
+
self.session: Optional[ConversationSession] = None
|
|
14
21
|
|
|
15
22
|
def run_review(self, path: str) -> ReviewResult:
|
|
16
23
|
"""
|
|
@@ -91,27 +98,46 @@ class ReasoningEngine:
|
|
|
91
98
|
])
|
|
92
99
|
return response.content
|
|
93
100
|
|
|
94
|
-
def run_solve(self, query: str, path: str = ".") -> str:
|
|
101
|
+
def run_solve(self, query: str, path: str = ".", session: Optional[ConversationSession] = None) -> str:
|
|
95
102
|
"""
|
|
96
103
|
Generates a Unified Diff solution for the given query.
|
|
104
|
+
Supports iterative refinement if a session is provided.
|
|
97
105
|
"""
|
|
98
|
-
|
|
99
|
-
builder = ContextBuilder(path)
|
|
100
|
-
snapshot = builder.build()
|
|
101
|
-
|
|
102
|
-
files_str = "\n---\n".join([f"FILE: {f.path}\nCONTENT:\n{f.content}" for f in snapshot.files[:10]]) # Limit for solve
|
|
106
|
+
self.trace.add_step("Solve", f"Starting solve for query: {query}")
|
|
103
107
|
|
|
104
|
-
|
|
105
|
-
"
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
108
|
+
if not session:
|
|
109
|
+
self.trace.add_step("Context", f"Building context for {path}")
|
|
110
|
+
builder = ContextBuilder(path)
|
|
111
|
+
snapshot = builder.build(query=query)
|
|
112
|
+
|
|
113
|
+
files_str = "\n---\n".join([f"FILE: {f.path}\nCONTENT:\n{f.content}" for f in snapshot.files[:10]])
|
|
114
|
+
|
|
115
|
+
rag_str = ""
|
|
116
|
+
if snapshot.rag_snippets:
|
|
117
|
+
rag_str = "\n\nRELEVANT SNIPPETS (RAG):\n" + "\n".join([
|
|
118
|
+
f"- {s['path']} ({s['name']}):\n{s['content']}" for s in snapshot.rag_snippets
|
|
119
|
+
])
|
|
120
|
+
|
|
121
|
+
tools_info = "\n".join([f"- {t['name']}: {t['description']}" for t in self.plugin_manager.get_all_tools()])
|
|
122
|
+
|
|
123
|
+
system_prompt = (
|
|
124
|
+
"You are an Expert Programmer. Solve the requested task by providing code changes in Unified Diff format. "
|
|
125
|
+
"Respond ONLY with the Diff block. Use +++ and --- with file paths relative to project root.\n\n"
|
|
126
|
+
f"Available Tools:\n{tools_info}"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
session = ConversationSession()
|
|
130
|
+
session.add_message("system", system_prompt)
|
|
131
|
+
session.add_message("user", f"Task: {query}\n\nContext:\n{files_str}{rag_str}")
|
|
132
|
+
self.session = session
|
|
133
|
+
else:
|
|
134
|
+
session.add_message("user", query)
|
|
135
|
+
|
|
136
|
+
console.print("🤖 [bold green]Thinking...[/]")
|
|
137
|
+
response = self.model.chat(session.get_messages_dict())
|
|
138
|
+
session.add_message("assistant", response.content)
|
|
109
139
|
|
|
110
|
-
|
|
111
|
-
response = self.model.chat([
|
|
112
|
-
{"role": "system", "content": system_prompt},
|
|
113
|
-
{"role": "user", "content": user_prompt}
|
|
114
|
-
])
|
|
140
|
+
self.trace.add_step("LLM Response", "Received solution from model")
|
|
115
141
|
return response.content
|
|
116
142
|
|
|
117
143
|
def run_pipeline(self, task: str):
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
from pydantic import BaseModel, Field
|
|
3
|
+
|
|
4
|
+
class ChatMessage(BaseModel):
|
|
5
|
+
role: str
|
|
6
|
+
content: str
|
|
7
|
+
|
|
8
|
+
class ConversationSession(BaseModel):
|
|
9
|
+
messages: List[ChatMessage] = Field(default_factory=list)
|
|
10
|
+
|
|
11
|
+
def add_message(self, role: str, content: str):
|
|
12
|
+
self.messages.append(ChatMessage(role=role, content=content))
|
|
13
|
+
|
|
14
|
+
def get_messages_dict(self) -> List[Dict[str, str]]:
|
|
15
|
+
return [m.model_dump() for m in self.messages]
|
akita/tools/base.py
CHANGED
|
@@ -15,7 +15,12 @@ class FileSystemTools:
|
|
|
15
15
|
return f.read()
|
|
16
16
|
|
|
17
17
|
@staticmethod
|
|
18
|
-
def
|
|
18
|
+
def write_file(path: str, content: str):
|
|
19
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
20
|
+
f.write(content)
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def list_dir(path: str) -> List[str]:
|
|
19
24
|
return os.listdir(path)
|
|
20
25
|
|
|
21
26
|
class ShellTools:
|
akita/tools/context.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import List, Dict, Optional
|
|
3
|
+
from typing import Any, List, Dict, Optional
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
6
6
|
class FileContext(BaseModel):
|
|
7
7
|
path: str
|
|
8
8
|
content: str
|
|
9
9
|
extension: str
|
|
10
|
+
summary: Optional[str] = None # New field for semantic summary
|
|
10
11
|
|
|
11
12
|
class ContextSnapshot(BaseModel):
|
|
12
13
|
files: List[FileContext]
|
|
13
14
|
project_structure: List[str]
|
|
15
|
+
rag_snippets: Optional[List[Dict[str, Any]]] = None
|
|
14
16
|
|
|
15
17
|
class ContextBuilder:
|
|
16
18
|
def __init__(
|
|
@@ -19,26 +21,49 @@ class ContextBuilder:
|
|
|
19
21
|
extensions: Optional[List[str]] = None,
|
|
20
22
|
exclude_dirs: Optional[List[str]] = None,
|
|
21
23
|
max_file_size_kb: int = 50,
|
|
22
|
-
max_files: int = 50
|
|
24
|
+
max_files: int = 50,
|
|
25
|
+
use_semantical_context: bool = True
|
|
23
26
|
):
|
|
24
27
|
self.base_path = Path(base_path)
|
|
25
28
|
self.extensions = extensions or [".py", ".js", ".ts", ".cpp", ".h", ".toml", ".md", ".json"]
|
|
26
29
|
self.exclude_dirs = exclude_dirs or [".git", ".venv", "node_modules", "__pycache__", "dist", "build"]
|
|
27
30
|
self.max_file_size_kb = max_file_size_kb
|
|
28
31
|
self.max_files = max_files
|
|
32
|
+
self.use_semantical_context = use_semantical_context
|
|
33
|
+
|
|
34
|
+
if self.use_semantical_context:
|
|
35
|
+
try:
|
|
36
|
+
from akita.core.ast_utils import ASTParser
|
|
37
|
+
from akita.core.indexing import CodeIndexer
|
|
38
|
+
self.ast_parser = ASTParser()
|
|
39
|
+
self.indexer = CodeIndexer(str(self.base_path))
|
|
40
|
+
except ImportError:
|
|
41
|
+
self.ast_parser = None
|
|
42
|
+
self.indexer = None
|
|
29
43
|
|
|
30
|
-
def build(self) -> ContextSnapshot:
|
|
31
|
-
"""
|
|
44
|
+
def build(self, query: Optional[str] = None) -> ContextSnapshot:
|
|
45
|
+
"""
|
|
46
|
+
Scan the path and build a context snapshot.
|
|
47
|
+
If a query is provided and indexer is available, it includes RAG snippets.
|
|
48
|
+
"""
|
|
32
49
|
files_context = []
|
|
33
50
|
project_structure = []
|
|
51
|
+
rag_snippets = None
|
|
34
52
|
|
|
53
|
+
if query and self.indexer:
|
|
54
|
+
try:
|
|
55
|
+
# Ensure index exists (lazy indexing for now)
|
|
56
|
+
# In production, we'd have a separate command or check timestamps
|
|
57
|
+
rag_snippets = self.indexer.search(query, n_results=10)
|
|
58
|
+
except Exception:
|
|
59
|
+
pass
|
|
60
|
+
|
|
35
61
|
if self.base_path.is_file():
|
|
36
62
|
if self._should_include_file(self.base_path):
|
|
37
63
|
files_context.append(self._read_file(self.base_path))
|
|
38
64
|
project_structure.append(str(self.base_path.name))
|
|
39
65
|
else:
|
|
40
66
|
for root, dirs, files in os.walk(self.base_path):
|
|
41
|
-
# Filter out excluded directories
|
|
42
67
|
dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
|
|
43
68
|
|
|
44
69
|
rel_root = os.path.relpath(root, self.base_path)
|
|
@@ -54,7 +79,11 @@ class ContextBuilder:
|
|
|
54
79
|
files_context.append(context)
|
|
55
80
|
project_structure.append(os.path.join(rel_root, file))
|
|
56
81
|
|
|
57
|
-
return ContextSnapshot(
|
|
82
|
+
return ContextSnapshot(
|
|
83
|
+
files=files_context,
|
|
84
|
+
project_structure=project_structure,
|
|
85
|
+
rag_snippets=rag_snippets
|
|
86
|
+
)
|
|
58
87
|
|
|
59
88
|
def _should_include_file(self, path: Path) -> bool:
|
|
60
89
|
if path.name == ".env" or path.suffix == ".env":
|
|
@@ -66,8 +95,12 @@ class ContextBuilder:
|
|
|
66
95
|
if not path.exists():
|
|
67
96
|
return False
|
|
68
97
|
|
|
69
|
-
# Check size
|
|
70
|
-
|
|
98
|
+
# Check size (we can be more lenient if using semantic summaries)
|
|
99
|
+
size_limit = self.max_file_size_kb * 1024
|
|
100
|
+
if self.use_semantical_context:
|
|
101
|
+
size_limit *= 2 # Allow larger files if we can summarize them
|
|
102
|
+
|
|
103
|
+
if path.stat().st_size > size_limit:
|
|
71
104
|
return False
|
|
72
105
|
|
|
73
106
|
return True
|
|
@@ -76,10 +109,22 @@ class ContextBuilder:
|
|
|
76
109
|
try:
|
|
77
110
|
with open(path, 'r', encoding='utf-8') as f:
|
|
78
111
|
content = f.read()
|
|
112
|
+
|
|
113
|
+
summary = None
|
|
114
|
+
if self.use_semantical_context and self.ast_parser and path.suffix == ".py":
|
|
115
|
+
try:
|
|
116
|
+
defs = self.ast_parser.get_definitions(str(path))
|
|
117
|
+
if defs:
|
|
118
|
+
summary_lines = [f"{d['type'].upper()} {d['name']} (L{d['start_line']}-L{d['end_line']})" for d in defs]
|
|
119
|
+
summary = "\n".join(summary_lines)
|
|
120
|
+
except Exception:
|
|
121
|
+
pass
|
|
122
|
+
|
|
79
123
|
return FileContext(
|
|
80
124
|
path=str(path.relative_to(self.base_path) if self.base_path.is_dir() else path.name),
|
|
81
125
|
content=content,
|
|
82
|
-
extension=path.suffix
|
|
126
|
+
extension=path.suffix,
|
|
127
|
+
summary=summary
|
|
83
128
|
)
|
|
84
129
|
except Exception:
|
|
85
130
|
return None
|
akita/tools/diff.py
CHANGED
|
@@ -1,35 +1,110 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import pathlib
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
import
|
|
5
|
+
import whatthepatch
|
|
6
|
+
from typing import List, Tuple, Optional
|
|
4
7
|
|
|
5
8
|
class DiffApplier:
|
|
6
9
|
@staticmethod
|
|
7
|
-
def apply_unified_diff(diff_text: str, base_path: str = "."):
|
|
10
|
+
def apply_unified_diff(diff_text: str, base_path: str = ".") -> bool:
|
|
8
11
|
"""
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
For AkitaLLM, we keep it simple for now.
|
|
12
|
+
Applies a unified diff to files in the base_path.
|
|
13
|
+
Includes backup and rollback logic for atomicity.
|
|
12
14
|
"""
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
15
|
+
patches = list(whatthepatch.parse_patch(diff_text))
|
|
16
|
+
if not patches:
|
|
17
|
+
print("ERROR: No valid patches found in the diff text.")
|
|
18
|
+
return False
|
|
19
|
+
|
|
20
|
+
backups: List[Tuple[Path, Path]] = []
|
|
21
|
+
base = Path(base_path)
|
|
22
|
+
backup_dir = base / ".akita" / "backups"
|
|
23
|
+
backup_dir.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
for patch in patches:
|
|
27
|
+
if not patch.header:
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
# whatthepatch identifies the target file in the header
|
|
31
|
+
# We usually want the 'new' filename (the +++ part)
|
|
32
|
+
rel_path = patch.header.new_path
|
|
33
|
+
is_new = (patch.header.old_path == "/dev/null")
|
|
34
|
+
is_delete = (patch.header.new_path == "/dev/null")
|
|
35
|
+
|
|
36
|
+
if is_new:
|
|
37
|
+
rel_path = patch.header.new_path
|
|
38
|
+
elif is_delete:
|
|
39
|
+
rel_path = patch.header.old_path
|
|
40
|
+
else:
|
|
41
|
+
rel_path = patch.header.new_path or patch.header.old_path
|
|
42
|
+
|
|
43
|
+
if not rel_path or rel_path == "/dev/null":
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
# Clean up path (sometimes they have a/ or b/ prefixes)
|
|
47
|
+
if rel_path.startswith("a/") or rel_path.startswith("b/"):
|
|
48
|
+
rel_path = rel_path[2:]
|
|
49
|
+
|
|
50
|
+
target_file = (base / rel_path).resolve()
|
|
51
|
+
|
|
52
|
+
if not is_new and not target_file.exists():
|
|
53
|
+
print(f"ERROR: Target file {target_file} does not exist for patching.")
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
# 1. Create backup
|
|
57
|
+
if target_file.exists():
|
|
58
|
+
backup_file = backup_dir / f"{target_file.name}.bak"
|
|
59
|
+
shutil.copy2(target_file, backup_file)
|
|
60
|
+
backups.append((target_file, backup_file))
|
|
61
|
+
else:
|
|
62
|
+
backups.append((target_file, None)) # Mark for deletion on rollback if it's a new file
|
|
63
|
+
|
|
64
|
+
# 2. Apply patch
|
|
65
|
+
content = ""
|
|
66
|
+
if target_file.exists():
|
|
67
|
+
with open(target_file, "r", encoding="utf-8") as f:
|
|
68
|
+
content = f.read()
|
|
69
|
+
|
|
70
|
+
lines = content.splitlines()
|
|
71
|
+
# whatthepatch apply_diff returns a generator of lines
|
|
72
|
+
patched_lines = whatthepatch.apply_diff(patch, lines)
|
|
73
|
+
|
|
74
|
+
if patched_lines is None:
|
|
75
|
+
print(f"ERROR: Failed to apply patch to {rel_path}.")
|
|
76
|
+
raise Exception(f"Patch failure on {rel_path}")
|
|
77
|
+
|
|
78
|
+
# 3. Write new content
|
|
79
|
+
target_file.parent.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
with open(target_file, "w", encoding="utf-8") as f:
|
|
81
|
+
f.write("\n".join(patched_lines) + "\n")
|
|
82
|
+
|
|
83
|
+
print(f"SUCCESS: Applied {len(patches)} patches successfully.")
|
|
84
|
+
|
|
85
|
+
# 4. Pre-flight Validation
|
|
86
|
+
# Run tests to ensure the patch didn't break anything
|
|
87
|
+
if (base / "tests").exists():
|
|
88
|
+
print("🧪 Running pre-flight validation (pytest)...")
|
|
89
|
+
import subprocess
|
|
90
|
+
# Run pytest in the base_path
|
|
91
|
+
result = subprocess.run(["pytest"], cwd=str(base), capture_output=True, text=True)
|
|
92
|
+
if result.returncode != 0:
|
|
93
|
+
print(f"❌ Validation FAILED:\n{result.stdout}")
|
|
94
|
+
raise Exception("Pre-flight validation failed. Tests are broken.")
|
|
95
|
+
else:
|
|
96
|
+
print("✅ Pre-flight validation passed!")
|
|
97
|
+
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
print(f"CRITICAL ERROR: {e}. Starting rollback...")
|
|
102
|
+
for target, backup in backups:
|
|
103
|
+
if backup and backup.exists():
|
|
104
|
+
shutil.move(str(backup), str(target))
|
|
105
|
+
elif not backup and target.exists():
|
|
106
|
+
target.unlink() # Delete newly created file
|
|
107
|
+
return False
|
|
33
108
|
|
|
34
109
|
@staticmethod
|
|
35
110
|
def apply_whole_file(file_path: str, content: str):
|