enhanced-git 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enhanced_git-1.0.0.dist-info/METADATA +349 -0
- enhanced_git-1.0.0.dist-info/RECORD +18 -0
- enhanced_git-1.0.0.dist-info/WHEEL +4 -0
- enhanced_git-1.0.0.dist-info/entry_points.txt +2 -0
- enhanced_git-1.0.0.dist-info/licenses/LICENSE +21 -0
- gitai/__init__.py +3 -0
- gitai/changelog.py +251 -0
- gitai/cli.py +166 -0
- gitai/commit.py +338 -0
- gitai/config.py +120 -0
- gitai/constants.py +134 -0
- gitai/diff.py +167 -0
- gitai/hook.py +81 -0
- gitai/providers/__init__.py +1 -0
- gitai/providers/base.py +71 -0
- gitai/providers/ollama_provider.py +86 -0
- gitai/providers/openai_provider.py +78 -0
- gitai/util.py +137 -0
gitai/constants.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
"""Constants and templates for GitAI."""
|
2
|
+
|
3
|
+
from typing import Final
|
4
|
+
|
5
|
+
# Commit message prompts
|
6
|
+
COMMIT_SYSTEM_PROMPT: Final[
|
7
|
+
str
|
8
|
+
] = """You are an expert release engineer. Output plain text only, no code fences.
|
9
|
+
Follow Conventional Commits strictly. Do not invent changes."""
|
10
|
+
|
11
|
+
COMMIT_USER_PROMPT_SINGLE: Final[
|
12
|
+
str
|
13
|
+
] = """Generate a Conventional Commit message from this unified diff.
|
14
|
+
|
15
|
+
Rules:
|
16
|
+
- Subject < 70 chars: type(optional-scope): message
|
17
|
+
- Sentence case, no trailing period in subject
|
18
|
+
- Include a short body with bullet points only if helpful
|
19
|
+
- Wrap body at 72 columns
|
20
|
+
- Do not invent changes; reflect only what the diff shows
|
21
|
+
|
22
|
+
Diff:
|
23
|
+
{diff}"""
|
24
|
+
|
25
|
+
COMMIT_USER_PROMPT_MERGE: Final[
|
26
|
+
str
|
27
|
+
] = """Given these chunk summaries of staged changes, produce one Conventional Commit message.
|
28
|
+
Follow the same rules as above. Prefer the most representative type(scope).
|
29
|
+
|
30
|
+
Chunk summaries:
|
31
|
+
{chunk_summaries}"""
|
32
|
+
|
33
|
+
# Changelog prompts
|
34
|
+
CHANGELOG_SYSTEM_PROMPT: Final[str] = (
|
35
|
+
"""You are an expert technical writer. Output markdown bullets only. Do not invent facts."""
|
36
|
+
)
|
37
|
+
|
38
|
+
CHANGELOG_USER_PROMPT: Final[
|
39
|
+
str
|
40
|
+
] = """Polish these grouped commit bullets for a changelog. Improve clarity and brevity.
|
41
|
+
Do not add or invent items. Keep issue/PR references unchanged.
|
42
|
+
|
43
|
+
Grouped bullets:
|
44
|
+
{grouped_bullets}"""
|
45
|
+
|
46
|
+
# Type mappings for heuristics
|
47
|
+
TYPE_HINTS_PATH: Final[dict[str, str]] = {
|
48
|
+
"tests/": "test",
|
49
|
+
"_test.": "test",
|
50
|
+
"docs/": "docs",
|
51
|
+
".md": "docs",
|
52
|
+
"mkdocs.yml": "docs",
|
53
|
+
".github/workflows/": "ci",
|
54
|
+
"Dockerfile": "ci",
|
55
|
+
"docker-compose": "ci",
|
56
|
+
"Makefile": "build",
|
57
|
+
"pyproject.toml": "build",
|
58
|
+
"setup.py": "build",
|
59
|
+
"requirements": "build",
|
60
|
+
}
|
61
|
+
|
62
|
+
TYPE_HINTS_CONTENT: Final[dict[str, str]] = {
|
63
|
+
"add": "feat",
|
64
|
+
"new": "feat",
|
65
|
+
"create": "feat",
|
66
|
+
"remove": "feat",
|
67
|
+
"delete": "feat",
|
68
|
+
"fix": "fix",
|
69
|
+
"bug": "fix",
|
70
|
+
"error": "fix",
|
71
|
+
"issue": "fix",
|
72
|
+
"refactor": "refactor",
|
73
|
+
"rename": "chore",
|
74
|
+
"move": "chore",
|
75
|
+
"update": "feat",
|
76
|
+
"improve": "feat",
|
77
|
+
"enhance": "feat",
|
78
|
+
"optimize": "perf",
|
79
|
+
"performance": "perf",
|
80
|
+
"speed": "perf",
|
81
|
+
"test": "test",
|
82
|
+
"testing": "test",
|
83
|
+
"doc": "docs",
|
84
|
+
"document": "docs",
|
85
|
+
"readme": "docs",
|
86
|
+
"ci": "ci",
|
87
|
+
"build": "ci",
|
88
|
+
"config": "ci",
|
89
|
+
"chore": "chore",
|
90
|
+
"maintenance": "chore",
|
91
|
+
}
|
92
|
+
|
93
|
+
# Conventional Commit types in display order
|
94
|
+
CHANGELOG_SECTIONS: Final[list[str]] = [
|
95
|
+
"Features",
|
96
|
+
"Fixes",
|
97
|
+
"Performance",
|
98
|
+
"Documentation",
|
99
|
+
"Refactoring",
|
100
|
+
"CI",
|
101
|
+
"Build",
|
102
|
+
"Tests",
|
103
|
+
"Chore",
|
104
|
+
"Other",
|
105
|
+
]
|
106
|
+
|
107
|
+
# Type to section mapping
|
108
|
+
TYPE_TO_SECTION: Final[dict[str, str]] = {
|
109
|
+
"feat": "Features",
|
110
|
+
"fix": "Fixes",
|
111
|
+
"perf": "Performance",
|
112
|
+
"docs": "Documentation",
|
113
|
+
"refactor": "Refactoring",
|
114
|
+
"ci": "CI",
|
115
|
+
"build": "Build",
|
116
|
+
"test": "Tests",
|
117
|
+
"chore": "Chore",
|
118
|
+
}
|
119
|
+
|
120
|
+
# File size limits
|
121
|
+
MAX_CHUNK_SIZE: Final[int] = 6000 # characters
|
122
|
+
MAX_SUBJECT_LENGTH: Final[int] = 100
|
123
|
+
BODY_WRAP_WIDTH: Final[int] = 72
|
124
|
+
CHANGELOG_WRAP_WIDTH: Final[int] = 100
|
125
|
+
|
126
|
+
# Default timeouts
|
127
|
+
DEFAULT_TIMEOUT: Final[int] = 45
|
128
|
+
|
129
|
+
# Git hook content
|
130
|
+
HOOK_CONTENT: Final[
|
131
|
+
str
|
132
|
+
] = """#!/usr/bin/env sh
|
133
|
+
git-ai commit --hook "$1" || exit 0
|
134
|
+
"""
|
gitai/diff.py
ADDED
@@ -0,0 +1,167 @@
|
|
1
|
+
"""Diff collection and chunking functionality."""
|
2
|
+
|
3
|
+
import re
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from .constants import MAX_CHUNK_SIZE
|
7
|
+
from .util import run
|
8
|
+
|
9
|
+
|
10
|
+
class DiffChunk:
|
11
|
+
"""Represents a chunk of diff content."""
|
12
|
+
|
13
|
+
def __init__(self, content: str, files: list[str]):
|
14
|
+
self.content = content
|
15
|
+
self.files = files
|
16
|
+
self.size = len(content)
|
17
|
+
|
18
|
+
def __str__(self) -> str:
|
19
|
+
return f"DiffChunk({len(self.files)} files, {self.size} chars)"
|
20
|
+
|
21
|
+
|
22
|
+
class StagedDiff:
|
23
|
+
"""Represents staged changes in a git repository."""
|
24
|
+
|
25
|
+
def __init__(self, raw_diff: str):
|
26
|
+
self.raw_diff = raw_diff
|
27
|
+
self.files = self._extract_files()
|
28
|
+
self.stats = self._extract_stats()
|
29
|
+
|
30
|
+
def _extract_files(self) -> list[str]:
|
31
|
+
"""Extract file paths from diff."""
|
32
|
+
files = []
|
33
|
+
lines = self.raw_diff.split("\n")
|
34
|
+
|
35
|
+
for line in lines:
|
36
|
+
if line.startswith("diff --git"):
|
37
|
+
# extract file paths from diff --git a/path b/path
|
38
|
+
match = re.search(r"diff --git a/(.+) b/(.+)", line)
|
39
|
+
if match:
|
40
|
+
file_path = match.group(2) # use b/path (new path)
|
41
|
+
files.append(file_path)
|
42
|
+
|
43
|
+
return files
|
44
|
+
|
45
|
+
def _extract_stats(self) -> dict[str, Any]:
|
46
|
+
"""Extract statistics from diff."""
|
47
|
+
stats = {
|
48
|
+
"files_changed": len(self.files),
|
49
|
+
"additions": 0,
|
50
|
+
"deletions": 0,
|
51
|
+
"renames": 0,
|
52
|
+
"new_files": 0,
|
53
|
+
}
|
54
|
+
|
55
|
+
lines = self.raw_diff.split("\n")
|
56
|
+
for line in lines:
|
57
|
+
if line.startswith("+++ b/") and line.endswith("(new file)"):
|
58
|
+
stats["new_files"] += 1
|
59
|
+
elif line.startswith("rename"):
|
60
|
+
stats["renames"] += 1
|
61
|
+
elif line.startswith("+") and not line.startswith("+++"):
|
62
|
+
stats["additions"] += 1
|
63
|
+
elif line.startswith("-") and not line.startswith("---"):
|
64
|
+
stats["deletions"] += 1
|
65
|
+
|
66
|
+
return stats
|
67
|
+
|
68
|
+
def is_empty(self) -> bool:
|
69
|
+
"""Check if diff is empty."""
|
70
|
+
return not self.raw_diff.strip()
|
71
|
+
|
72
|
+
def chunk_by_files(self, max_size: int = MAX_CHUNK_SIZE) -> list[DiffChunk]:
|
73
|
+
"""Chunk diff by files to stay within size limits."""
|
74
|
+
if self.is_empty():
|
75
|
+
return []
|
76
|
+
|
77
|
+
chunks = []
|
78
|
+
current_chunk = ""
|
79
|
+
current_files: list[str] = []
|
80
|
+
|
81
|
+
# split diff by file boundaries
|
82
|
+
file_sections = re.split(r"(?=diff --git)", self.raw_diff)
|
83
|
+
|
84
|
+
for section in file_sections:
|
85
|
+
if not section.strip():
|
86
|
+
continue
|
87
|
+
|
88
|
+
# extract file path from this section
|
89
|
+
file_match = re.search(r"diff --git a/(.+) b/(.+)", section)
|
90
|
+
if file_match:
|
91
|
+
file_path = file_match.group(2)
|
92
|
+
|
93
|
+
# check if adding this file would exceed the chunk size
|
94
|
+
if len(current_chunk + section) > max_size and current_chunk:
|
95
|
+
# create chunk with current content
|
96
|
+
chunks.append(DiffChunk(current_chunk.strip(), current_files))
|
97
|
+
current_chunk = section
|
98
|
+
current_files = [file_path]
|
99
|
+
else:
|
100
|
+
current_chunk += section
|
101
|
+
current_files.append(file_path)
|
102
|
+
|
103
|
+
# add remaining content as final chunk
|
104
|
+
if current_chunk.strip():
|
105
|
+
chunks.append(DiffChunk(current_chunk.strip(), current_files))
|
106
|
+
|
107
|
+
return chunks
|
108
|
+
|
109
|
+
|
110
|
+
def get_staged_diff() -> StagedDiff:
|
111
|
+
"""Get staged diff from git."""
|
112
|
+
try:
|
113
|
+
raw_diff = run(["git", "diff", "--staged", "-U0"])
|
114
|
+
return StagedDiff(raw_diff)
|
115
|
+
except SystemExit:
|
116
|
+
# return empty diff if command fails
|
117
|
+
return StagedDiff("")
|
118
|
+
|
119
|
+
|
120
|
+
def get_diff_between_refs(since_ref: str, to_ref: str = "HEAD") -> str:
|
121
|
+
"""Get diff between two git references."""
|
122
|
+
try:
|
123
|
+
return run(
|
124
|
+
[
|
125
|
+
"git",
|
126
|
+
"log",
|
127
|
+
"--pretty=format:%H%n%s%n%b---END---",
|
128
|
+
f"{since_ref}..{to_ref}",
|
129
|
+
]
|
130
|
+
)
|
131
|
+
except SystemExit:
|
132
|
+
return ""
|
133
|
+
|
134
|
+
|
135
|
+
def get_commit_history(since_ref: str, to_ref: str = "HEAD") -> list[dict[str, str]]:
|
136
|
+
"""Get commit history between two references."""
|
137
|
+
try:
|
138
|
+
log_output = run(
|
139
|
+
[
|
140
|
+
"git",
|
141
|
+
"log",
|
142
|
+
"--pretty=format:%H%n%s%n%b---END---",
|
143
|
+
f"{since_ref}..{to_ref}",
|
144
|
+
]
|
145
|
+
)
|
146
|
+
|
147
|
+
commits = []
|
148
|
+
sections = log_output.split("---END---")
|
149
|
+
|
150
|
+
for section in sections:
|
151
|
+
lines = section.strip().split("\n")
|
152
|
+
if len(lines) >= 2:
|
153
|
+
commit_hash = lines[0]
|
154
|
+
subject = lines[1]
|
155
|
+
body = "\n".join(lines[2:]) if len(lines) > 2 else ""
|
156
|
+
|
157
|
+
commits.append(
|
158
|
+
{
|
159
|
+
"hash": commit_hash,
|
160
|
+
"subject": subject,
|
161
|
+
"body": body,
|
162
|
+
}
|
163
|
+
)
|
164
|
+
|
165
|
+
return commits
|
166
|
+
except SystemExit:
|
167
|
+
return []
|
gitai/hook.py
ADDED
@@ -0,0 +1,81 @@
|
|
1
|
+
"""Git hook installation functionality."""
|
2
|
+
|
3
|
+
from .constants import HOOK_CONTENT
|
4
|
+
from .util import (
|
5
|
+
backup_file,
|
6
|
+
find_git_root,
|
7
|
+
make_executable,
|
8
|
+
print_success,
|
9
|
+
print_warning,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
def install_commit_msg_hook(force: bool = False) -> None:
|
14
|
+
"""Install the commit-msg Git hook."""
|
15
|
+
git_root = find_git_root()
|
16
|
+
hooks_dir = git_root / ".git" / "hooks"
|
17
|
+
|
18
|
+
if not hooks_dir.exists():
|
19
|
+
print_warning(".git/hooks directory not found. Is this a valid git repository?")
|
20
|
+
return
|
21
|
+
|
22
|
+
hook_path = hooks_dir / "commit-msg"
|
23
|
+
|
24
|
+
# check if hook already exists
|
25
|
+
if hook_path.exists() and not force:
|
26
|
+
content = hook_path.read_text().strip()
|
27
|
+
if content == HOOK_CONTENT.strip():
|
28
|
+
print_success("Git-AI commit-msg hook is already installed")
|
29
|
+
return
|
30
|
+
|
31
|
+
print_warning("Existing commit-msg hook found:")
|
32
|
+
print_warning(str(hook_path))
|
33
|
+
print_warning("Use --force to overwrite it, or manually merge the content.")
|
34
|
+
|
35
|
+
# show current content
|
36
|
+
print_warning("Current hook content:")
|
37
|
+
print_warning(content)
|
38
|
+
print_warning("---")
|
39
|
+
print_warning("Git-AI hook content:")
|
40
|
+
print_warning(HOOK_CONTENT.strip())
|
41
|
+
|
42
|
+
# offer to create backup
|
43
|
+
backup = backup_file(hook_path)
|
44
|
+
if backup:
|
45
|
+
print_success(f"Created backup: {backup}")
|
46
|
+
|
47
|
+
return
|
48
|
+
|
49
|
+
# install the hook
|
50
|
+
try:
|
51
|
+
hook_path.write_text(HOOK_CONTENT)
|
52
|
+
make_executable(hook_path)
|
53
|
+
print_success("Installed Git-AI commit-msg hook")
|
54
|
+
print_success("The hook will generate commit messages for your commits.")
|
55
|
+
print_success("To remove it, delete or modify .git/hooks/commit-msg")
|
56
|
+
|
57
|
+
except Exception as e:
|
58
|
+
print_warning(f"Failed to install hook: {e}")
|
59
|
+
|
60
|
+
|
61
|
+
def uninstall_commit_msg_hook() -> None:
|
62
|
+
"""Remove the Git-AI commit-msg hook."""
|
63
|
+
git_root = find_git_root()
|
64
|
+
hook_path = git_root / ".git" / "hooks" / "commit-msg"
|
65
|
+
|
66
|
+
if not hook_path.exists():
|
67
|
+
print_warning("Git-AI commit-msg hook not found")
|
68
|
+
return
|
69
|
+
|
70
|
+
# check if it's our hook
|
71
|
+
content = hook_path.read_text().strip()
|
72
|
+
if content != HOOK_CONTENT.strip():
|
73
|
+
print_warning("Existing hook doesn't match Git-AI hook content")
|
74
|
+
print_warning("Refusing to remove potentially modified hook")
|
75
|
+
return
|
76
|
+
|
77
|
+
try:
|
78
|
+
hook_path.unlink()
|
79
|
+
print_success("Removed Git-AI commit-msg hook")
|
80
|
+
except Exception as e:
|
81
|
+
print_warning(f"Failed to remove hook: {e}")
|
@@ -0,0 +1 @@
|
|
1
|
+
"""LLM providers for GitAI."""
|
gitai/providers/base.py
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
"""Base protocol for LLM providers."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Any, Protocol
|
5
|
+
|
6
|
+
from ..config import Config
|
7
|
+
|
8
|
+
|
9
|
+
class LLMProvider(Protocol):
|
10
|
+
"""Protocol for LLM providers."""
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def generate(
|
14
|
+
self,
|
15
|
+
system: str,
|
16
|
+
user: str,
|
17
|
+
*,
|
18
|
+
max_tokens: int | None = None,
|
19
|
+
temperature: float = 0.0,
|
20
|
+
timeout: int = 60,
|
21
|
+
) -> str:
|
22
|
+
"""Generate text using the LLM provider.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
system: System prompt
|
26
|
+
user: User prompt
|
27
|
+
max_tokens: Maximum tokens to generate
|
28
|
+
temperature: Sampling temperature
|
29
|
+
timeout: Request timeout in seconds
|
30
|
+
config: Configuration object
|
31
|
+
Returns:
|
32
|
+
Generated text
|
33
|
+
|
34
|
+
Raises:
|
35
|
+
Exception: If generation fails
|
36
|
+
"""
|
37
|
+
|
38
|
+
|
39
|
+
class BaseProvider(ABC):
|
40
|
+
"""Base class for LLM providers with common functionality."""
|
41
|
+
|
42
|
+
def __init__(self, timeout: int = 60, config: Config = Config.load()):
|
43
|
+
self.timeout = timeout
|
44
|
+
self.debug_mode = config.debug_settings.debug_mode
|
45
|
+
|
46
|
+
@abstractmethod
|
47
|
+
def generate(
|
48
|
+
self,
|
49
|
+
system: str,
|
50
|
+
user: str,
|
51
|
+
*,
|
52
|
+
max_tokens: int | None = None,
|
53
|
+
temperature: float = 0.0,
|
54
|
+
timeout: int = 60,
|
55
|
+
) -> str:
|
56
|
+
"""Generate text using the LLM provider."""
|
57
|
+
pass
|
58
|
+
|
59
|
+
|
60
|
+
def create_provider(provider_name: str, **kwargs: Any) -> LLMProvider:
|
61
|
+
"""Factory function to create LLM providers."""
|
62
|
+
if provider_name == "openai":
|
63
|
+
from .openai_provider import OpenAIProvider
|
64
|
+
|
65
|
+
return OpenAIProvider(**kwargs)
|
66
|
+
elif provider_name == "ollama":
|
67
|
+
from .ollama_provider import OllamaProvider
|
68
|
+
|
69
|
+
return OllamaProvider(**kwargs)
|
70
|
+
else:
|
71
|
+
raise ValueError(f"Unknown provider: {provider_name}")
|
@@ -0,0 +1,86 @@
|
|
1
|
+
"""Ollama local LLM provider."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
|
5
|
+
import requests
|
6
|
+
from rich.console import Console
|
7
|
+
|
8
|
+
from .base import BaseProvider
|
9
|
+
|
10
|
+
console = Console()
|
11
|
+
|
12
|
+
|
13
|
+
class OllamaProvider(BaseProvider):
|
14
|
+
"""Ollama provider for local LLM models."""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
base_url: str = "http://localhost:11434",
|
19
|
+
model: str = "qwen2.5-coder:3b",
|
20
|
+
timeout: int = 60,
|
21
|
+
):
|
22
|
+
super().__init__(timeout)
|
23
|
+
self.base_url = base_url.rstrip("/")
|
24
|
+
self.model = model
|
25
|
+
|
26
|
+
def generate(
|
27
|
+
self,
|
28
|
+
system: str,
|
29
|
+
user: str,
|
30
|
+
*,
|
31
|
+
max_tokens: int | None = None,
|
32
|
+
temperature: float = 0.0,
|
33
|
+
timeout: int = 60,
|
34
|
+
) -> str:
|
35
|
+
"""Generate text using Ollama API."""
|
36
|
+
# combine system and user prompts
|
37
|
+
full_prompt = f"{system}\n\n{user}"
|
38
|
+
|
39
|
+
payload = {
|
40
|
+
"model": self.model,
|
41
|
+
"prompt": full_prompt,
|
42
|
+
"stream": False,
|
43
|
+
"options": {
|
44
|
+
"temperature": temperature,
|
45
|
+
"num_predict": max_tokens,
|
46
|
+
},
|
47
|
+
}
|
48
|
+
|
49
|
+
if self.debug_mode:
|
50
|
+
print(
|
51
|
+
f"Ollama API Call - Model: {self.model}, URL: {self.base_url}/api/generate"
|
52
|
+
)
|
53
|
+
print(f"Payload: {payload}")
|
54
|
+
print("-" * 50)
|
55
|
+
|
56
|
+
with console.status(
|
57
|
+
f"[bold green]Generating with Ollama... using {self.model}", spinner="dots"
|
58
|
+
):
|
59
|
+
try:
|
60
|
+
response = requests.post(
|
61
|
+
f"{self.base_url}/api/generate",
|
62
|
+
json=payload,
|
63
|
+
timeout=timeout,
|
64
|
+
)
|
65
|
+
response.raise_for_status()
|
66
|
+
|
67
|
+
result = response.json()
|
68
|
+
content = result.get("response", "").strip()
|
69
|
+
|
70
|
+
if not content:
|
71
|
+
raise ValueError("Empty response from Ollama API")
|
72
|
+
content_str: str = str(content)
|
73
|
+
|
74
|
+
if self.debug_mode:
|
75
|
+
print(f"Ollama Response received ({len(content_str)} chars)")
|
76
|
+
print(f"Response: {content_str}")
|
77
|
+
print("-" * 50)
|
78
|
+
|
79
|
+
return content_str
|
80
|
+
|
81
|
+
except requests.exceptions.RequestException as e:
|
82
|
+
print(f"Ollama API Error: {e}")
|
83
|
+
raise RuntimeError(f"Ollama API error: {e}") from e
|
84
|
+
except json.JSONDecodeError as e:
|
85
|
+
print(f"Ollama JSON Error: {e}")
|
86
|
+
raise RuntimeError(f"Invalid JSON response from Ollama: {e}") from e
|
@@ -0,0 +1,78 @@
|
|
1
|
+
"""OpenAI-compatible LLM provider."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
from openai import OpenAI
|
6
|
+
from rich.console import Console
|
7
|
+
|
8
|
+
from .base import BaseProvider
|
9
|
+
|
10
|
+
console = Console()
|
11
|
+
|
12
|
+
|
13
|
+
class OpenAIProvider(BaseProvider):
|
14
|
+
"""OpenAI-compatible LLM provider using the OpenAI Python client."""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
api_key: str | None = None,
|
19
|
+
base_url: str | None = None,
|
20
|
+
model: str = "gpt-4o-mini",
|
21
|
+
timeout: int = 60,
|
22
|
+
):
|
23
|
+
super().__init__(timeout)
|
24
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
25
|
+
self.base_url = base_url
|
26
|
+
self.model = model
|
27
|
+
if not self.api_key:
|
28
|
+
raise ValueError("OPENAI_API_KEY environment variable is required")
|
29
|
+
self.client = OpenAI(
|
30
|
+
api_key=self.api_key,
|
31
|
+
base_url=self.base_url,
|
32
|
+
timeout=timeout,
|
33
|
+
)
|
34
|
+
|
35
|
+
def generate(
|
36
|
+
self,
|
37
|
+
system: str,
|
38
|
+
user: str,
|
39
|
+
*,
|
40
|
+
max_tokens: int | None = None,
|
41
|
+
temperature: float = 0.0,
|
42
|
+
timeout: int = 60,
|
43
|
+
) -> str:
|
44
|
+
"""Generate text using OpenAI API."""
|
45
|
+
if self.debug_mode:
|
46
|
+
print(
|
47
|
+
f"OpenAI API Call - Model: {self.model}, Max tokens: {max_tokens}, Temp: {temperature}"
|
48
|
+
)
|
49
|
+
|
50
|
+
# create and start spinner
|
51
|
+
with console.status("[bold blue]Generating with OpenAI...", spinner="dots"):
|
52
|
+
try:
|
53
|
+
response = self.client.chat.completions.create(
|
54
|
+
model=self.model,
|
55
|
+
messages=[
|
56
|
+
{"role": "system", "content": system},
|
57
|
+
{"role": "user", "content": user},
|
58
|
+
],
|
59
|
+
max_tokens=max_tokens,
|
60
|
+
temperature=temperature,
|
61
|
+
timeout=timeout,
|
62
|
+
)
|
63
|
+
|
64
|
+
content = response.choices[0].message.content
|
65
|
+
if content is None:
|
66
|
+
raise ValueError("Empty response from OpenAI API")
|
67
|
+
content_str: str = str(content)
|
68
|
+
|
69
|
+
if self.debug_mode:
|
70
|
+
print(f"OpenAI Response received ({len(content_str)} chars)")
|
71
|
+
print(f"Response: {content_str}")
|
72
|
+
print("-" * 50)
|
73
|
+
|
74
|
+
return content_str.strip()
|
75
|
+
|
76
|
+
except Exception as e:
|
77
|
+
print(f"OpenAI API Error: {e}")
|
78
|
+
raise RuntimeError(f"OpenAI API error: {e}") from e
|