cognify-code 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_code_assistant/__init__.py +14 -0
- ai_code_assistant/agent/__init__.py +63 -0
- ai_code_assistant/agent/code_agent.py +461 -0
- ai_code_assistant/agent/code_generator.py +388 -0
- ai_code_assistant/agent/code_reviewer.py +365 -0
- ai_code_assistant/agent/diff_engine.py +308 -0
- ai_code_assistant/agent/file_manager.py +300 -0
- ai_code_assistant/agent/intent_classifier.py +284 -0
- ai_code_assistant/chat/__init__.py +11 -0
- ai_code_assistant/chat/agent_session.py +156 -0
- ai_code_assistant/chat/session.py +165 -0
- ai_code_assistant/cli.py +1571 -0
- ai_code_assistant/config.py +149 -0
- ai_code_assistant/editor/__init__.py +8 -0
- ai_code_assistant/editor/diff_handler.py +270 -0
- ai_code_assistant/editor/file_editor.py +350 -0
- ai_code_assistant/editor/prompts.py +146 -0
- ai_code_assistant/generator/__init__.py +7 -0
- ai_code_assistant/generator/code_gen.py +265 -0
- ai_code_assistant/generator/prompts.py +114 -0
- ai_code_assistant/git/__init__.py +6 -0
- ai_code_assistant/git/commit_generator.py +130 -0
- ai_code_assistant/git/manager.py +203 -0
- ai_code_assistant/llm.py +111 -0
- ai_code_assistant/providers/__init__.py +23 -0
- ai_code_assistant/providers/base.py +124 -0
- ai_code_assistant/providers/cerebras.py +97 -0
- ai_code_assistant/providers/factory.py +148 -0
- ai_code_assistant/providers/google.py +103 -0
- ai_code_assistant/providers/groq.py +111 -0
- ai_code_assistant/providers/ollama.py +86 -0
- ai_code_assistant/providers/openai.py +114 -0
- ai_code_assistant/providers/openrouter.py +130 -0
- ai_code_assistant/py.typed +0 -0
- ai_code_assistant/refactor/__init__.py +20 -0
- ai_code_assistant/refactor/analyzer.py +189 -0
- ai_code_assistant/refactor/change_plan.py +172 -0
- ai_code_assistant/refactor/multi_file_editor.py +346 -0
- ai_code_assistant/refactor/prompts.py +175 -0
- ai_code_assistant/retrieval/__init__.py +19 -0
- ai_code_assistant/retrieval/chunker.py +215 -0
- ai_code_assistant/retrieval/indexer.py +236 -0
- ai_code_assistant/retrieval/search.py +239 -0
- ai_code_assistant/reviewer/__init__.py +7 -0
- ai_code_assistant/reviewer/analyzer.py +278 -0
- ai_code_assistant/reviewer/prompts.py +113 -0
- ai_code_assistant/utils/__init__.py +18 -0
- ai_code_assistant/utils/file_handler.py +155 -0
- ai_code_assistant/utils/formatters.py +259 -0
- cognify_code-0.2.0.dist-info/METADATA +383 -0
- cognify_code-0.2.0.dist-info/RECORD +55 -0
- cognify_code-0.2.0.dist-info/WHEEL +5 -0
- cognify_code-0.2.0.dist-info/entry_points.txt +3 -0
- cognify_code-0.2.0.dist-info/licenses/LICENSE +22 -0
- cognify_code-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""File Context Manager for reading/writing project files."""
|
|
2
|
+
|
|
3
|
+
import fnmatch
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Optional, Set
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Common patterns to ignore
|
|
11
|
+
DEFAULT_IGNORE_PATTERNS = [
|
|
12
|
+
".git", ".git/*", "__pycache__", "__pycache__/*", "*.pyc",
|
|
13
|
+
".venv", ".venv/*", "venv", "venv/*", "env", "env/*",
|
|
14
|
+
"node_modules", "node_modules/*", ".next", ".next/*",
|
|
15
|
+
"dist", "dist/*", "build", "build/*", ".cache", ".cache/*",
|
|
16
|
+
"*.egg-info", "*.egg-info/*", ".eggs", ".eggs/*",
|
|
17
|
+
".pytest_cache", ".pytest_cache/*", ".mypy_cache", ".mypy_cache/*",
|
|
18
|
+
".tox", ".tox/*", ".coverage", "htmlcov", "htmlcov/*",
|
|
19
|
+
"*.log", "*.tmp", "*.temp", ".DS_Store", "Thumbs.db",
|
|
20
|
+
"*.min.js", "*.min.css", "*.map",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# File extensions we can work with
|
|
24
|
+
CODE_EXTENSIONS = {
|
|
25
|
+
".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rs",
|
|
26
|
+
".c", ".cpp", ".h", ".hpp", ".cs", ".rb", ".php", ".swift",
|
|
27
|
+
".kt", ".scala", ".r", ".sql", ".sh", ".bash", ".zsh",
|
|
28
|
+
".yaml", ".yml", ".json", ".toml", ".xml", ".html", ".css",
|
|
29
|
+
".scss", ".sass", ".less", ".md", ".rst", ".txt",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class FileInfo:
|
|
35
|
+
"""Information about a file."""
|
|
36
|
+
path: Path
|
|
37
|
+
relative_path: str
|
|
38
|
+
extension: str
|
|
39
|
+
size: int
|
|
40
|
+
is_code: bool
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def language(self) -> str:
|
|
44
|
+
"""Detect language from extension."""
|
|
45
|
+
ext_to_lang = {
|
|
46
|
+
".py": "python", ".js": "javascript", ".ts": "typescript",
|
|
47
|
+
".jsx": "javascript", ".tsx": "typescript", ".java": "java",
|
|
48
|
+
".go": "go", ".rs": "rust", ".c": "c", ".cpp": "cpp",
|
|
49
|
+
".h": "c", ".hpp": "cpp", ".cs": "csharp", ".rb": "ruby",
|
|
50
|
+
".php": "php", ".swift": "swift", ".kt": "kotlin",
|
|
51
|
+
".scala": "scala", ".r": "r", ".sql": "sql",
|
|
52
|
+
".sh": "bash", ".bash": "bash", ".zsh": "zsh",
|
|
53
|
+
".yaml": "yaml", ".yml": "yaml", ".json": "json",
|
|
54
|
+
".toml": "toml", ".xml": "xml", ".html": "html",
|
|
55
|
+
".css": "css", ".scss": "scss", ".md": "markdown",
|
|
56
|
+
}
|
|
57
|
+
return ext_to_lang.get(self.extension, "text")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class ProjectContext:
|
|
62
|
+
"""Context about the current project."""
|
|
63
|
+
root_path: Path
|
|
64
|
+
files: List[FileInfo] = field(default_factory=list)
|
|
65
|
+
structure: str = ""
|
|
66
|
+
languages: Set[str] = field(default_factory=set)
|
|
67
|
+
total_files: int = 0
|
|
68
|
+
total_code_files: int = 0
|
|
69
|
+
|
|
70
|
+
def get_files_by_language(self, language: str) -> List[FileInfo]:
|
|
71
|
+
"""Get all files of a specific language."""
|
|
72
|
+
return [f for f in self.files if f.language == language]
|
|
73
|
+
|
|
74
|
+
def get_files_by_pattern(self, pattern: str) -> List[FileInfo]:
|
|
75
|
+
"""Get files matching a glob pattern."""
|
|
76
|
+
return [f for f in self.files if fnmatch.fnmatch(f.relative_path, pattern)]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class FileContextManager:
|
|
80
|
+
"""Manages file operations and project context."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, root_path: Optional[Path] = None,
|
|
83
|
+
ignore_patterns: Optional[List[str]] = None):
|
|
84
|
+
self.root_path = Path(root_path or os.getcwd()).resolve()
|
|
85
|
+
self.ignore_patterns = ignore_patterns or DEFAULT_IGNORE_PATTERNS
|
|
86
|
+
self._context: Optional[ProjectContext] = None
|
|
87
|
+
|
|
88
|
+
def _should_ignore(self, path: Path) -> bool:
|
|
89
|
+
"""Check if a path should be ignored."""
|
|
90
|
+
rel_path = str(path.relative_to(self.root_path))
|
|
91
|
+
name = path.name
|
|
92
|
+
|
|
93
|
+
for pattern in self.ignore_patterns:
|
|
94
|
+
if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(rel_path, pattern):
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
def _is_code_file(self, path: Path) -> bool:
|
|
99
|
+
"""Check if a file is a code file."""
|
|
100
|
+
return path.suffix.lower() in CODE_EXTENSIONS
|
|
101
|
+
|
|
102
|
+
def get_project_context(self, refresh: bool = False) -> ProjectContext:
|
|
103
|
+
"""Get or build project context."""
|
|
104
|
+
if self._context and not refresh:
|
|
105
|
+
return self._context
|
|
106
|
+
|
|
107
|
+
context = ProjectContext(root_path=self.root_path)
|
|
108
|
+
structure_lines = []
|
|
109
|
+
|
|
110
|
+
for root, dirs, files in os.walk(self.root_path):
|
|
111
|
+
root_path = Path(root)
|
|
112
|
+
|
|
113
|
+
# Filter out ignored directories
|
|
114
|
+
dirs[:] = [d for d in dirs if not self._should_ignore(root_path / d)]
|
|
115
|
+
|
|
116
|
+
# Calculate depth for indentation
|
|
117
|
+
rel_root = root_path.relative_to(self.root_path)
|
|
118
|
+
depth = len(rel_root.parts)
|
|
119
|
+
indent = " " * depth
|
|
120
|
+
|
|
121
|
+
if depth > 0:
|
|
122
|
+
structure_lines.append(f"{indent[:-2]}📁 {root_path.name}/")
|
|
123
|
+
|
|
124
|
+
for file in sorted(files):
|
|
125
|
+
file_path = root_path / file
|
|
126
|
+
|
|
127
|
+
if self._should_ignore(file_path):
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
rel_path = str(file_path.relative_to(self.root_path))
|
|
131
|
+
is_code = self._is_code_file(file_path)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
size = file_path.stat().st_size
|
|
135
|
+
except OSError:
|
|
136
|
+
size = 0
|
|
137
|
+
|
|
138
|
+
file_info = FileInfo(
|
|
139
|
+
path=file_path,
|
|
140
|
+
relative_path=rel_path,
|
|
141
|
+
extension=file_path.suffix.lower(),
|
|
142
|
+
size=size,
|
|
143
|
+
is_code=is_code,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
context.files.append(file_info)
|
|
147
|
+
context.total_files += 1
|
|
148
|
+
|
|
149
|
+
if is_code:
|
|
150
|
+
context.total_code_files += 1
|
|
151
|
+
context.languages.add(file_info.language)
|
|
152
|
+
|
|
153
|
+
icon = "📄" if is_code else "📋"
|
|
154
|
+
structure_lines.append(f"{indent}{icon} {file}")
|
|
155
|
+
|
|
156
|
+
context.structure = "\n".join(structure_lines)
|
|
157
|
+
self._context = context
|
|
158
|
+
return context
|
|
159
|
+
|
|
160
|
+
def read_file(self, path: str) -> Optional[str]:
|
|
161
|
+
"""Read a file's content."""
|
|
162
|
+
file_path = self._resolve_path(path)
|
|
163
|
+
|
|
164
|
+
if not file_path.exists():
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
return file_path.read_text(encoding="utf-8")
|
|
169
|
+
except Exception:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
def write_file(self, path: str, content: str, create_dirs: bool = True) -> bool:
|
|
173
|
+
"""Write content to a file."""
|
|
174
|
+
file_path = self._resolve_path(path)
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
if create_dirs:
|
|
178
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
|
|
180
|
+
file_path.write_text(content, encoding="utf-8")
|
|
181
|
+
|
|
182
|
+
# Invalidate context cache
|
|
183
|
+
self._context = None
|
|
184
|
+
return True
|
|
185
|
+
except Exception:
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
def file_exists(self, path: str) -> bool:
|
|
189
|
+
"""Check if a file exists."""
|
|
190
|
+
return self._resolve_path(path).exists()
|
|
191
|
+
|
|
192
|
+
def delete_file(self, path: str) -> bool:
|
|
193
|
+
"""Delete a file."""
|
|
194
|
+
file_path = self._resolve_path(path)
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
if file_path.exists():
|
|
198
|
+
file_path.unlink()
|
|
199
|
+
self._context = None
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
202
|
+
except Exception:
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
def find_files(self, pattern: str) -> List[str]:
|
|
206
|
+
"""Find files matching a glob pattern."""
|
|
207
|
+
context = self.get_project_context()
|
|
208
|
+
matches = []
|
|
209
|
+
|
|
210
|
+
for file_info in context.files:
|
|
211
|
+
if fnmatch.fnmatch(file_info.relative_path, pattern):
|
|
212
|
+
matches.append(file_info.relative_path)
|
|
213
|
+
elif fnmatch.fnmatch(file_info.path.name, pattern):
|
|
214
|
+
matches.append(file_info.relative_path)
|
|
215
|
+
|
|
216
|
+
return matches
|
|
217
|
+
|
|
218
|
+
def get_file_info(self, path: str) -> Optional[FileInfo]:
|
|
219
|
+
"""Get information about a specific file."""
|
|
220
|
+
file_path = self._resolve_path(path)
|
|
221
|
+
|
|
222
|
+
if not file_path.exists():
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
rel_path = str(file_path.relative_to(self.root_path))
|
|
226
|
+
|
|
227
|
+
return FileInfo(
|
|
228
|
+
path=file_path,
|
|
229
|
+
relative_path=rel_path,
|
|
230
|
+
extension=file_path.suffix.lower(),
|
|
231
|
+
size=file_path.stat().st_size,
|
|
232
|
+
is_code=self._is_code_file(file_path),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def get_related_files(self, path: str, max_files: int = 5) -> List[str]:
|
|
236
|
+
"""Find files related to the given file (same directory, imports, etc.)."""
|
|
237
|
+
file_path = self._resolve_path(path)
|
|
238
|
+
related = []
|
|
239
|
+
|
|
240
|
+
# Same directory files
|
|
241
|
+
if file_path.parent.exists():
|
|
242
|
+
for sibling in file_path.parent.iterdir():
|
|
243
|
+
if sibling.is_file() and sibling != file_path:
|
|
244
|
+
if self._is_code_file(sibling) and not self._should_ignore(sibling):
|
|
245
|
+
related.append(str(sibling.relative_to(self.root_path)))
|
|
246
|
+
|
|
247
|
+
# TODO: Parse imports and find imported files
|
|
248
|
+
|
|
249
|
+
return related[:max_files]
|
|
250
|
+
|
|
251
|
+
def _resolve_path(self, path: str) -> Path:
|
|
252
|
+
"""Resolve a path relative to root."""
|
|
253
|
+
p = Path(path)
|
|
254
|
+
if p.is_absolute():
|
|
255
|
+
return p
|
|
256
|
+
return self.root_path / path
|
|
257
|
+
|
|
258
|
+
def get_structure_summary(self, max_depth: int = 3) -> str:
|
|
259
|
+
"""Get a summarized project structure."""
|
|
260
|
+
lines = [f"📁 {self.root_path.name}/"]
|
|
261
|
+
|
|
262
|
+
def walk_dir(dir_path: Path, depth: int = 1):
|
|
263
|
+
if depth > max_depth:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
items = sorted(dir_path.iterdir(), key=lambda x: (x.is_file(), x.name))
|
|
268
|
+
except PermissionError:
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
dirs = []
|
|
272
|
+
files = []
|
|
273
|
+
|
|
274
|
+
for item in items:
|
|
275
|
+
if self._should_ignore(item):
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
if item.is_dir():
|
|
279
|
+
dirs.append(item)
|
|
280
|
+
elif item.is_file():
|
|
281
|
+
files.append(item)
|
|
282
|
+
|
|
283
|
+
indent = " " * depth
|
|
284
|
+
|
|
285
|
+
for d in dirs[:10]: # Limit directories shown
|
|
286
|
+
lines.append(f"{indent}�� {d.name}/")
|
|
287
|
+
walk_dir(d, depth + 1)
|
|
288
|
+
|
|
289
|
+
if len(dirs) > 10:
|
|
290
|
+
lines.append(f"{indent}... and {len(dirs) - 10} more directories")
|
|
291
|
+
|
|
292
|
+
for f in files[:10]: # Limit files shown
|
|
293
|
+
icon = "📄" if self._is_code_file(f) else "📋"
|
|
294
|
+
lines.append(f"{indent}{icon} {f.name}")
|
|
295
|
+
|
|
296
|
+
if len(files) > 10:
|
|
297
|
+
lines.append(f"{indent}... and {len(files) - 10} more files")
|
|
298
|
+
|
|
299
|
+
walk_dir(self.root_path)
|
|
300
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""Intent Classifier for understanding user requests."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class IntentType(Enum):
|
|
10
|
+
"""Types of user intents."""
|
|
11
|
+
CODE_GENERATE = "code_generate" # Create new code
|
|
12
|
+
CODE_EDIT = "code_edit" # Modify existing code
|
|
13
|
+
CODE_REVIEW = "code_review" # Review code for issues
|
|
14
|
+
CODE_EXPLAIN = "code_explain" # Explain how code works
|
|
15
|
+
CODE_REFACTOR = "code_refactor" # Refactor/improve code
|
|
16
|
+
TEST_GENERATE = "test_generate" # Generate tests
|
|
17
|
+
FILE_CREATE = "file_create" # Create a new file
|
|
18
|
+
FILE_DELETE = "file_delete" # Delete a file
|
|
19
|
+
PROJECT_INFO = "project_info" # Get project information
|
|
20
|
+
GENERAL_CHAT = "general_chat" # General conversation
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class Intent:
|
|
25
|
+
"""Represents a classified intent."""
|
|
26
|
+
type: IntentType
|
|
27
|
+
confidence: float
|
|
28
|
+
file_paths: List[str] = field(default_factory=list)
|
|
29
|
+
target_description: str = ""
|
|
30
|
+
action_description: str = ""
|
|
31
|
+
language: Optional[str] = None
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def requires_file_context(self) -> bool:
|
|
35
|
+
"""Check if this intent needs file context."""
|
|
36
|
+
return self.type in {
|
|
37
|
+
IntentType.CODE_EDIT,
|
|
38
|
+
IntentType.CODE_REVIEW,
|
|
39
|
+
IntentType.CODE_EXPLAIN,
|
|
40
|
+
IntentType.CODE_REFACTOR,
|
|
41
|
+
IntentType.TEST_GENERATE,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def modifies_files(self) -> bool:
|
|
46
|
+
"""Check if this intent will modify files."""
|
|
47
|
+
return self.type in {
|
|
48
|
+
IntentType.CODE_GENERATE,
|
|
49
|
+
IntentType.CODE_EDIT,
|
|
50
|
+
IntentType.CODE_REFACTOR,
|
|
51
|
+
IntentType.TEST_GENERATE,
|
|
52
|
+
IntentType.FILE_CREATE,
|
|
53
|
+
IntentType.FILE_DELETE,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Keyword patterns for intent detection
|
|
58
|
+
INTENT_PATTERNS = {
|
|
59
|
+
IntentType.CODE_GENERATE: [
|
|
60
|
+
r"\b(create|write|generate|make|build|implement|add)\b.*\b(function|class|method|api|endpoint|component|module|service)\b",
|
|
61
|
+
r"\b(create|write|generate|make|build)\b.*\b(code|script|program)\b",
|
|
62
|
+
r"\bnew\b.*\b(function|class|file|component)\b",
|
|
63
|
+
r"\bimplement\b",
|
|
64
|
+
],
|
|
65
|
+
IntentType.CODE_EDIT: [
|
|
66
|
+
r"\b(edit|modify|change|update|fix|patch)\b.*\b(code|file|function|class|line)\b",
|
|
67
|
+
r"\b(add|insert|append)\b.*\b(to|in|into)\b",
|
|
68
|
+
r"\b(remove|delete)\b.*\b(from|in)\b.*\b(code|file|function)\b",
|
|
69
|
+
r"\bfix\b.*\b(bug|error|issue|problem)\b",
|
|
70
|
+
r"\bchange\b.*\bto\b",
|
|
71
|
+
],
|
|
72
|
+
IntentType.CODE_REVIEW: [
|
|
73
|
+
r"\b(review|check|analyze|audit|inspect)\b",
|
|
74
|
+
r"\b(find|look for|check for)\b.*\b(issues|bugs|problems|errors|vulnerabilities)\b",
|
|
75
|
+
r"\b(security|performance)\b.*\b(review|check|audit)\b",
|
|
76
|
+
r"\bcode\s*review\b",
|
|
77
|
+
r"\breview\b.*\.(py|js|ts|java|go|rs|rb|cpp|c)\b",
|
|
78
|
+
],
|
|
79
|
+
IntentType.CODE_EXPLAIN: [
|
|
80
|
+
r"\b(explain|describe)\b",
|
|
81
|
+
r"\b(what does|how does|what is|how is)\b.*\b(code|function|class|file|this|it|work)\b",
|
|
82
|
+
r"\bwalk\s*(me\s*)?through\b",
|
|
83
|
+
r"\bunderstand\b.*\b(code|function|class)\b",
|
|
84
|
+
r"\bwhat\b.*\b(doing|happening|mean)\b",
|
|
85
|
+
r"\bexplain\b.*\.(py|js|ts|java|go|rs|rb|cpp|c)\b",
|
|
86
|
+
],
|
|
87
|
+
IntentType.CODE_REFACTOR: [
|
|
88
|
+
r"\b(refactor|improve|optimize|clean|simplify|restructure)\b",
|
|
89
|
+
r"\bmake\b.*\b(better|cleaner|faster|more efficient|readable)\b",
|
|
90
|
+
r"\b(reduce|remove)\b.*\b(duplication|complexity)\b",
|
|
91
|
+
],
|
|
92
|
+
IntentType.TEST_GENERATE: [
|
|
93
|
+
r"\b(create|write|generate|add)\b.*\b(test|tests|unit test|spec)\b",
|
|
94
|
+
r"\btest\b.*\b(for|coverage)\b",
|
|
95
|
+
r"\b(pytest|unittest|jest|mocha)\b",
|
|
96
|
+
],
|
|
97
|
+
IntentType.FILE_CREATE: [
|
|
98
|
+
r"\b(create|make|new)\b.*\b(file|directory|folder)\b",
|
|
99
|
+
r"\btouch\b",
|
|
100
|
+
],
|
|
101
|
+
IntentType.FILE_DELETE: [
|
|
102
|
+
r"\b(delete|remove|rm)\b.*\b(file|directory|folder)\b",
|
|
103
|
+
],
|
|
104
|
+
IntentType.PROJECT_INFO: [
|
|
105
|
+
r"\b(show|list|what)\b.*\b(files|structure|project)\b",
|
|
106
|
+
r"\bproject\b.*\b(structure|info|overview)\b",
|
|
107
|
+
r"\bwhat\s*(files|languages)\b",
|
|
108
|
+
],
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
# File path patterns
|
|
112
|
+
FILE_PATH_PATTERNS = [
|
|
113
|
+
r'["\']([^"\']+\.[a-zA-Z]{1,10})["\']', # Quoted paths with extension
|
|
114
|
+
r'\b(\S+\.(?:py|js|ts|jsx|tsx|java|go|rs|c|cpp|h|rb|php|swift|kt|scala|sql|yaml|yml|json|toml|xml|html|css|md))\b', # Unquoted paths
|
|
115
|
+
r'\b(src/\S+)\b', # src/ paths
|
|
116
|
+
r'\b(tests?/\S+)\b', # test/ paths
|
|
117
|
+
r'\b(lib/\S+)\b', # lib/ paths
|
|
118
|
+
r'\b(app/\S+)\b', # app/ paths
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Language detection patterns
|
|
122
|
+
LANGUAGE_PATTERNS = {
|
|
123
|
+
"python": [r"\bpython\b", r"\bpy\b", r"\.py\b", r"\bdjango\b", r"\bflask\b", r"\bfastapi\b"],
|
|
124
|
+
"javascript": [r"\bjavascript\b", r"\bjs\b", r"\.js\b", r"\bnode\b", r"\breact\b", r"\bvue\b"],
|
|
125
|
+
"typescript": [r"\btypescript\b", r"\bts\b", r"\.ts\b", r"\.tsx\b"],
|
|
126
|
+
"java": [r"\bjava\b", r"\.java\b", r"\bspring\b"],
|
|
127
|
+
"go": [r"\bgo\b", r"\bgolang\b", r"\.go\b"],
|
|
128
|
+
"rust": [r"\brust\b", r"\.rs\b", r"\bcargo\b"],
|
|
129
|
+
"ruby": [r"\bruby\b", r"\.rb\b", r"\brails\b"],
|
|
130
|
+
"php": [r"\bphp\b", r"\.php\b", r"\blaravel\b"],
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class IntentClassifier:
|
|
135
|
+
"""Classifies user messages into intents."""
|
|
136
|
+
|
|
137
|
+
def __init__(self, llm_manager=None):
|
|
138
|
+
self.llm = llm_manager
|
|
139
|
+
self._compile_patterns()
|
|
140
|
+
|
|
141
|
+
def _compile_patterns(self):
|
|
142
|
+
"""Compile regex patterns for efficiency."""
|
|
143
|
+
self._intent_patterns = {
|
|
144
|
+
intent: [re.compile(p, re.IGNORECASE) for p in patterns]
|
|
145
|
+
for intent, patterns in INTENT_PATTERNS.items()
|
|
146
|
+
}
|
|
147
|
+
self._file_patterns = [re.compile(p) for p in FILE_PATH_PATTERNS]
|
|
148
|
+
self._language_patterns = {
|
|
149
|
+
lang: [re.compile(p, re.IGNORECASE) for p in patterns]
|
|
150
|
+
for lang, patterns in LANGUAGE_PATTERNS.items()
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
def classify(self, message: str) -> Intent:
|
|
154
|
+
"""Classify a user message into an intent."""
|
|
155
|
+
# Extract file paths
|
|
156
|
+
file_paths = self._extract_file_paths(message)
|
|
157
|
+
|
|
158
|
+
# Detect language
|
|
159
|
+
language = self._detect_language(message)
|
|
160
|
+
|
|
161
|
+
# Score each intent type
|
|
162
|
+
scores: List[Tuple[IntentType, float]] = []
|
|
163
|
+
|
|
164
|
+
for intent_type, patterns in self._intent_patterns.items():
|
|
165
|
+
score = 0.0
|
|
166
|
+
for pattern in patterns:
|
|
167
|
+
if pattern.search(message):
|
|
168
|
+
score += 1.0
|
|
169
|
+
|
|
170
|
+
if score > 0:
|
|
171
|
+
# Normalize by number of patterns
|
|
172
|
+
score = score / len(patterns)
|
|
173
|
+
scores.append((intent_type, score))
|
|
174
|
+
|
|
175
|
+
# Sort by score
|
|
176
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
|
177
|
+
|
|
178
|
+
if scores:
|
|
179
|
+
best_intent, confidence = scores[0]
|
|
180
|
+
else:
|
|
181
|
+
best_intent = IntentType.GENERAL_CHAT
|
|
182
|
+
confidence = 0.5
|
|
183
|
+
|
|
184
|
+
# Boost confidence if file paths found for file-related intents
|
|
185
|
+
if file_paths and best_intent.value.startswith("code_"):
|
|
186
|
+
confidence = min(confidence + 0.2, 1.0)
|
|
187
|
+
|
|
188
|
+
return Intent(
|
|
189
|
+
type=best_intent,
|
|
190
|
+
confidence=confidence,
|
|
191
|
+
file_paths=file_paths,
|
|
192
|
+
target_description=self._extract_target(message, best_intent),
|
|
193
|
+
action_description=message,
|
|
194
|
+
language=language,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def classify_with_llm(self, message: str) -> Intent:
|
|
198
|
+
"""Use LLM for more accurate classification."""
|
|
199
|
+
if not self.llm:
|
|
200
|
+
return self.classify(message)
|
|
201
|
+
|
|
202
|
+
# First do rule-based classification
|
|
203
|
+
rule_intent = self.classify(message)
|
|
204
|
+
|
|
205
|
+
# Use LLM to refine
|
|
206
|
+
prompt = f"""Classify this user request into one of these categories:
|
|
207
|
+
- code_generate: Create new code (function, class, API, etc.)
|
|
208
|
+
- code_edit: Modify existing code
|
|
209
|
+
- code_review: Review code for issues/bugs
|
|
210
|
+
- code_explain: Explain how code works
|
|
211
|
+
- code_refactor: Improve/optimize code
|
|
212
|
+
- test_generate: Create tests
|
|
213
|
+
- file_create: Create new file
|
|
214
|
+
- file_delete: Delete file
|
|
215
|
+
- project_info: Get project information
|
|
216
|
+
- general_chat: General conversation
|
|
217
|
+
|
|
218
|
+
User request: "{message}"
|
|
219
|
+
|
|
220
|
+
Respond with ONLY the category name, nothing else."""
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
response = self.llm.invoke(prompt).strip().lower()
|
|
224
|
+
|
|
225
|
+
# Map response to IntentType
|
|
226
|
+
intent_map = {
|
|
227
|
+
"code_generate": IntentType.CODE_GENERATE,
|
|
228
|
+
"code_edit": IntentType.CODE_EDIT,
|
|
229
|
+
"code_review": IntentType.CODE_REVIEW,
|
|
230
|
+
"code_explain": IntentType.CODE_EXPLAIN,
|
|
231
|
+
"code_refactor": IntentType.CODE_REFACTOR,
|
|
232
|
+
"test_generate": IntentType.TEST_GENERATE,
|
|
233
|
+
"file_create": IntentType.FILE_CREATE,
|
|
234
|
+
"file_delete": IntentType.FILE_DELETE,
|
|
235
|
+
"project_info": IntentType.PROJECT_INFO,
|
|
236
|
+
"general_chat": IntentType.GENERAL_CHAT,
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
if response in intent_map:
|
|
240
|
+
rule_intent.type = intent_map[response]
|
|
241
|
+
rule_intent.confidence = 0.9
|
|
242
|
+
except Exception:
|
|
243
|
+
pass # Fall back to rule-based
|
|
244
|
+
|
|
245
|
+
return rule_intent
|
|
246
|
+
|
|
247
|
+
def _extract_file_paths(self, message: str) -> List[str]:
|
|
248
|
+
"""Extract file paths from message."""
|
|
249
|
+
paths = []
|
|
250
|
+
|
|
251
|
+
for pattern in self._file_patterns:
|
|
252
|
+
matches = pattern.findall(message)
|
|
253
|
+
paths.extend(matches)
|
|
254
|
+
|
|
255
|
+
# Deduplicate while preserving order
|
|
256
|
+
seen = set()
|
|
257
|
+
unique_paths = []
|
|
258
|
+
for p in paths:
|
|
259
|
+
if p not in seen:
|
|
260
|
+
seen.add(p)
|
|
261
|
+
unique_paths.append(p)
|
|
262
|
+
|
|
263
|
+
return unique_paths
|
|
264
|
+
|
|
265
|
+
def _detect_language(self, message: str) -> Optional[str]:
|
|
266
|
+
"""Detect programming language from message."""
|
|
267
|
+
for lang, patterns in self._language_patterns.items():
|
|
268
|
+
for pattern in patterns:
|
|
269
|
+
if pattern.search(message):
|
|
270
|
+
return lang
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
def _extract_target(self, message: str, intent: IntentType) -> str:
|
|
274
|
+
"""Extract the target of the action."""
|
|
275
|
+
# Simple extraction - could be improved with NLP
|
|
276
|
+
message_lower = message.lower()
|
|
277
|
+
|
|
278
|
+
# Remove common action words
|
|
279
|
+
for word in ["create", "write", "generate", "make", "build", "implement",
|
|
280
|
+
"edit", "modify", "change", "update", "fix", "review",
|
|
281
|
+
"explain", "refactor", "improve", "optimize", "test"]:
|
|
282
|
+
message_lower = message_lower.replace(word, "")
|
|
283
|
+
|
|
284
|
+
return message_lower.strip()
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Chat module for AI Code Assistant."""
|
|
2
|
+
|
|
3
|
+
from ai_code_assistant.chat.session import ChatSession, Message
|
|
4
|
+
from ai_code_assistant.chat.agent_session import AgentChatSession, AgentMessage
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ChatSession",
|
|
8
|
+
"Message",
|
|
9
|
+
"AgentChatSession",
|
|
10
|
+
"AgentMessage",
|
|
11
|
+
]
|