yuho 5.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yuho/__init__.py +16 -0
- yuho/ast/__init__.py +196 -0
- yuho/ast/builder.py +926 -0
- yuho/ast/constant_folder.py +280 -0
- yuho/ast/dead_code.py +199 -0
- yuho/ast/exhaustiveness.py +503 -0
- yuho/ast/nodes.py +907 -0
- yuho/ast/overlap.py +291 -0
- yuho/ast/reachability.py +293 -0
- yuho/ast/scope_analysis.py +490 -0
- yuho/ast/transformer.py +490 -0
- yuho/ast/type_check.py +471 -0
- yuho/ast/type_inference.py +425 -0
- yuho/ast/visitor.py +239 -0
- yuho/cli/__init__.py +14 -0
- yuho/cli/commands/__init__.py +1 -0
- yuho/cli/commands/api.py +431 -0
- yuho/cli/commands/ast_viz.py +334 -0
- yuho/cli/commands/check.py +218 -0
- yuho/cli/commands/config.py +311 -0
- yuho/cli/commands/contribute.py +122 -0
- yuho/cli/commands/diff.py +487 -0
- yuho/cli/commands/explain.py +240 -0
- yuho/cli/commands/fmt.py +253 -0
- yuho/cli/commands/generate.py +316 -0
- yuho/cli/commands/graph.py +410 -0
- yuho/cli/commands/init.py +120 -0
- yuho/cli/commands/library.py +656 -0
- yuho/cli/commands/lint.py +503 -0
- yuho/cli/commands/lsp.py +36 -0
- yuho/cli/commands/preview.py +377 -0
- yuho/cli/commands/repl.py +444 -0
- yuho/cli/commands/serve.py +44 -0
- yuho/cli/commands/test.py +528 -0
- yuho/cli/commands/transpile.py +121 -0
- yuho/cli/commands/wizard.py +370 -0
- yuho/cli/completions.py +182 -0
- yuho/cli/error_formatter.py +193 -0
- yuho/cli/main.py +1064 -0
- yuho/config/__init__.py +46 -0
- yuho/config/loader.py +235 -0
- yuho/config/mask.py +194 -0
- yuho/config/schema.py +147 -0
- yuho/library/__init__.py +84 -0
- yuho/library/index.py +328 -0
- yuho/library/install.py +699 -0
- yuho/library/lockfile.py +330 -0
- yuho/library/package.py +421 -0
- yuho/library/resolver.py +791 -0
- yuho/library/signature.py +335 -0
- yuho/llm/__init__.py +45 -0
- yuho/llm/config.py +75 -0
- yuho/llm/factory.py +123 -0
- yuho/llm/prompts.py +146 -0
- yuho/llm/providers.py +383 -0
- yuho/llm/utils.py +470 -0
- yuho/lsp/__init__.py +14 -0
- yuho/lsp/code_action_handler.py +518 -0
- yuho/lsp/completion_handler.py +85 -0
- yuho/lsp/diagnostics.py +100 -0
- yuho/lsp/hover_handler.py +130 -0
- yuho/lsp/server.py +1425 -0
- yuho/mcp/__init__.py +10 -0
- yuho/mcp/server.py +1452 -0
- yuho/parser/__init__.py +8 -0
- yuho/parser/source_location.py +108 -0
- yuho/parser/wrapper.py +311 -0
- yuho/testing/__init__.py +48 -0
- yuho/testing/coverage.py +274 -0
- yuho/testing/fixtures.py +263 -0
- yuho/transpile/__init__.py +52 -0
- yuho/transpile/alloy_transpiler.py +546 -0
- yuho/transpile/base.py +100 -0
- yuho/transpile/blocks_transpiler.py +338 -0
- yuho/transpile/english_transpiler.py +470 -0
- yuho/transpile/graphql_transpiler.py +404 -0
- yuho/transpile/json_transpiler.py +217 -0
- yuho/transpile/jsonld_transpiler.py +250 -0
- yuho/transpile/latex_preamble.py +161 -0
- yuho/transpile/latex_transpiler.py +406 -0
- yuho/transpile/latex_utils.py +206 -0
- yuho/transpile/mermaid_transpiler.py +357 -0
- yuho/transpile/registry.py +275 -0
- yuho/verify/__init__.py +43 -0
- yuho/verify/alloy.py +352 -0
- yuho/verify/combined.py +218 -0
- yuho/verify/z3_solver.py +1155 -0
- yuho-5.0.0.dist-info/METADATA +186 -0
- yuho-5.0.0.dist-info/RECORD +91 -0
- yuho-5.0.0.dist-info/WHEEL +4 -0
- yuho-5.0.0.dist-info/entry_points.txt +2 -0
yuho/llm/prompts.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prompt templates for LLM interactions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
STATUTE_EXPLANATION_PROMPT = """You are a legal expert explaining statute provisions to a general audience.
|
|
6
|
+
|
|
7
|
+
Your task is to explain the following legal statute in clear, accessible language that a non-lawyer could understand. Keep the legal accuracy but make it readable.
|
|
8
|
+
|
|
9
|
+
Guidelines:
|
|
10
|
+
1. Start with a one-sentence summary of what the statute covers
|
|
11
|
+
2. Explain each element of the offense in plain language
|
|
12
|
+
3. Describe the penalties in concrete terms
|
|
13
|
+
4. Provide a simple example to illustrate (if applicable)
|
|
14
|
+
5. Note any important exceptions or defenses
|
|
15
|
+
|
|
16
|
+
Yuho Statute Code:
|
|
17
|
+
```
|
|
18
|
+
{yuho_code}
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
Structured Analysis:
|
|
22
|
+
{english_explanation}
|
|
23
|
+
|
|
24
|
+
Please provide a clear, accessible explanation:
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
STATUTE_TO_YUHO_PROMPT = """You are an expert in the Yuho legal DSL (domain-specific language). Your task is to convert natural language statute text into valid Yuho code.
|
|
28
|
+
|
|
29
|
+
## Yuho Grammar Summary
|
|
30
|
+
|
|
31
|
+
A Yuho statute has this structure:
|
|
32
|
+
|
|
33
|
+
```
|
|
34
|
+
statute SECTION_NUMBER "Title" {
|
|
35
|
+
definitions {
|
|
36
|
+
term := "definition";
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
elements {
|
|
40
|
+
actus_reus element_name := "physical act description";
|
|
41
|
+
mens_rea element_name := "mental state description";
|
|
42
|
+
circumstance element_name := "circumstantial requirements";
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
penalty {
|
|
46
|
+
imprisonment := DURATION;
|
|
47
|
+
fine := MONEY_AMOUNT;
|
|
48
|
+
supplementary := "additional punishment";
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
illustration LABEL {
|
|
52
|
+
"Example scenario"
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
## Types
|
|
58
|
+
- int: integers (e.g., 42)
|
|
59
|
+
- float: decimals (e.g., 3.14)
|
|
60
|
+
- bool: TRUE or FALSE
|
|
61
|
+
- string: "text in quotes"
|
|
62
|
+
- money: $100.00 or SGD1000
|
|
63
|
+
- percent: 50%
|
|
64
|
+
- date: 2024-01-15 (ISO format)
|
|
65
|
+
- duration: 3 years, 6 months, 10 days
|
|
66
|
+
|
|
67
|
+
## Match Expressions
|
|
68
|
+
For conditional elements:
|
|
69
|
+
```
|
|
70
|
+
match (condition) {
|
|
71
|
+
case pattern := consequence result;
|
|
72
|
+
case _ := consequence default_result;
|
|
73
|
+
}
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## Example Statute
|
|
77
|
+
|
|
78
|
+
```
|
|
79
|
+
statute 415 "Cheating" {
|
|
80
|
+
definitions {
|
|
81
|
+
deceive := "to induce a person to believe as true something which is false";
|
|
82
|
+
dishonestly := "with intention to cause wrongful gain or loss";
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
elements {
|
|
86
|
+
actus_reus deception := "deceives any person";
|
|
87
|
+
actus_reus inducement := "fraudulently or dishonestly induces the person deceived to deliver any property";
|
|
88
|
+
mens_rea intention := "intention to cause wrongful gain or loss";
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
penalty {
|
|
92
|
+
imprisonment := 0 years .. 10 years;
|
|
93
|
+
fine := $0 .. $50,000;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Natural Language Statute to Convert
|
|
99
|
+
|
|
100
|
+
{statute_text}
|
|
101
|
+
|
|
102
|
+
## Instructions
|
|
103
|
+
1. Identify the section number and title
|
|
104
|
+
2. Extract key definitions
|
|
105
|
+
3. Identify actus reus (physical) elements
|
|
106
|
+
4. Identify mens rea (mental) elements
|
|
107
|
+
5. Parse the penalty provisions
|
|
108
|
+
6. Generate valid Yuho code
|
|
109
|
+
|
|
110
|
+
Please output only the Yuho code, no explanations:
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
ANALYZE_COVERAGE_PROMPT = """You are a legal analyst reviewing statute coverage.
|
|
114
|
+
|
|
115
|
+
Given the following Yuho statutes, analyze:
|
|
116
|
+
1. What offenses are covered?
|
|
117
|
+
2. Are there any gaps in coverage (common offenses not addressed)?
|
|
118
|
+
3. Are the penalty ranges consistent across similar offenses?
|
|
119
|
+
4. Any overlapping definitions that might cause ambiguity?
|
|
120
|
+
|
|
121
|
+
Statutes:
|
|
122
|
+
{statutes_json}
|
|
123
|
+
|
|
124
|
+
Please provide your analysis:
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
COMPARE_STATUTES_PROMPT = """You are a legal comparator.
|
|
128
|
+
|
|
129
|
+
Compare the following two statute representations:
|
|
130
|
+
|
|
131
|
+
Original Natural Language:
|
|
132
|
+
{original_text}
|
|
133
|
+
|
|
134
|
+
Yuho Code:
|
|
135
|
+
{yuho_code}
|
|
136
|
+
|
|
137
|
+
English Transpilation:
|
|
138
|
+
{english_output}
|
|
139
|
+
|
|
140
|
+
Analyze:
|
|
141
|
+
1. Does the Yuho code accurately capture all elements?
|
|
142
|
+
2. Are there any discrepancies between the original and the transpilation?
|
|
143
|
+
3. Are any elements missing or misrepresented?
|
|
144
|
+
|
|
145
|
+
Please provide your comparison:
|
|
146
|
+
"""
|
yuho/llm/providers.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM provider implementations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Optional, Iterator
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from yuho.llm.config import LLMConfig
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMProvider(ABC):
|
|
15
|
+
"""
|
|
16
|
+
Abstract base class for LLM providers.
|
|
17
|
+
|
|
18
|
+
All providers must implement the generate() method.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: LLMConfig):
|
|
22
|
+
self.config = config
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Generate text from a prompt.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
prompt: The input prompt
|
|
31
|
+
max_tokens: Maximum tokens to generate (uses config default if None)
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Generated text
|
|
35
|
+
"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def stream(self, prompt: str, max_tokens: Optional[int] = None) -> Iterator[str]:
|
|
39
|
+
"""
|
|
40
|
+
Stream generated text token by token.
|
|
41
|
+
|
|
42
|
+
Default implementation just yields the full response.
|
|
43
|
+
Override for true streaming support.
|
|
44
|
+
"""
|
|
45
|
+
yield self.generate(prompt, max_tokens)
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def is_available(self) -> bool:
|
|
49
|
+
"""Check if the provider is available and configured."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class OllamaProvider(LLMProvider):
|
|
54
|
+
"""
|
|
55
|
+
Ollama provider for local LLM inference.
|
|
56
|
+
|
|
57
|
+
Uses the Ollama HTTP API at configurable host:port.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, config: LLMConfig):
|
|
61
|
+
super().__init__(config)
|
|
62
|
+
self._client = None
|
|
63
|
+
|
|
64
|
+
def _get_client(self):
|
|
65
|
+
"""Lazy-initialize HTTP client."""
|
|
66
|
+
if self._client is None:
|
|
67
|
+
try:
|
|
68
|
+
import httpx
|
|
69
|
+
self._client = httpx.Client(
|
|
70
|
+
base_url=self.config.ollama_url,
|
|
71
|
+
timeout=120.0,
|
|
72
|
+
)
|
|
73
|
+
except ImportError:
|
|
74
|
+
raise ImportError("httpx required for Ollama. Install with: pip install httpx")
|
|
75
|
+
return self._client
|
|
76
|
+
|
|
77
|
+
def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
78
|
+
"""Generate text using Ollama API."""
|
|
79
|
+
client = self._get_client()
|
|
80
|
+
|
|
81
|
+
max_tokens = max_tokens or self.config.max_tokens
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
response = client.post(
|
|
85
|
+
"/api/generate",
|
|
86
|
+
json={
|
|
87
|
+
"model": self.config.model_name,
|
|
88
|
+
"prompt": prompt,
|
|
89
|
+
"stream": False,
|
|
90
|
+
"options": {
|
|
91
|
+
"num_predict": max_tokens,
|
|
92
|
+
"temperature": self.config.temperature,
|
|
93
|
+
},
|
|
94
|
+
},
|
|
95
|
+
)
|
|
96
|
+
response.raise_for_status()
|
|
97
|
+
data = response.json()
|
|
98
|
+
return data.get("response", "")
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.error(f"Ollama generation error: {e}")
|
|
102
|
+
raise RuntimeError(f"Ollama error: {e}") from e
|
|
103
|
+
|
|
104
|
+
def stream(self, prompt: str, max_tokens: Optional[int] = None) -> Iterator[str]:
|
|
105
|
+
"""Stream generated text from Ollama."""
|
|
106
|
+
client = self._get_client()
|
|
107
|
+
|
|
108
|
+
max_tokens = max_tokens or self.config.max_tokens
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
with client.stream(
|
|
112
|
+
"POST",
|
|
113
|
+
"/api/generate",
|
|
114
|
+
json={
|
|
115
|
+
"model": self.config.model_name,
|
|
116
|
+
"prompt": prompt,
|
|
117
|
+
"stream": True,
|
|
118
|
+
"options": {
|
|
119
|
+
"num_predict": max_tokens,
|
|
120
|
+
"temperature": self.config.temperature,
|
|
121
|
+
},
|
|
122
|
+
},
|
|
123
|
+
) as response:
|
|
124
|
+
response.raise_for_status()
|
|
125
|
+
import json as json_module
|
|
126
|
+
for line in response.iter_lines():
|
|
127
|
+
if line:
|
|
128
|
+
data = json_module.loads(line)
|
|
129
|
+
if "response" in data:
|
|
130
|
+
yield data["response"]
|
|
131
|
+
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.error(f"Ollama streaming error: {e}")
|
|
134
|
+
raise RuntimeError(f"Ollama error: {e}") from e
|
|
135
|
+
|
|
136
|
+
def is_available(self) -> bool:
|
|
137
|
+
"""Check if Ollama server is available."""
|
|
138
|
+
try:
|
|
139
|
+
client = self._get_client()
|
|
140
|
+
response = client.get("/api/tags")
|
|
141
|
+
return response.status_code == 200
|
|
142
|
+
except Exception:
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
def check_model(self) -> bool:
|
|
146
|
+
"""Check if the configured model is available."""
|
|
147
|
+
try:
|
|
148
|
+
client = self._get_client()
|
|
149
|
+
response = client.get("/api/tags")
|
|
150
|
+
if response.status_code != 200:
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
data = response.json()
|
|
154
|
+
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
|
|
155
|
+
return self.config.model_name.split(":")[0] in models
|
|
156
|
+
|
|
157
|
+
except Exception:
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
def pull_model(self) -> bool:
|
|
161
|
+
"""Pull the configured model if missing."""
|
|
162
|
+
try:
|
|
163
|
+
client = self._get_client()
|
|
164
|
+
logger.info(f"Pulling Ollama model: {self.config.model_name}")
|
|
165
|
+
|
|
166
|
+
response = client.post(
|
|
167
|
+
"/api/pull",
|
|
168
|
+
json={"name": self.config.model_name},
|
|
169
|
+
timeout=600.0, # 10 minute timeout for model download
|
|
170
|
+
)
|
|
171
|
+
return response.status_code == 200
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(f"Failed to pull model: {e}")
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class HuggingFaceProvider(LLMProvider):
|
|
179
|
+
"""
|
|
180
|
+
HuggingFace Transformers provider for local inference.
|
|
181
|
+
|
|
182
|
+
Supports automatic model download and device selection.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self, config: LLMConfig):
|
|
186
|
+
super().__init__(config)
|
|
187
|
+
self._model = None
|
|
188
|
+
self._tokenizer = None
|
|
189
|
+
self._device = None
|
|
190
|
+
|
|
191
|
+
def _load_model(self):
|
|
192
|
+
"""Lazy-load the model and tokenizer."""
|
|
193
|
+
if self._model is not None:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
import torch
|
|
198
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
199
|
+
|
|
200
|
+
# Device selection
|
|
201
|
+
if torch.cuda.is_available():
|
|
202
|
+
self._device = "cuda"
|
|
203
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
204
|
+
self._device = "mps"
|
|
205
|
+
else:
|
|
206
|
+
self._device = "cpu"
|
|
207
|
+
|
|
208
|
+
logger.info(f"Loading HuggingFace model {self.config.model_name} on {self._device}")
|
|
209
|
+
|
|
210
|
+
# Set cache directory if configured
|
|
211
|
+
cache_dir = self.config.huggingface_cache
|
|
212
|
+
|
|
213
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
214
|
+
self.config.model_name,
|
|
215
|
+
cache_dir=cache_dir,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
219
|
+
self.config.model_name,
|
|
220
|
+
cache_dir=cache_dir,
|
|
221
|
+
torch_dtype=torch.float16 if self._device != "cpu" else torch.float32,
|
|
222
|
+
device_map="auto" if self._device == "cuda" else None,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if self._device != "cuda":
|
|
226
|
+
self._model = self._model.to(self._device)
|
|
227
|
+
|
|
228
|
+
except ImportError:
|
|
229
|
+
raise ImportError(
|
|
230
|
+
"transformers and torch required for HuggingFace. "
|
|
231
|
+
"Install with: pip install transformers torch"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
235
|
+
"""Generate text using HuggingFace Transformers."""
|
|
236
|
+
self._load_model()
|
|
237
|
+
|
|
238
|
+
max_tokens = max_tokens or self.config.max_tokens
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
import torch
|
|
242
|
+
|
|
243
|
+
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._device)
|
|
244
|
+
|
|
245
|
+
with torch.no_grad():
|
|
246
|
+
outputs = self._model.generate(
|
|
247
|
+
**inputs,
|
|
248
|
+
max_new_tokens=max_tokens,
|
|
249
|
+
temperature=self.config.temperature,
|
|
250
|
+
do_sample=self.config.temperature > 0,
|
|
251
|
+
pad_token_id=self._tokenizer.eos_token_id,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Decode only the new tokens
|
|
255
|
+
generated = self._tokenizer.decode(
|
|
256
|
+
outputs[0][inputs["input_ids"].shape[1]:],
|
|
257
|
+
skip_special_tokens=True,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
return generated
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
logger.error(f"HuggingFace generation error: {e}")
|
|
264
|
+
raise RuntimeError(f"HuggingFace error: {e}") from e
|
|
265
|
+
|
|
266
|
+
def is_available(self) -> bool:
|
|
267
|
+
"""Check if transformers and torch are available."""
|
|
268
|
+
try:
|
|
269
|
+
import torch
|
|
270
|
+
from transformers import AutoModelForCausalLM
|
|
271
|
+
return True
|
|
272
|
+
except ImportError:
|
|
273
|
+
return False
|
|
274
|
+
|
|
275
|
+
def download_model(self) -> bool:
|
|
276
|
+
"""Download and cache the model."""
|
|
277
|
+
try:
|
|
278
|
+
from huggingface_hub import snapshot_download
|
|
279
|
+
|
|
280
|
+
cache_dir = self.config.huggingface_cache
|
|
281
|
+
logger.info(f"Downloading HuggingFace model: {self.config.model_name}")
|
|
282
|
+
|
|
283
|
+
snapshot_download(
|
|
284
|
+
self.config.model_name,
|
|
285
|
+
cache_dir=cache_dir,
|
|
286
|
+
)
|
|
287
|
+
return True
|
|
288
|
+
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.error(f"Failed to download model: {e}")
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class OpenAIProvider(LLMProvider):
|
|
295
|
+
"""
|
|
296
|
+
OpenAI API provider.
|
|
297
|
+
|
|
298
|
+
Also works with OpenAI-compatible APIs via base_url.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(self, config: LLMConfig):
|
|
302
|
+
super().__init__(config)
|
|
303
|
+
self._client = None
|
|
304
|
+
|
|
305
|
+
def _get_client(self):
|
|
306
|
+
"""Lazy-initialize OpenAI client."""
|
|
307
|
+
if self._client is None:
|
|
308
|
+
try:
|
|
309
|
+
from openai import OpenAI
|
|
310
|
+
|
|
311
|
+
self._client = OpenAI(
|
|
312
|
+
api_key=self.config.api_key,
|
|
313
|
+
base_url=self.config.base_url,
|
|
314
|
+
)
|
|
315
|
+
except ImportError:
|
|
316
|
+
raise ImportError("openai required. Install with: pip install openai")
|
|
317
|
+
return self._client
|
|
318
|
+
|
|
319
|
+
def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
320
|
+
"""Generate text using OpenAI API."""
|
|
321
|
+
client = self._get_client()
|
|
322
|
+
|
|
323
|
+
max_tokens = max_tokens or self.config.max_tokens
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
response = client.chat.completions.create(
|
|
327
|
+
model=self.config.model_name,
|
|
328
|
+
messages=[{"role": "user", "content": prompt}],
|
|
329
|
+
max_tokens=max_tokens,
|
|
330
|
+
temperature=self.config.temperature,
|
|
331
|
+
)
|
|
332
|
+
return response.choices[0].message.content or ""
|
|
333
|
+
|
|
334
|
+
except Exception as e:
|
|
335
|
+
logger.error(f"OpenAI generation error: {e}")
|
|
336
|
+
raise RuntimeError(f"OpenAI error: {e}") from e
|
|
337
|
+
|
|
338
|
+
def is_available(self) -> bool:
|
|
339
|
+
"""Check if API key is configured."""
|
|
340
|
+
return bool(self.config.api_key)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class AnthropicProvider(LLMProvider):
|
|
344
|
+
"""
|
|
345
|
+
Anthropic API provider.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def __init__(self, config: LLMConfig):
|
|
349
|
+
super().__init__(config)
|
|
350
|
+
self._client = None
|
|
351
|
+
|
|
352
|
+
def _get_client(self):
|
|
353
|
+
"""Lazy-initialize Anthropic client."""
|
|
354
|
+
if self._client is None:
|
|
355
|
+
try:
|
|
356
|
+
from anthropic import Anthropic
|
|
357
|
+
|
|
358
|
+
self._client = Anthropic(api_key=self.config.api_key)
|
|
359
|
+
except ImportError:
|
|
360
|
+
raise ImportError("anthropic required. Install with: pip install anthropic")
|
|
361
|
+
return self._client
|
|
362
|
+
|
|
363
|
+
def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
364
|
+
"""Generate text using Anthropic API."""
|
|
365
|
+
client = self._get_client()
|
|
366
|
+
|
|
367
|
+
max_tokens = max_tokens or self.config.max_tokens
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
response = client.messages.create(
|
|
371
|
+
model=self.config.model_name,
|
|
372
|
+
max_tokens=max_tokens,
|
|
373
|
+
messages=[{"role": "user", "content": prompt}],
|
|
374
|
+
)
|
|
375
|
+
return response.content[0].text
|
|
376
|
+
|
|
377
|
+
except Exception as e:
|
|
378
|
+
logger.error(f"Anthropic generation error: {e}")
|
|
379
|
+
raise RuntimeError(f"Anthropic error: {e}") from e
|
|
380
|
+
|
|
381
|
+
def is_available(self) -> bool:
|
|
382
|
+
"""Check if API key is configured."""
|
|
383
|
+
return bool(self.config.api_key)
|