cognify-code 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. ai_code_assistant/__init__.py +14 -0
  2. ai_code_assistant/agent/__init__.py +63 -0
  3. ai_code_assistant/agent/code_agent.py +461 -0
  4. ai_code_assistant/agent/code_generator.py +388 -0
  5. ai_code_assistant/agent/code_reviewer.py +365 -0
  6. ai_code_assistant/agent/diff_engine.py +308 -0
  7. ai_code_assistant/agent/file_manager.py +300 -0
  8. ai_code_assistant/agent/intent_classifier.py +284 -0
  9. ai_code_assistant/chat/__init__.py +11 -0
  10. ai_code_assistant/chat/agent_session.py +156 -0
  11. ai_code_assistant/chat/session.py +165 -0
  12. ai_code_assistant/cli.py +1571 -0
  13. ai_code_assistant/config.py +149 -0
  14. ai_code_assistant/editor/__init__.py +8 -0
  15. ai_code_assistant/editor/diff_handler.py +270 -0
  16. ai_code_assistant/editor/file_editor.py +350 -0
  17. ai_code_assistant/editor/prompts.py +146 -0
  18. ai_code_assistant/generator/__init__.py +7 -0
  19. ai_code_assistant/generator/code_gen.py +265 -0
  20. ai_code_assistant/generator/prompts.py +114 -0
  21. ai_code_assistant/git/__init__.py +6 -0
  22. ai_code_assistant/git/commit_generator.py +130 -0
  23. ai_code_assistant/git/manager.py +203 -0
  24. ai_code_assistant/llm.py +111 -0
  25. ai_code_assistant/providers/__init__.py +23 -0
  26. ai_code_assistant/providers/base.py +124 -0
  27. ai_code_assistant/providers/cerebras.py +97 -0
  28. ai_code_assistant/providers/factory.py +148 -0
  29. ai_code_assistant/providers/google.py +103 -0
  30. ai_code_assistant/providers/groq.py +111 -0
  31. ai_code_assistant/providers/ollama.py +86 -0
  32. ai_code_assistant/providers/openai.py +114 -0
  33. ai_code_assistant/providers/openrouter.py +130 -0
  34. ai_code_assistant/py.typed +0 -0
  35. ai_code_assistant/refactor/__init__.py +20 -0
  36. ai_code_assistant/refactor/analyzer.py +189 -0
  37. ai_code_assistant/refactor/change_plan.py +172 -0
  38. ai_code_assistant/refactor/multi_file_editor.py +346 -0
  39. ai_code_assistant/refactor/prompts.py +175 -0
  40. ai_code_assistant/retrieval/__init__.py +19 -0
  41. ai_code_assistant/retrieval/chunker.py +215 -0
  42. ai_code_assistant/retrieval/indexer.py +236 -0
  43. ai_code_assistant/retrieval/search.py +239 -0
  44. ai_code_assistant/reviewer/__init__.py +7 -0
  45. ai_code_assistant/reviewer/analyzer.py +278 -0
  46. ai_code_assistant/reviewer/prompts.py +113 -0
  47. ai_code_assistant/utils/__init__.py +18 -0
  48. ai_code_assistant/utils/file_handler.py +155 -0
  49. ai_code_assistant/utils/formatters.py +259 -0
  50. cognify_code-0.2.0.dist-info/METADATA +383 -0
  51. cognify_code-0.2.0.dist-info/RECORD +55 -0
  52. cognify_code-0.2.0.dist-info/WHEEL +5 -0
  53. cognify_code-0.2.0.dist-info/entry_points.txt +3 -0
  54. cognify_code-0.2.0.dist-info/licenses/LICENSE +22 -0
  55. cognify_code-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,203 @@
1
+ """Git repository manager."""
2
+
3
+ import subprocess
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import List, Optional, Tuple
7
+
8
+
9
+ @dataclass
10
+ class GitStatus:
11
+ """Represents the current git status."""
12
+ staged: List[str] = field(default_factory=list)
13
+ modified: List[str] = field(default_factory=list)
14
+ untracked: List[str] = field(default_factory=list)
15
+ deleted: List[str] = field(default_factory=list)
16
+ branch: str = ""
17
+ remote: str = ""
18
+ ahead: int = 0
19
+ behind: int = 0
20
+
21
+ @property
22
+ def has_changes(self) -> bool:
23
+ return bool(self.staged or self.modified or self.untracked or self.deleted)
24
+
25
+ @property
26
+ def has_staged(self) -> bool:
27
+ return bool(self.staged)
28
+
29
+ @property
30
+ def total_changes(self) -> int:
31
+ return len(self.staged) + len(self.modified) + len(self.untracked) + len(self.deleted)
32
+
33
+
34
+ @dataclass
35
+ class GitDiff:
36
+ """Represents a git diff."""
37
+ files_changed: int = 0
38
+ insertions: int = 0
39
+ deletions: int = 0
40
+ diff_text: str = ""
41
+ file_diffs: List[dict] = field(default_factory=list)
42
+
43
+
44
+ class GitManager:
45
+ """Manages git operations."""
46
+
47
+ def __init__(self, repo_path: Optional[Path] = None):
48
+ self.repo_path = repo_path or Path.cwd()
49
+ self._validate_repo()
50
+
51
+ def _validate_repo(self) -> None:
52
+ """Validate that we're in a git repository."""
53
+ if not (self.repo_path / ".git").exists():
54
+ raise ValueError(f"Not a git repository: {self.repo_path}")
55
+
56
+ def _run_git(self, *args: str, check: bool = True) -> subprocess.CompletedProcess:
57
+ """Run a git command."""
58
+ result = subprocess.run(
59
+ ["git", *args],
60
+ cwd=self.repo_path,
61
+ capture_output=True,
62
+ text=True,
63
+ )
64
+ if check and result.returncode != 0:
65
+ raise RuntimeError(f"Git command failed: {result.stderr}")
66
+ return result
67
+
68
+ def get_status(self) -> GitStatus:
69
+ """Get current git status."""
70
+ status = GitStatus()
71
+
72
+ # Get branch name
73
+ result = self._run_git("branch", "--show-current", check=False)
74
+ status.branch = result.stdout.strip()
75
+
76
+ # Get remote
77
+ result = self._run_git("remote", check=False)
78
+ status.remote = result.stdout.strip().split("\n")[0] if result.stdout else "origin"
79
+
80
+ # Get ahead/behind
81
+ result = self._run_git("rev-list", "--left-right", "--count", f"{status.remote}/{status.branch}...HEAD", check=False)
82
+ if result.returncode == 0 and result.stdout.strip():
83
+ parts = result.stdout.strip().split()
84
+ if len(parts) == 2:
85
+ status.behind, status.ahead = int(parts[0]), int(parts[1])
86
+
87
+ # Get file status
88
+ result = self._run_git("status", "--porcelain", check=False)
89
+ for line in result.stdout.strip().split("\n"):
90
+ if not line:
91
+ continue
92
+ status_code = line[:2]
93
+ file_path = line[3:]
94
+
95
+ if status_code[0] in "MADRC":
96
+ status.staged.append(file_path)
97
+ if status_code[1] == "M":
98
+ status.modified.append(file_path)
99
+ elif status_code[1] == "D":
100
+ status.deleted.append(file_path)
101
+ elif status_code == "??":
102
+ status.untracked.append(file_path)
103
+
104
+ return status
105
+
106
+ def get_diff(self, staged: bool = True) -> GitDiff:
107
+ """Get diff of changes."""
108
+ diff = GitDiff()
109
+
110
+ # Get diff text
111
+ if staged:
112
+ result = self._run_git("diff", "--cached", check=False)
113
+ else:
114
+ result = self._run_git("diff", check=False)
115
+
116
+ diff.diff_text = result.stdout
117
+
118
+ # Get stats
119
+ if staged:
120
+ result = self._run_git("diff", "--cached", "--stat", check=False)
121
+ else:
122
+ result = self._run_git("diff", "--stat", check=False)
123
+
124
+ # Parse stats from last line
125
+ lines = result.stdout.strip().split("\n")
126
+ if lines and lines[-1]:
127
+ last_line = lines[-1]
128
+ # Parse "X files changed, Y insertions(+), Z deletions(-)"
129
+ import re
130
+ match = re.search(r"(\d+) files? changed", last_line)
131
+ if match:
132
+ diff.files_changed = int(match.group(1))
133
+ match = re.search(r"(\d+) insertions?", last_line)
134
+ if match:
135
+ diff.insertions = int(match.group(1))
136
+ match = re.search(r"(\d+) deletions?", last_line)
137
+ if match:
138
+ diff.deletions = int(match.group(1))
139
+
140
+ return diff
141
+
142
+ def stage_all(self) -> None:
143
+ """Stage all changes."""
144
+ self._run_git("add", "-A")
145
+
146
+ def stage_files(self, files: List[str]) -> None:
147
+ """Stage specific files."""
148
+ self._run_git("add", *files)
149
+
150
+ def commit(self, message: str) -> str:
151
+ """Create a commit with the given message."""
152
+ result = self._run_git("commit", "-m", message)
153
+ # Extract commit hash from output
154
+ import re
155
+ match = re.search(r"\[[\w-]+ ([a-f0-9]+)\]", result.stdout)
156
+ return match.group(1) if match else ""
157
+
158
+ def push(self, remote: str = "origin", branch: Optional[str] = None,
159
+ set_upstream: bool = False) -> Tuple[bool, str]:
160
+ """Push to remote."""
161
+ status = self.get_status()
162
+ branch = branch or status.branch
163
+
164
+ args = ["push"]
165
+ if set_upstream:
166
+ args.extend(["-u", remote, branch])
167
+ else:
168
+ args.extend([remote, branch])
169
+
170
+ result = self._run_git(*args, check=False)
171
+ success = result.returncode == 0
172
+ output = result.stdout + result.stderr
173
+
174
+ return success, output
175
+
176
+ def get_recent_commits(self, count: int = 5) -> List[dict]:
177
+ """Get recent commits."""
178
+ result = self._run_git(
179
+ "log", f"-{count}",
180
+ "--pretty=format:%H|%h|%s|%an|%ar",
181
+ check=False
182
+ )
183
+
184
+ commits = []
185
+ for line in result.stdout.strip().split("\n"):
186
+ if not line:
187
+ continue
188
+ parts = line.split("|")
189
+ if len(parts) >= 5:
190
+ commits.append({
191
+ "hash": parts[0],
192
+ "short_hash": parts[1],
193
+ "message": parts[2],
194
+ "author": parts[3],
195
+ "time": parts[4],
196
+ })
197
+
198
+ return commits
199
+
200
+ def get_remote_url(self) -> str:
201
+ """Get the remote URL."""
202
+ result = self._run_git("remote", "get-url", "origin", check=False)
203
+ return result.stdout.strip()
@@ -0,0 +1,111 @@
1
+ """LLM Manager for multi-provider LLM integration via LangChain."""
2
+
3
+ from typing import Iterator, Optional
4
+
5
+ from langchain_core.language_models import BaseChatModel
6
+ from langchain_core.messages import HumanMessage, SystemMessage
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+
9
+ from ai_code_assistant.config import Config
10
+ from ai_code_assistant.providers.base import BaseProvider, ProviderConfig, ProviderType
11
+ from ai_code_assistant.providers.factory import get_provider, get_available_providers
12
+
13
+
14
+ class LLMManager:
15
+ """Manages LLM interactions using LangChain with multiple provider support."""
16
+
17
+ def __init__(self, config: Config):
18
+ """Initialize LLM manager with configuration."""
19
+ self.config = config
20
+ self._provider: Optional[BaseProvider] = None
21
+
22
+ @property
23
+ def provider(self) -> BaseProvider:
24
+ """Get or create the provider instance."""
25
+ if self._provider is None:
26
+ provider_config = ProviderConfig(
27
+ provider=ProviderType(self.config.llm.provider),
28
+ model=self.config.llm.model,
29
+ api_key=self.config.llm.api_key,
30
+ base_url=self.config.llm.base_url,
31
+ temperature=self.config.llm.temperature,
32
+ max_tokens=self.config.llm.max_tokens,
33
+ timeout=self.config.llm.timeout,
34
+ )
35
+ self._provider = get_provider(provider_config)
36
+ return self._provider
37
+
38
+ @property
39
+ def llm(self) -> BaseChatModel:
40
+ """Get the underlying LLM instance for backward compatibility."""
41
+ return self.provider.llm
42
+
43
+ def invoke(self, prompt: str, system_prompt: Optional[str] = None) -> str:
44
+ """Invoke the LLM with a prompt and optional system message."""
45
+ return self.provider.invoke(prompt, system_prompt)
46
+
47
+ def invoke_with_template(
48
+ self,
49
+ template: ChatPromptTemplate,
50
+ **kwargs,
51
+ ) -> str:
52
+ """Invoke the LLM using a prompt template."""
53
+ chain = template | self.llm
54
+ response = chain.invoke(kwargs)
55
+ return str(response.content)
56
+
57
+ def stream(self, prompt: str, system_prompt: Optional[str] = None) -> Iterator[str]:
58
+ """Stream LLM response for real-time output."""
59
+ return self.provider.stream(prompt, system_prompt)
60
+
61
+ def check_connection(self) -> bool:
62
+ """Check if the LLM provider is accessible."""
63
+ return self.provider.check_connection()
64
+
65
+ def get_model_info(self) -> dict:
66
+ """Get information about the current model configuration."""
67
+ return self.provider.get_model_info()
68
+
69
+ def validate_config(self) -> tuple[bool, str]:
70
+ """Validate the current provider configuration."""
71
+ return self.provider.validate_config()
72
+
73
+ @staticmethod
74
+ def list_providers() -> dict:
75
+ """List all available providers and their models."""
76
+ return get_available_providers()
77
+
78
+ def switch_provider(
79
+ self,
80
+ provider: str,
81
+ model: Optional[str] = None,
82
+ api_key: Optional[str] = None,
83
+ ) -> None:
84
+ """
85
+ Switch to a different provider at runtime.
86
+
87
+ Args:
88
+ provider: Provider name (ollama, google, groq, cerebras, openrouter, openai)
89
+ model: Optional model name (uses provider default if not specified)
90
+ api_key: Optional API key (uses environment variable if not specified)
91
+ """
92
+ provider_type = ProviderType(provider)
93
+
94
+ # Get default model for provider if not specified
95
+ if model is None:
96
+ from ai_code_assistant.providers.factory import PROVIDER_REGISTRY
97
+ provider_class = PROVIDER_REGISTRY.get(provider_type)
98
+ if provider_class:
99
+ model = provider_class.default_model
100
+ else:
101
+ raise ValueError(f"Unknown provider: {provider}")
102
+
103
+ provider_config = ProviderConfig(
104
+ provider=provider_type,
105
+ model=model,
106
+ api_key=api_key,
107
+ temperature=self.config.llm.temperature,
108
+ max_tokens=self.config.llm.max_tokens,
109
+ timeout=self.config.llm.timeout,
110
+ )
111
+ self._provider = get_provider(provider_config)
@@ -0,0 +1,23 @@
1
+ """LLM Provider implementations for Cognify AI."""
2
+
3
+ from ai_code_assistant.providers.base import BaseProvider, ProviderType
4
+ from ai_code_assistant.providers.ollama import OllamaProvider
5
+ from ai_code_assistant.providers.google import GoogleProvider
6
+ from ai_code_assistant.providers.groq import GroqProvider
7
+ from ai_code_assistant.providers.cerebras import CerebrasProvider
8
+ from ai_code_assistant.providers.openrouter import OpenRouterProvider
9
+ from ai_code_assistant.providers.openai import OpenAIProvider
10
+ from ai_code_assistant.providers.factory import get_provider, get_available_providers
11
+
12
+ __all__ = [
13
+ "BaseProvider",
14
+ "ProviderType",
15
+ "OllamaProvider",
16
+ "GoogleProvider",
17
+ "GroqProvider",
18
+ "CerebrasProvider",
19
+ "OpenRouterProvider",
20
+ "OpenAIProvider",
21
+ "get_provider",
22
+ "get_available_providers",
23
+ ]
@@ -0,0 +1,124 @@
1
+ """Base provider class for LLM integrations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from enum import Enum
5
+ from typing import Any, Dict, Iterator, List, Optional
6
+
7
+ from langchain_core.language_models import BaseChatModel
8
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
9
+ from pydantic import BaseModel
10
+
11
+
12
+ class ProviderType(str, Enum):
13
+ """Supported LLM provider types."""
14
+ OLLAMA = "ollama"
15
+ GOOGLE = "google"
16
+ GROQ = "groq"
17
+ CEREBRAS = "cerebras"
18
+ OPENROUTER = "openrouter"
19
+ OPENAI = "openai"
20
+
21
+
22
+ class ProviderConfig(BaseModel):
23
+ """Configuration for an LLM provider."""
24
+ provider: ProviderType = ProviderType.OLLAMA
25
+ model: str = "deepseek-coder:6.7b"
26
+ api_key: Optional[str] = None
27
+ base_url: Optional[str] = None
28
+ temperature: float = 0.1
29
+ max_tokens: int = 4096
30
+ timeout: int = 120
31
+
32
+
33
+ class ModelInfo(BaseModel):
34
+ """Information about a model."""
35
+ name: str
36
+ provider: ProviderType
37
+ description: str = ""
38
+ context_window: int = 4096
39
+ is_free: bool = True
40
+
41
+
42
+ class BaseProvider(ABC):
43
+ """Abstract base class for LLM providers."""
44
+
45
+ # Provider metadata
46
+ provider_type: ProviderType
47
+ display_name: str
48
+ requires_api_key: bool = True
49
+ default_model: str
50
+ free_tier: bool = True
51
+
52
+ # Available models for this provider
53
+ available_models: List[ModelInfo] = []
54
+
55
+ def __init__(self, config: ProviderConfig):
56
+ """Initialize the provider with configuration."""
57
+ self.config = config
58
+ self._llm: Optional[BaseChatModel] = None
59
+
60
+ @property
61
+ def llm(self) -> BaseChatModel:
62
+ """Get or create the LLM instance."""
63
+ if self._llm is None:
64
+ self._llm = self._create_llm()
65
+ return self._llm
66
+
67
+ @abstractmethod
68
+ def _create_llm(self) -> BaseChatModel:
69
+ """Create the LangChain LLM instance. Must be implemented by subclasses."""
70
+ pass
71
+
72
+ @abstractmethod
73
+ def validate_config(self) -> tuple[bool, str]:
74
+ """Validate the provider configuration. Returns (is_valid, error_message)."""
75
+ pass
76
+
77
+ def invoke(self, prompt: str, system_prompt: Optional[str] = None) -> str:
78
+ """Invoke the LLM with a prompt and optional system message."""
79
+ messages: List[BaseMessage] = []
80
+ if system_prompt:
81
+ messages.append(SystemMessage(content=system_prompt))
82
+ messages.append(HumanMessage(content=prompt))
83
+
84
+ response = self.llm.invoke(messages)
85
+ return str(response.content)
86
+
87
+ def stream(self, prompt: str, system_prompt: Optional[str] = None) -> Iterator[str]:
88
+ """Stream LLM response for real-time output."""
89
+ messages: List[BaseMessage] = []
90
+ if system_prompt:
91
+ messages.append(SystemMessage(content=system_prompt))
92
+ messages.append(HumanMessage(content=prompt))
93
+
94
+ for chunk in self.llm.stream(messages):
95
+ yield str(chunk.content)
96
+
97
+ def check_connection(self) -> bool:
98
+ """Check if the provider is accessible."""
99
+ try:
100
+ self.invoke("Say 'ok' and nothing else.")
101
+ return True
102
+ except Exception:
103
+ return False
104
+
105
+ def get_model_info(self) -> Dict[str, Any]:
106
+ """Get information about the current model configuration."""
107
+ return {
108
+ "provider": self.provider_type.value,
109
+ "model": self.config.model,
110
+ "temperature": self.config.temperature,
111
+ "max_tokens": self.config.max_tokens,
112
+ "base_url": self.config.base_url,
113
+ "free_tier": self.free_tier,
114
+ }
115
+
116
+ @classmethod
117
+ def get_available_models(cls) -> List[ModelInfo]:
118
+ """Get list of available models for this provider."""
119
+ return cls.available_models
120
+
121
+ @classmethod
122
+ def get_setup_instructions(cls) -> str:
123
+ """Get setup instructions for this provider."""
124
+ return f"Configure {cls.display_name} in your config.yaml or set environment variables."
@@ -0,0 +1,97 @@
1
+ """Cerebras provider for fast LLM inference."""
2
+
3
+ import os
4
+ from langchain_core.language_models import BaseChatModel
5
+
6
+ from ai_code_assistant.providers.base import (
7
+ BaseProvider,
8
+ ModelInfo,
9
+ ProviderConfig,
10
+ ProviderType,
11
+ )
12
+
13
+
14
+ class CerebrasProvider(BaseProvider):
15
+ """Provider for Cerebras - fast inference with large models."""
16
+
17
+ provider_type = ProviderType.CEREBRAS
18
+ display_name = "Cerebras (Fast Inference)"
19
+ requires_api_key = True
20
+ default_model = "llama3.1-8b"
21
+ free_tier = True
22
+
23
+ available_models = [
24
+ ModelInfo(
25
+ name="llama3.1-8b",
26
+ provider=ProviderType.CEREBRAS,
27
+ description="Llama 3.1 8B - Fast and efficient",
28
+ context_window=8192,
29
+ is_free=True,
30
+ ),
31
+ ModelInfo(
32
+ name="llama3.1-70b",
33
+ provider=ProviderType.CEREBRAS,
34
+ description="Llama 3.1 70B - More capable",
35
+ context_window=8192,
36
+ is_free=True,
37
+ ),
38
+ ModelInfo(
39
+ name="llama-3.3-70b",
40
+ provider=ProviderType.CEREBRAS,
41
+ description="Llama 3.3 70B - Latest Llama model",
42
+ context_window=8192,
43
+ is_free=True,
44
+ ),
45
+ ]
46
+
47
+ def _create_llm(self) -> BaseChatModel:
48
+ """Create Cerebras LLM instance using OpenAI-compatible API."""
49
+ try:
50
+ from langchain_openai import ChatOpenAI
51
+ except ImportError:
52
+ raise ImportError(
53
+ "langchain-openai is required for Cerebras provider. "
54
+ "Install with: pip install langchain-openai"
55
+ )
56
+
57
+ api_key = self.config.api_key or os.getenv("CEREBRAS_API_KEY")
58
+ if not api_key:
59
+ raise ValueError(
60
+ "Cerebras API key is required. Set CEREBRAS_API_KEY environment variable "
61
+ "or provide api_key in config. Get your key at: https://cloud.cerebras.ai/"
62
+ )
63
+
64
+ return ChatOpenAI(
65
+ model=self.config.model,
66
+ api_key=api_key,
67
+ base_url="https://api.cerebras.ai/v1",
68
+ temperature=self.config.temperature,
69
+ max_tokens=self.config.max_tokens,
70
+ timeout=self.config.timeout,
71
+ )
72
+
73
+ def validate_config(self) -> tuple[bool, str]:
74
+ """Validate Cerebras configuration."""
75
+ api_key = self.config.api_key or os.getenv("CEREBRAS_API_KEY")
76
+ if not api_key:
77
+ return False, "CEREBRAS_API_KEY environment variable or api_key config is required"
78
+ if not self.config.model:
79
+ return False, "Model name is required"
80
+ return True, ""
81
+
82
+ @classmethod
83
+ def get_setup_instructions(cls) -> str:
84
+ """Get Cerebras setup instructions."""
85
+ return """
86
+ Cerebras Setup Instructions:
87
+ 1. Go to https://cloud.cerebras.ai/
88
+ 2. Create a free account and generate an API key
89
+ 3. Set environment variable: export CEREBRAS_API_KEY="your-key"
90
+ 4. Or add to config.yaml:
91
+ llm:
92
+ provider: cerebras
93
+ api_key: "your-key"
94
+ model: llama3.1-8b
95
+
96
+ Free tier: 14,400 requests/day, very fast inference
97
+ """