cognify-code 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_code_assistant/__init__.py +14 -0
- ai_code_assistant/agent/__init__.py +63 -0
- ai_code_assistant/agent/code_agent.py +461 -0
- ai_code_assistant/agent/code_generator.py +388 -0
- ai_code_assistant/agent/code_reviewer.py +365 -0
- ai_code_assistant/agent/diff_engine.py +308 -0
- ai_code_assistant/agent/file_manager.py +300 -0
- ai_code_assistant/agent/intent_classifier.py +284 -0
- ai_code_assistant/chat/__init__.py +11 -0
- ai_code_assistant/chat/agent_session.py +156 -0
- ai_code_assistant/chat/session.py +165 -0
- ai_code_assistant/cli.py +1571 -0
- ai_code_assistant/config.py +149 -0
- ai_code_assistant/editor/__init__.py +8 -0
- ai_code_assistant/editor/diff_handler.py +270 -0
- ai_code_assistant/editor/file_editor.py +350 -0
- ai_code_assistant/editor/prompts.py +146 -0
- ai_code_assistant/generator/__init__.py +7 -0
- ai_code_assistant/generator/code_gen.py +265 -0
- ai_code_assistant/generator/prompts.py +114 -0
- ai_code_assistant/git/__init__.py +6 -0
- ai_code_assistant/git/commit_generator.py +130 -0
- ai_code_assistant/git/manager.py +203 -0
- ai_code_assistant/llm.py +111 -0
- ai_code_assistant/providers/__init__.py +23 -0
- ai_code_assistant/providers/base.py +124 -0
- ai_code_assistant/providers/cerebras.py +97 -0
- ai_code_assistant/providers/factory.py +148 -0
- ai_code_assistant/providers/google.py +103 -0
- ai_code_assistant/providers/groq.py +111 -0
- ai_code_assistant/providers/ollama.py +86 -0
- ai_code_assistant/providers/openai.py +114 -0
- ai_code_assistant/providers/openrouter.py +130 -0
- ai_code_assistant/py.typed +0 -0
- ai_code_assistant/refactor/__init__.py +20 -0
- ai_code_assistant/refactor/analyzer.py +189 -0
- ai_code_assistant/refactor/change_plan.py +172 -0
- ai_code_assistant/refactor/multi_file_editor.py +346 -0
- ai_code_assistant/refactor/prompts.py +175 -0
- ai_code_assistant/retrieval/__init__.py +19 -0
- ai_code_assistant/retrieval/chunker.py +215 -0
- ai_code_assistant/retrieval/indexer.py +236 -0
- ai_code_assistant/retrieval/search.py +239 -0
- ai_code_assistant/reviewer/__init__.py +7 -0
- ai_code_assistant/reviewer/analyzer.py +278 -0
- ai_code_assistant/reviewer/prompts.py +113 -0
- ai_code_assistant/utils/__init__.py +18 -0
- ai_code_assistant/utils/file_handler.py +155 -0
- ai_code_assistant/utils/formatters.py +259 -0
- cognify_code-0.2.0.dist-info/METADATA +383 -0
- cognify_code-0.2.0.dist-info/RECORD +55 -0
- cognify_code-0.2.0.dist-info/WHEEL +5 -0
- cognify_code-0.2.0.dist-info/entry_points.txt +3 -0
- cognify_code-0.2.0.dist-info/licenses/LICENSE +22 -0
- cognify_code-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Git repository manager."""
|
|
2
|
+
|
|
3
|
+
import subprocess
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class GitStatus:
|
|
11
|
+
"""Represents the current git status."""
|
|
12
|
+
staged: List[str] = field(default_factory=list)
|
|
13
|
+
modified: List[str] = field(default_factory=list)
|
|
14
|
+
untracked: List[str] = field(default_factory=list)
|
|
15
|
+
deleted: List[str] = field(default_factory=list)
|
|
16
|
+
branch: str = ""
|
|
17
|
+
remote: str = ""
|
|
18
|
+
ahead: int = 0
|
|
19
|
+
behind: int = 0
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def has_changes(self) -> bool:
|
|
23
|
+
return bool(self.staged or self.modified or self.untracked or self.deleted)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def has_staged(self) -> bool:
|
|
27
|
+
return bool(self.staged)
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def total_changes(self) -> int:
|
|
31
|
+
return len(self.staged) + len(self.modified) + len(self.untracked) + len(self.deleted)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class GitDiff:
|
|
36
|
+
"""Represents a git diff."""
|
|
37
|
+
files_changed: int = 0
|
|
38
|
+
insertions: int = 0
|
|
39
|
+
deletions: int = 0
|
|
40
|
+
diff_text: str = ""
|
|
41
|
+
file_diffs: List[dict] = field(default_factory=list)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class GitManager:
|
|
45
|
+
"""Manages git operations."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, repo_path: Optional[Path] = None):
|
|
48
|
+
self.repo_path = repo_path or Path.cwd()
|
|
49
|
+
self._validate_repo()
|
|
50
|
+
|
|
51
|
+
def _validate_repo(self) -> None:
|
|
52
|
+
"""Validate that we're in a git repository."""
|
|
53
|
+
if not (self.repo_path / ".git").exists():
|
|
54
|
+
raise ValueError(f"Not a git repository: {self.repo_path}")
|
|
55
|
+
|
|
56
|
+
def _run_git(self, *args: str, check: bool = True) -> subprocess.CompletedProcess:
|
|
57
|
+
"""Run a git command."""
|
|
58
|
+
result = subprocess.run(
|
|
59
|
+
["git", *args],
|
|
60
|
+
cwd=self.repo_path,
|
|
61
|
+
capture_output=True,
|
|
62
|
+
text=True,
|
|
63
|
+
)
|
|
64
|
+
if check and result.returncode != 0:
|
|
65
|
+
raise RuntimeError(f"Git command failed: {result.stderr}")
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
def get_status(self) -> GitStatus:
|
|
69
|
+
"""Get current git status."""
|
|
70
|
+
status = GitStatus()
|
|
71
|
+
|
|
72
|
+
# Get branch name
|
|
73
|
+
result = self._run_git("branch", "--show-current", check=False)
|
|
74
|
+
status.branch = result.stdout.strip()
|
|
75
|
+
|
|
76
|
+
# Get remote
|
|
77
|
+
result = self._run_git("remote", check=False)
|
|
78
|
+
status.remote = result.stdout.strip().split("\n")[0] if result.stdout else "origin"
|
|
79
|
+
|
|
80
|
+
# Get ahead/behind
|
|
81
|
+
result = self._run_git("rev-list", "--left-right", "--count", f"{status.remote}/{status.branch}...HEAD", check=False)
|
|
82
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
83
|
+
parts = result.stdout.strip().split()
|
|
84
|
+
if len(parts) == 2:
|
|
85
|
+
status.behind, status.ahead = int(parts[0]), int(parts[1])
|
|
86
|
+
|
|
87
|
+
# Get file status
|
|
88
|
+
result = self._run_git("status", "--porcelain", check=False)
|
|
89
|
+
for line in result.stdout.strip().split("\n"):
|
|
90
|
+
if not line:
|
|
91
|
+
continue
|
|
92
|
+
status_code = line[:2]
|
|
93
|
+
file_path = line[3:]
|
|
94
|
+
|
|
95
|
+
if status_code[0] in "MADRC":
|
|
96
|
+
status.staged.append(file_path)
|
|
97
|
+
if status_code[1] == "M":
|
|
98
|
+
status.modified.append(file_path)
|
|
99
|
+
elif status_code[1] == "D":
|
|
100
|
+
status.deleted.append(file_path)
|
|
101
|
+
elif status_code == "??":
|
|
102
|
+
status.untracked.append(file_path)
|
|
103
|
+
|
|
104
|
+
return status
|
|
105
|
+
|
|
106
|
+
def get_diff(self, staged: bool = True) -> GitDiff:
|
|
107
|
+
"""Get diff of changes."""
|
|
108
|
+
diff = GitDiff()
|
|
109
|
+
|
|
110
|
+
# Get diff text
|
|
111
|
+
if staged:
|
|
112
|
+
result = self._run_git("diff", "--cached", check=False)
|
|
113
|
+
else:
|
|
114
|
+
result = self._run_git("diff", check=False)
|
|
115
|
+
|
|
116
|
+
diff.diff_text = result.stdout
|
|
117
|
+
|
|
118
|
+
# Get stats
|
|
119
|
+
if staged:
|
|
120
|
+
result = self._run_git("diff", "--cached", "--stat", check=False)
|
|
121
|
+
else:
|
|
122
|
+
result = self._run_git("diff", "--stat", check=False)
|
|
123
|
+
|
|
124
|
+
# Parse stats from last line
|
|
125
|
+
lines = result.stdout.strip().split("\n")
|
|
126
|
+
if lines and lines[-1]:
|
|
127
|
+
last_line = lines[-1]
|
|
128
|
+
# Parse "X files changed, Y insertions(+), Z deletions(-)"
|
|
129
|
+
import re
|
|
130
|
+
match = re.search(r"(\d+) files? changed", last_line)
|
|
131
|
+
if match:
|
|
132
|
+
diff.files_changed = int(match.group(1))
|
|
133
|
+
match = re.search(r"(\d+) insertions?", last_line)
|
|
134
|
+
if match:
|
|
135
|
+
diff.insertions = int(match.group(1))
|
|
136
|
+
match = re.search(r"(\d+) deletions?", last_line)
|
|
137
|
+
if match:
|
|
138
|
+
diff.deletions = int(match.group(1))
|
|
139
|
+
|
|
140
|
+
return diff
|
|
141
|
+
|
|
142
|
+
def stage_all(self) -> None:
|
|
143
|
+
"""Stage all changes."""
|
|
144
|
+
self._run_git("add", "-A")
|
|
145
|
+
|
|
146
|
+
def stage_files(self, files: List[str]) -> None:
|
|
147
|
+
"""Stage specific files."""
|
|
148
|
+
self._run_git("add", *files)
|
|
149
|
+
|
|
150
|
+
def commit(self, message: str) -> str:
|
|
151
|
+
"""Create a commit with the given message."""
|
|
152
|
+
result = self._run_git("commit", "-m", message)
|
|
153
|
+
# Extract commit hash from output
|
|
154
|
+
import re
|
|
155
|
+
match = re.search(r"\[[\w-]+ ([a-f0-9]+)\]", result.stdout)
|
|
156
|
+
return match.group(1) if match else ""
|
|
157
|
+
|
|
158
|
+
def push(self, remote: str = "origin", branch: Optional[str] = None,
|
|
159
|
+
set_upstream: bool = False) -> Tuple[bool, str]:
|
|
160
|
+
"""Push to remote."""
|
|
161
|
+
status = self.get_status()
|
|
162
|
+
branch = branch or status.branch
|
|
163
|
+
|
|
164
|
+
args = ["push"]
|
|
165
|
+
if set_upstream:
|
|
166
|
+
args.extend(["-u", remote, branch])
|
|
167
|
+
else:
|
|
168
|
+
args.extend([remote, branch])
|
|
169
|
+
|
|
170
|
+
result = self._run_git(*args, check=False)
|
|
171
|
+
success = result.returncode == 0
|
|
172
|
+
output = result.stdout + result.stderr
|
|
173
|
+
|
|
174
|
+
return success, output
|
|
175
|
+
|
|
176
|
+
def get_recent_commits(self, count: int = 5) -> List[dict]:
|
|
177
|
+
"""Get recent commits."""
|
|
178
|
+
result = self._run_git(
|
|
179
|
+
"log", f"-{count}",
|
|
180
|
+
"--pretty=format:%H|%h|%s|%an|%ar",
|
|
181
|
+
check=False
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
commits = []
|
|
185
|
+
for line in result.stdout.strip().split("\n"):
|
|
186
|
+
if not line:
|
|
187
|
+
continue
|
|
188
|
+
parts = line.split("|")
|
|
189
|
+
if len(parts) >= 5:
|
|
190
|
+
commits.append({
|
|
191
|
+
"hash": parts[0],
|
|
192
|
+
"short_hash": parts[1],
|
|
193
|
+
"message": parts[2],
|
|
194
|
+
"author": parts[3],
|
|
195
|
+
"time": parts[4],
|
|
196
|
+
})
|
|
197
|
+
|
|
198
|
+
return commits
|
|
199
|
+
|
|
200
|
+
def get_remote_url(self) -> str:
|
|
201
|
+
"""Get the remote URL."""
|
|
202
|
+
result = self._run_git("remote", "get-url", "origin", check=False)
|
|
203
|
+
return result.stdout.strip()
|
ai_code_assistant/llm.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""LLM Manager for multi-provider LLM integration via LangChain."""
|
|
2
|
+
|
|
3
|
+
from typing import Iterator, Optional
|
|
4
|
+
|
|
5
|
+
from langchain_core.language_models import BaseChatModel
|
|
6
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
7
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
8
|
+
|
|
9
|
+
from ai_code_assistant.config import Config
|
|
10
|
+
from ai_code_assistant.providers.base import BaseProvider, ProviderConfig, ProviderType
|
|
11
|
+
from ai_code_assistant.providers.factory import get_provider, get_available_providers
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMManager:
|
|
15
|
+
"""Manages LLM interactions using LangChain with multiple provider support."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, config: Config):
|
|
18
|
+
"""Initialize LLM manager with configuration."""
|
|
19
|
+
self.config = config
|
|
20
|
+
self._provider: Optional[BaseProvider] = None
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def provider(self) -> BaseProvider:
|
|
24
|
+
"""Get or create the provider instance."""
|
|
25
|
+
if self._provider is None:
|
|
26
|
+
provider_config = ProviderConfig(
|
|
27
|
+
provider=ProviderType(self.config.llm.provider),
|
|
28
|
+
model=self.config.llm.model,
|
|
29
|
+
api_key=self.config.llm.api_key,
|
|
30
|
+
base_url=self.config.llm.base_url,
|
|
31
|
+
temperature=self.config.llm.temperature,
|
|
32
|
+
max_tokens=self.config.llm.max_tokens,
|
|
33
|
+
timeout=self.config.llm.timeout,
|
|
34
|
+
)
|
|
35
|
+
self._provider = get_provider(provider_config)
|
|
36
|
+
return self._provider
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def llm(self) -> BaseChatModel:
|
|
40
|
+
"""Get the underlying LLM instance for backward compatibility."""
|
|
41
|
+
return self.provider.llm
|
|
42
|
+
|
|
43
|
+
def invoke(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
|
44
|
+
"""Invoke the LLM with a prompt and optional system message."""
|
|
45
|
+
return self.provider.invoke(prompt, system_prompt)
|
|
46
|
+
|
|
47
|
+
def invoke_with_template(
|
|
48
|
+
self,
|
|
49
|
+
template: ChatPromptTemplate,
|
|
50
|
+
**kwargs,
|
|
51
|
+
) -> str:
|
|
52
|
+
"""Invoke the LLM using a prompt template."""
|
|
53
|
+
chain = template | self.llm
|
|
54
|
+
response = chain.invoke(kwargs)
|
|
55
|
+
return str(response.content)
|
|
56
|
+
|
|
57
|
+
def stream(self, prompt: str, system_prompt: Optional[str] = None) -> Iterator[str]:
|
|
58
|
+
"""Stream LLM response for real-time output."""
|
|
59
|
+
return self.provider.stream(prompt, system_prompt)
|
|
60
|
+
|
|
61
|
+
def check_connection(self) -> bool:
|
|
62
|
+
"""Check if the LLM provider is accessible."""
|
|
63
|
+
return self.provider.check_connection()
|
|
64
|
+
|
|
65
|
+
def get_model_info(self) -> dict:
|
|
66
|
+
"""Get information about the current model configuration."""
|
|
67
|
+
return self.provider.get_model_info()
|
|
68
|
+
|
|
69
|
+
def validate_config(self) -> tuple[bool, str]:
|
|
70
|
+
"""Validate the current provider configuration."""
|
|
71
|
+
return self.provider.validate_config()
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def list_providers() -> dict:
|
|
75
|
+
"""List all available providers and their models."""
|
|
76
|
+
return get_available_providers()
|
|
77
|
+
|
|
78
|
+
def switch_provider(
|
|
79
|
+
self,
|
|
80
|
+
provider: str,
|
|
81
|
+
model: Optional[str] = None,
|
|
82
|
+
api_key: Optional[str] = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Switch to a different provider at runtime.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
provider: Provider name (ollama, google, groq, cerebras, openrouter, openai)
|
|
89
|
+
model: Optional model name (uses provider default if not specified)
|
|
90
|
+
api_key: Optional API key (uses environment variable if not specified)
|
|
91
|
+
"""
|
|
92
|
+
provider_type = ProviderType(provider)
|
|
93
|
+
|
|
94
|
+
# Get default model for provider if not specified
|
|
95
|
+
if model is None:
|
|
96
|
+
from ai_code_assistant.providers.factory import PROVIDER_REGISTRY
|
|
97
|
+
provider_class = PROVIDER_REGISTRY.get(provider_type)
|
|
98
|
+
if provider_class:
|
|
99
|
+
model = provider_class.default_model
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
102
|
+
|
|
103
|
+
provider_config = ProviderConfig(
|
|
104
|
+
provider=provider_type,
|
|
105
|
+
model=model,
|
|
106
|
+
api_key=api_key,
|
|
107
|
+
temperature=self.config.llm.temperature,
|
|
108
|
+
max_tokens=self.config.llm.max_tokens,
|
|
109
|
+
timeout=self.config.llm.timeout,
|
|
110
|
+
)
|
|
111
|
+
self._provider = get_provider(provider_config)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""LLM Provider implementations for Cognify AI."""
|
|
2
|
+
|
|
3
|
+
from ai_code_assistant.providers.base import BaseProvider, ProviderType
|
|
4
|
+
from ai_code_assistant.providers.ollama import OllamaProvider
|
|
5
|
+
from ai_code_assistant.providers.google import GoogleProvider
|
|
6
|
+
from ai_code_assistant.providers.groq import GroqProvider
|
|
7
|
+
from ai_code_assistant.providers.cerebras import CerebrasProvider
|
|
8
|
+
from ai_code_assistant.providers.openrouter import OpenRouterProvider
|
|
9
|
+
from ai_code_assistant.providers.openai import OpenAIProvider
|
|
10
|
+
from ai_code_assistant.providers.factory import get_provider, get_available_providers
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"BaseProvider",
|
|
14
|
+
"ProviderType",
|
|
15
|
+
"OllamaProvider",
|
|
16
|
+
"GoogleProvider",
|
|
17
|
+
"GroqProvider",
|
|
18
|
+
"CerebrasProvider",
|
|
19
|
+
"OpenRouterProvider",
|
|
20
|
+
"OpenAIProvider",
|
|
21
|
+
"get_provider",
|
|
22
|
+
"get_available_providers",
|
|
23
|
+
]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Base provider class for LLM integrations."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Dict, Iterator, List, Optional
|
|
6
|
+
|
|
7
|
+
from langchain_core.language_models import BaseChatModel
|
|
8
|
+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProviderType(str, Enum):
|
|
13
|
+
"""Supported LLM provider types."""
|
|
14
|
+
OLLAMA = "ollama"
|
|
15
|
+
GOOGLE = "google"
|
|
16
|
+
GROQ = "groq"
|
|
17
|
+
CEREBRAS = "cerebras"
|
|
18
|
+
OPENROUTER = "openrouter"
|
|
19
|
+
OPENAI = "openai"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ProviderConfig(BaseModel):
|
|
23
|
+
"""Configuration for an LLM provider."""
|
|
24
|
+
provider: ProviderType = ProviderType.OLLAMA
|
|
25
|
+
model: str = "deepseek-coder:6.7b"
|
|
26
|
+
api_key: Optional[str] = None
|
|
27
|
+
base_url: Optional[str] = None
|
|
28
|
+
temperature: float = 0.1
|
|
29
|
+
max_tokens: int = 4096
|
|
30
|
+
timeout: int = 120
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ModelInfo(BaseModel):
|
|
34
|
+
"""Information about a model."""
|
|
35
|
+
name: str
|
|
36
|
+
provider: ProviderType
|
|
37
|
+
description: str = ""
|
|
38
|
+
context_window: int = 4096
|
|
39
|
+
is_free: bool = True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BaseProvider(ABC):
|
|
43
|
+
"""Abstract base class for LLM providers."""
|
|
44
|
+
|
|
45
|
+
# Provider metadata
|
|
46
|
+
provider_type: ProviderType
|
|
47
|
+
display_name: str
|
|
48
|
+
requires_api_key: bool = True
|
|
49
|
+
default_model: str
|
|
50
|
+
free_tier: bool = True
|
|
51
|
+
|
|
52
|
+
# Available models for this provider
|
|
53
|
+
available_models: List[ModelInfo] = []
|
|
54
|
+
|
|
55
|
+
def __init__(self, config: ProviderConfig):
|
|
56
|
+
"""Initialize the provider with configuration."""
|
|
57
|
+
self.config = config
|
|
58
|
+
self._llm: Optional[BaseChatModel] = None
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def llm(self) -> BaseChatModel:
|
|
62
|
+
"""Get or create the LLM instance."""
|
|
63
|
+
if self._llm is None:
|
|
64
|
+
self._llm = self._create_llm()
|
|
65
|
+
return self._llm
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def _create_llm(self) -> BaseChatModel:
|
|
69
|
+
"""Create the LangChain LLM instance. Must be implemented by subclasses."""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def validate_config(self) -> tuple[bool, str]:
|
|
74
|
+
"""Validate the provider configuration. Returns (is_valid, error_message)."""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
def invoke(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
|
78
|
+
"""Invoke the LLM with a prompt and optional system message."""
|
|
79
|
+
messages: List[BaseMessage] = []
|
|
80
|
+
if system_prompt:
|
|
81
|
+
messages.append(SystemMessage(content=system_prompt))
|
|
82
|
+
messages.append(HumanMessage(content=prompt))
|
|
83
|
+
|
|
84
|
+
response = self.llm.invoke(messages)
|
|
85
|
+
return str(response.content)
|
|
86
|
+
|
|
87
|
+
def stream(self, prompt: str, system_prompt: Optional[str] = None) -> Iterator[str]:
|
|
88
|
+
"""Stream LLM response for real-time output."""
|
|
89
|
+
messages: List[BaseMessage] = []
|
|
90
|
+
if system_prompt:
|
|
91
|
+
messages.append(SystemMessage(content=system_prompt))
|
|
92
|
+
messages.append(HumanMessage(content=prompt))
|
|
93
|
+
|
|
94
|
+
for chunk in self.llm.stream(messages):
|
|
95
|
+
yield str(chunk.content)
|
|
96
|
+
|
|
97
|
+
def check_connection(self) -> bool:
|
|
98
|
+
"""Check if the provider is accessible."""
|
|
99
|
+
try:
|
|
100
|
+
self.invoke("Say 'ok' and nothing else.")
|
|
101
|
+
return True
|
|
102
|
+
except Exception:
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
def get_model_info(self) -> Dict[str, Any]:
|
|
106
|
+
"""Get information about the current model configuration."""
|
|
107
|
+
return {
|
|
108
|
+
"provider": self.provider_type.value,
|
|
109
|
+
"model": self.config.model,
|
|
110
|
+
"temperature": self.config.temperature,
|
|
111
|
+
"max_tokens": self.config.max_tokens,
|
|
112
|
+
"base_url": self.config.base_url,
|
|
113
|
+
"free_tier": self.free_tier,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_available_models(cls) -> List[ModelInfo]:
|
|
118
|
+
"""Get list of available models for this provider."""
|
|
119
|
+
return cls.available_models
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def get_setup_instructions(cls) -> str:
|
|
123
|
+
"""Get setup instructions for this provider."""
|
|
124
|
+
return f"Configure {cls.display_name} in your config.yaml or set environment variables."
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Cerebras provider for fast LLM inference."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from langchain_core.language_models import BaseChatModel
|
|
5
|
+
|
|
6
|
+
from ai_code_assistant.providers.base import (
|
|
7
|
+
BaseProvider,
|
|
8
|
+
ModelInfo,
|
|
9
|
+
ProviderConfig,
|
|
10
|
+
ProviderType,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CerebrasProvider(BaseProvider):
|
|
15
|
+
"""Provider for Cerebras - fast inference with large models."""
|
|
16
|
+
|
|
17
|
+
provider_type = ProviderType.CEREBRAS
|
|
18
|
+
display_name = "Cerebras (Fast Inference)"
|
|
19
|
+
requires_api_key = True
|
|
20
|
+
default_model = "llama3.1-8b"
|
|
21
|
+
free_tier = True
|
|
22
|
+
|
|
23
|
+
available_models = [
|
|
24
|
+
ModelInfo(
|
|
25
|
+
name="llama3.1-8b",
|
|
26
|
+
provider=ProviderType.CEREBRAS,
|
|
27
|
+
description="Llama 3.1 8B - Fast and efficient",
|
|
28
|
+
context_window=8192,
|
|
29
|
+
is_free=True,
|
|
30
|
+
),
|
|
31
|
+
ModelInfo(
|
|
32
|
+
name="llama3.1-70b",
|
|
33
|
+
provider=ProviderType.CEREBRAS,
|
|
34
|
+
description="Llama 3.1 70B - More capable",
|
|
35
|
+
context_window=8192,
|
|
36
|
+
is_free=True,
|
|
37
|
+
),
|
|
38
|
+
ModelInfo(
|
|
39
|
+
name="llama-3.3-70b",
|
|
40
|
+
provider=ProviderType.CEREBRAS,
|
|
41
|
+
description="Llama 3.3 70B - Latest Llama model",
|
|
42
|
+
context_window=8192,
|
|
43
|
+
is_free=True,
|
|
44
|
+
),
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
def _create_llm(self) -> BaseChatModel:
|
|
48
|
+
"""Create Cerebras LLM instance using OpenAI-compatible API."""
|
|
49
|
+
try:
|
|
50
|
+
from langchain_openai import ChatOpenAI
|
|
51
|
+
except ImportError:
|
|
52
|
+
raise ImportError(
|
|
53
|
+
"langchain-openai is required for Cerebras provider. "
|
|
54
|
+
"Install with: pip install langchain-openai"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
api_key = self.config.api_key or os.getenv("CEREBRAS_API_KEY")
|
|
58
|
+
if not api_key:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Cerebras API key is required. Set CEREBRAS_API_KEY environment variable "
|
|
61
|
+
"or provide api_key in config. Get your key at: https://cloud.cerebras.ai/"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return ChatOpenAI(
|
|
65
|
+
model=self.config.model,
|
|
66
|
+
api_key=api_key,
|
|
67
|
+
base_url="https://api.cerebras.ai/v1",
|
|
68
|
+
temperature=self.config.temperature,
|
|
69
|
+
max_tokens=self.config.max_tokens,
|
|
70
|
+
timeout=self.config.timeout,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def validate_config(self) -> tuple[bool, str]:
|
|
74
|
+
"""Validate Cerebras configuration."""
|
|
75
|
+
api_key = self.config.api_key or os.getenv("CEREBRAS_API_KEY")
|
|
76
|
+
if not api_key:
|
|
77
|
+
return False, "CEREBRAS_API_KEY environment variable or api_key config is required"
|
|
78
|
+
if not self.config.model:
|
|
79
|
+
return False, "Model name is required"
|
|
80
|
+
return True, ""
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def get_setup_instructions(cls) -> str:
|
|
84
|
+
"""Get Cerebras setup instructions."""
|
|
85
|
+
return """
|
|
86
|
+
Cerebras Setup Instructions:
|
|
87
|
+
1. Go to https://cloud.cerebras.ai/
|
|
88
|
+
2. Create a free account and generate an API key
|
|
89
|
+
3. Set environment variable: export CEREBRAS_API_KEY="your-key"
|
|
90
|
+
4. Or add to config.yaml:
|
|
91
|
+
llm:
|
|
92
|
+
provider: cerebras
|
|
93
|
+
api_key: "your-key"
|
|
94
|
+
model: llama3.1-8b
|
|
95
|
+
|
|
96
|
+
Free tier: 14,400 requests/day, very fast inference
|
|
97
|
+
"""
|