eval-ai-library 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eval-ai-library might be problematic. Click here for more details.
- eval_ai_library-0.3.0.dist-info/METADATA +1042 -0
- eval_ai_library-0.3.0.dist-info/RECORD +34 -0
- eval_lib/__init__.py +19 -6
- eval_lib/agent_metrics/knowledge_retention_metric/knowledge_retention.py +8 -3
- eval_lib/agent_metrics/role_adherence_metric/role_adherence.py +12 -4
- eval_lib/agent_metrics/task_success_metric/task_success_rate.py +23 -23
- eval_lib/agent_metrics/tools_correctness_metric/tool_correctness.py +8 -2
- eval_lib/datagenerator/datagenerator.py +208 -12
- eval_lib/datagenerator/document_loader.py +29 -29
- eval_lib/evaluate.py +0 -22
- eval_lib/llm_client.py +223 -78
- eval_lib/metric_pattern.py +208 -152
- eval_lib/metrics/answer_precision_metric/answer_precision.py +8 -3
- eval_lib/metrics/answer_relevancy_metric/answer_relevancy.py +7 -2
- eval_lib/metrics/bias_metric/bias.py +12 -2
- eval_lib/metrics/contextual_precision_metric/contextual_precision.py +9 -4
- eval_lib/metrics/contextual_recall_metric/contextual_recall.py +7 -3
- eval_lib/metrics/contextual_relevancy_metric/contextual_relevancy.py +8 -2
- eval_lib/metrics/custom_metric/custom_eval.py +237 -204
- eval_lib/metrics/faithfulness_metric/faithfulness.py +7 -2
- eval_lib/metrics/geval/geval.py +8 -2
- eval_lib/metrics/restricted_refusal_metric/restricted_refusal.py +7 -3
- eval_lib/metrics/toxicity_metric/toxicity.py +8 -2
- eval_lib/utils.py +44 -29
- eval_ai_library-0.2.2.dist-info/METADATA +0 -779
- eval_ai_library-0.2.2.dist-info/RECORD +0 -34
- {eval_ai_library-0.2.2.dist-info → eval_ai_library-0.3.0.dist-info}/WHEEL +0 -0
- {eval_ai_library-0.2.2.dist-info → eval_ai_library-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {eval_ai_library-0.2.2.dist-info → eval_ai_library-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -2,22 +2,41 @@
|
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import List
|
|
5
|
-
|
|
6
5
|
from langchain_core.documents import Document
|
|
7
6
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
8
|
-
|
|
9
|
-
# LangChain loaders (оставляем существующие)
|
|
10
7
|
from langchain_community.document_loaders import PyPDFLoader
|
|
11
8
|
from langchain_community.document_loaders import Docx2txtLoader
|
|
12
9
|
from langchain_community.document_loaders import TextLoader
|
|
13
10
|
|
|
14
11
|
import html2text
|
|
15
12
|
import markdown
|
|
16
|
-
|
|
13
|
+
from pptx import Presentation
|
|
14
|
+
from striprtf.striprtf import rtf_to_text
|
|
15
|
+
from pypdf import PdfReader
|
|
16
|
+
import fitz
|
|
17
|
+
import zipfile
|
|
18
|
+
from xml.etree import ElementTree as ET
|
|
19
|
+
import pandas as pd
|
|
20
|
+
import yaml
|
|
21
|
+
import pytesseract
|
|
22
|
+
from PIL import Image
|
|
23
|
+
import io as _io
|
|
24
|
+
import pytesseract
|
|
25
|
+
from PIL import Image
|
|
26
|
+
import io as _io
|
|
27
|
+
import docx # python-docx
|
|
28
|
+
import mammoth
|
|
17
29
|
import io
|
|
18
30
|
import json
|
|
19
31
|
import zipfile
|
|
20
32
|
|
|
33
|
+
try:
|
|
34
|
+
import textract
|
|
35
|
+
HAS_TEXTRACT = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
HAS_TEXTRACT = False
|
|
38
|
+
textract = None
|
|
39
|
+
|
|
21
40
|
# ---------------------------
|
|
22
41
|
# Helper functions
|
|
23
42
|
# ---------------------------
|
|
@@ -33,7 +52,7 @@ def _read_bytes(p: Path) -> bytes:
|
|
|
33
52
|
|
|
34
53
|
def _csv_tsv_to_text(p: Path) -> str:
|
|
35
54
|
try:
|
|
36
|
-
|
|
55
|
+
|
|
37
56
|
sep = "," if p.suffix.lower() == ".csv" else "\t"
|
|
38
57
|
df = pd.read_csv(str(p), dtype=str, sep=sep,
|
|
39
58
|
encoding="utf-8", engine="python")
|
|
@@ -50,7 +69,6 @@ def _csv_tsv_to_text(p: Path) -> str:
|
|
|
50
69
|
|
|
51
70
|
def _xlsx_to_text(p: Path) -> str:
|
|
52
71
|
try:
|
|
53
|
-
import pandas as pd
|
|
54
72
|
df = pd.read_excel(str(p), dtype=str, engine="openpyxl")
|
|
55
73
|
df = df.fillna("")
|
|
56
74
|
buf = io.StringIO()
|
|
@@ -62,7 +80,6 @@ def _xlsx_to_text(p: Path) -> str:
|
|
|
62
80
|
|
|
63
81
|
def _pptx_to_text(p: Path) -> str:
|
|
64
82
|
try:
|
|
65
|
-
from pptx import Presentation
|
|
66
83
|
prs = Presentation(str(p))
|
|
67
84
|
texts = []
|
|
68
85
|
for slide in prs.slides:
|
|
@@ -96,7 +113,6 @@ def _json_to_text(p: Path) -> str:
|
|
|
96
113
|
|
|
97
114
|
def _yaml_to_text(p: Path) -> str:
|
|
98
115
|
try:
|
|
99
|
-
import yaml
|
|
100
116
|
data = yaml.safe_load(_read_text(p))
|
|
101
117
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
|
102
118
|
except Exception:
|
|
@@ -105,7 +121,6 @@ def _yaml_to_text(p: Path) -> str:
|
|
|
105
121
|
|
|
106
122
|
def _xml_to_text(p: Path) -> str:
|
|
107
123
|
try:
|
|
108
|
-
from xml.etree import ElementTree as ET
|
|
109
124
|
tree = ET.parse(str(p))
|
|
110
125
|
root = tree.getroot()
|
|
111
126
|
lines = []
|
|
@@ -125,7 +140,7 @@ def _xml_to_text(p: Path) -> str:
|
|
|
125
140
|
|
|
126
141
|
def _rtf_to_text(p: Path) -> str:
|
|
127
142
|
try:
|
|
128
|
-
|
|
143
|
+
|
|
129
144
|
return rtf_to_text(_read_text(p))
|
|
130
145
|
except Exception:
|
|
131
146
|
return ""
|
|
@@ -153,7 +168,6 @@ def _odt_to_text(p: Path) -> str:
|
|
|
153
168
|
|
|
154
169
|
def _pdf_text_pypdf(p: Path) -> str:
|
|
155
170
|
try:
|
|
156
|
-
from pypdf import PdfReader # <- именно pypdf
|
|
157
171
|
reader = PdfReader(str(p))
|
|
158
172
|
texts = []
|
|
159
173
|
for page in reader.pages:
|
|
@@ -167,7 +181,6 @@ def _pdf_text_pypdf(p: Path) -> str:
|
|
|
167
181
|
|
|
168
182
|
def _pdf_text_pymupdf(p: Path) -> str:
|
|
169
183
|
try:
|
|
170
|
-
import fitz # PyMuPDF
|
|
171
184
|
text_parts = []
|
|
172
185
|
with fitz.open(str(p)) as doc:
|
|
173
186
|
for page in doc:
|
|
@@ -182,11 +195,6 @@ def _pdf_text_pymupdf(p: Path) -> str:
|
|
|
182
195
|
def _pdf_ocr_via_pymupdf(p: Path) -> str:
|
|
183
196
|
"""Render pages via PyMuPDF and OCR pytesseract. Will work if pytesseract + tesseract are installed."""
|
|
184
197
|
try:
|
|
185
|
-
import fitz # PyMuPDF
|
|
186
|
-
import pytesseract
|
|
187
|
-
from PIL import Image
|
|
188
|
-
import io as _io
|
|
189
|
-
|
|
190
198
|
texts = []
|
|
191
199
|
zoom = 2.0
|
|
192
200
|
mat = fitz.Matrix(zoom, zoom)
|
|
@@ -208,9 +216,6 @@ def _pdf_ocr_via_pymupdf(p: Path) -> str:
|
|
|
208
216
|
|
|
209
217
|
def _ocr_image_bytes(img_bytes: bytes) -> str:
|
|
210
218
|
try:
|
|
211
|
-
import pytesseract
|
|
212
|
-
from PIL import Image
|
|
213
|
-
import io as _io
|
|
214
219
|
img = Image.open(_io.BytesIO(img_bytes))
|
|
215
220
|
return pytesseract.image_to_string(img) or ""
|
|
216
221
|
except Exception:
|
|
@@ -223,7 +228,6 @@ def _ocr_image_bytes(img_bytes: bytes) -> str:
|
|
|
223
228
|
|
|
224
229
|
def _docx_to_text_python_docx(p: Path) -> str:
|
|
225
230
|
try:
|
|
226
|
-
import docx # python-docx
|
|
227
231
|
d = docx.Document(str(p))
|
|
228
232
|
parts = []
|
|
229
233
|
for para in d.paragraphs:
|
|
@@ -242,7 +246,6 @@ def _docx_to_text_python_docx(p: Path) -> str:
|
|
|
242
246
|
|
|
243
247
|
def _docx_to_text_mammoth(p: Path) -> str:
|
|
244
248
|
try:
|
|
245
|
-
import mammoth
|
|
246
249
|
with open(str(p), "rb") as f:
|
|
247
250
|
result = mammoth.extract_raw_text(f)
|
|
248
251
|
return (result.value or "").strip()
|
|
@@ -253,19 +256,16 @@ def _docx_to_text_mammoth(p: Path) -> str:
|
|
|
253
256
|
def _docx_to_text_zipxml(p: Path) -> str:
|
|
254
257
|
"""Без зависимостей: читаем word/document.xml и вытаскиваем все w:t."""
|
|
255
258
|
try:
|
|
256
|
-
|
|
257
|
-
from xml.etree import ElementTree as ET
|
|
259
|
+
|
|
258
260
|
texts = []
|
|
259
261
|
with zipfile.ZipFile(str(p)) as z:
|
|
260
|
-
# основной документ
|
|
261
262
|
if "word/document.xml" in z.namelist():
|
|
262
263
|
with z.open("word/document.xml") as f:
|
|
263
264
|
root = ET.parse(f).getroot()
|
|
264
265
|
for el in root.iter():
|
|
265
|
-
tag = el.tag.rsplit("}", 1)[-1]
|
|
266
|
+
tag = el.tag.rsplit("}", 1)[-1]
|
|
266
267
|
if tag == "t" and el.text and el.text.strip():
|
|
267
268
|
texts.append(el.text.strip())
|
|
268
|
-
# заголовки/футеры тоже могут содержать текст
|
|
269
269
|
for name in z.namelist():
|
|
270
270
|
if name.startswith("word/header") and name.endswith(".xml"):
|
|
271
271
|
with z.open(name) as f:
|
|
@@ -287,9 +287,9 @@ def _docx_to_text_zipxml(p: Path) -> str:
|
|
|
287
287
|
|
|
288
288
|
|
|
289
289
|
def _doc_to_text_textract(p: Path) -> str:
|
|
290
|
-
|
|
290
|
+
if not HAS_TEXTRACT:
|
|
291
|
+
return ""
|
|
291
292
|
try:
|
|
292
|
-
import textract
|
|
293
293
|
return textract.process(str(p)).decode("utf-8", errors="ignore")
|
|
294
294
|
except Exception:
|
|
295
295
|
return ""
|
eval_lib/evaluate.py
CHANGED
|
@@ -183,18 +183,6 @@ async def evaluate(
|
|
|
183
183
|
_print_summary(results, total_cost, total_time,
|
|
184
184
|
total_passed, total_tests)
|
|
185
185
|
|
|
186
|
-
# Print detailed results if requested
|
|
187
|
-
if verbose:
|
|
188
|
-
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
|
189
|
-
for idx, (meta, tc_list) in enumerate(results, 1):
|
|
190
|
-
print(f"\n{Colors.DIM}{'─'*70}{Colors.ENDC}")
|
|
191
|
-
print(f"{Colors.BOLD}Test Case {idx}:{Colors.ENDC}")
|
|
192
|
-
for tc in tc_list:
|
|
193
|
-
tc_dict = asdict(tc)
|
|
194
|
-
# Pretty print with indentation
|
|
195
|
-
print(json.dumps(tc_dict, indent=2, ensure_ascii=False))
|
|
196
|
-
print(f"{Colors.DIM}{'─'*70}{Colors.ENDC}\n")
|
|
197
|
-
|
|
198
186
|
return results
|
|
199
187
|
|
|
200
188
|
|
|
@@ -322,14 +310,4 @@ async def evaluate_conversations(
|
|
|
322
310
|
_print_summary(results, total_cost, total_time,
|
|
323
311
|
total_passed, total_conversations)
|
|
324
312
|
|
|
325
|
-
# Print detailed results if requested
|
|
326
|
-
if verbose:
|
|
327
|
-
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
|
328
|
-
for idx, (_, conv_list) in enumerate(results, 1):
|
|
329
|
-
print(f"\n{Colors.DIM}{'─'*70}{Colors.ENDC}")
|
|
330
|
-
print(f"{Colors.BOLD}Conversation {idx}:{Colors.ENDC}")
|
|
331
|
-
for conv in conv_list:
|
|
332
|
-
print(json.dumps(asdict(conv), indent=2, ensure_ascii=False))
|
|
333
|
-
print(f"{Colors.DIM}{'─'*70}{Colors.ENDC}\n")
|
|
334
|
-
|
|
335
313
|
return results
|
eval_lib/llm_client.py
CHANGED
|
@@ -13,6 +13,11 @@ from types import SimpleNamespace
|
|
|
13
13
|
from .price import model_pricing
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
class LLMConfigurationError(Exception):
|
|
17
|
+
"""Raised when LLM client configuration is missing or invalid."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
16
21
|
class Provider(str, Enum):
|
|
17
22
|
OPENAI = "openai"
|
|
18
23
|
AZURE = "azure"
|
|
@@ -45,12 +50,59 @@ class LLMDescriptor:
|
|
|
45
50
|
return f"{self.provider}:{self.model}"
|
|
46
51
|
|
|
47
52
|
|
|
53
|
+
def _check_env_var(var_name: str, provider: str, required: bool = True) -> Optional[str]:
|
|
54
|
+
"""
|
|
55
|
+
Check if environment variable is set and return its value.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
var_name: Name of the environment variable
|
|
59
|
+
provider: Provider name for error message
|
|
60
|
+
required: Whether this variable is required
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Value of the environment variable or None if not required
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
LLMConfigurationError: If required variable is missing
|
|
67
|
+
"""
|
|
68
|
+
value = os.getenv(var_name)
|
|
69
|
+
if required and not value:
|
|
70
|
+
raise LLMConfigurationError(
|
|
71
|
+
f"❌ Missing {provider} configuration!\n\n"
|
|
72
|
+
f"Environment variable '{var_name}' is not set.\n\n"
|
|
73
|
+
f"To fix this, set the environment variable:\n"
|
|
74
|
+
f" export {var_name}='your-api-key-here'\n\n"
|
|
75
|
+
f"Or add it to your .env file:\n"
|
|
76
|
+
f" {var_name}=your-api-key-here\n\n"
|
|
77
|
+
f"📖 Documentation: https://github.com/meshkovQA/Eval-ai-library#environment-variables"
|
|
78
|
+
)
|
|
79
|
+
return value
|
|
80
|
+
|
|
81
|
+
|
|
48
82
|
@functools.cache
|
|
49
83
|
def _get_client(provider: Provider):
|
|
84
|
+
"""
|
|
85
|
+
Get or create LLM client for the specified provider.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
provider: LLM provider enum
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Configured client instance
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
LLMConfigurationError: If required configuration is missing
|
|
95
|
+
ValueError: If provider is not supported
|
|
96
|
+
"""
|
|
50
97
|
if provider == Provider.OPENAI:
|
|
98
|
+
_check_env_var("OPENAI_API_KEY", "OpenAI")
|
|
51
99
|
return openai.AsyncOpenAI()
|
|
52
100
|
|
|
53
101
|
if provider == Provider.AZURE:
|
|
102
|
+
_check_env_var("AZURE_OPENAI_API_KEY", "Azure OpenAI")
|
|
103
|
+
_check_env_var("AZURE_OPENAI_ENDPOINT", "Azure OpenAI")
|
|
104
|
+
# AZURE_OPENAI_DEPLOYMENT проверяется при вызове, не обязателен здесь
|
|
105
|
+
|
|
54
106
|
return AsyncAzureOpenAI(
|
|
55
107
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
|
56
108
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
|
@@ -58,20 +110,28 @@ def _get_client(provider: Provider):
|
|
|
58
110
|
)
|
|
59
111
|
|
|
60
112
|
if provider == Provider.GOOGLE:
|
|
113
|
+
_check_env_var("GOOGLE_API_KEY", "Google Gemini")
|
|
61
114
|
return genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
|
|
62
115
|
|
|
63
116
|
if provider == Provider.OLLAMA:
|
|
117
|
+
# Ollama может работать без ключа (локальный сервер)
|
|
118
|
+
api_key = _check_env_var(
|
|
119
|
+
"OLLAMA_API_KEY", "Ollama", required=False) or "ollama"
|
|
120
|
+
base_url = _check_env_var(
|
|
121
|
+
"OLLAMA_API_BASE_URL", "Ollama", required=False) or "http://localhost:11434/v1"
|
|
122
|
+
|
|
64
123
|
return openai.AsyncOpenAI(
|
|
65
|
-
api_key=
|
|
66
|
-
base_url=
|
|
124
|
+
api_key=api_key,
|
|
125
|
+
base_url=base_url
|
|
67
126
|
)
|
|
68
127
|
|
|
69
128
|
if provider == Provider.ANTHROPIC:
|
|
129
|
+
_check_env_var("ANTHROPIC_API_KEY", "Anthropic Claude")
|
|
70
130
|
return anthropic.AsyncAnthropic(
|
|
71
131
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
|
72
132
|
)
|
|
73
133
|
|
|
74
|
-
raise ValueError(f"Unsupported provider {provider}")
|
|
134
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
75
135
|
|
|
76
136
|
|
|
77
137
|
async def _openai_chat_complete(
|
|
@@ -80,17 +140,25 @@ async def _openai_chat_complete(
|
|
|
80
140
|
messages: list[dict[str, str]],
|
|
81
141
|
temperature: float,
|
|
82
142
|
):
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
143
|
+
"""OpenAI chat completion."""
|
|
144
|
+
try:
|
|
145
|
+
response = await client.chat.completions.create(
|
|
146
|
+
model=llm.model,
|
|
147
|
+
messages=messages,
|
|
148
|
+
temperature=temperature,
|
|
149
|
+
)
|
|
150
|
+
text = response.choices[0].message.content.strip()
|
|
151
|
+
cost = _calculate_cost(llm, response.usage)
|
|
152
|
+
return text, cost
|
|
153
|
+
except Exception as e:
|
|
154
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
155
|
+
raise LLMConfigurationError(
|
|
156
|
+
f"❌ OpenAI API authentication failed!\n\n"
|
|
157
|
+
f"Error: {str(e)}\n\n"
|
|
158
|
+
f"Please check that your OPENAI_API_KEY is valid.\n"
|
|
159
|
+
f"Get your API key at: https://platform.openai.com/api-keys"
|
|
160
|
+
)
|
|
161
|
+
raise
|
|
94
162
|
|
|
95
163
|
|
|
96
164
|
async def _azure_chat_complete(
|
|
@@ -99,17 +167,36 @@ async def _azure_chat_complete(
|
|
|
99
167
|
messages: list[dict[str, str]],
|
|
100
168
|
temperature: float,
|
|
101
169
|
):
|
|
102
|
-
|
|
170
|
+
"""Azure OpenAI chat completion."""
|
|
103
171
|
deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT") or llm.model
|
|
104
172
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
173
|
+
if not deployment_name:
|
|
174
|
+
raise LLMConfigurationError(
|
|
175
|
+
f"❌ Missing Azure OpenAI deployment name!\n\n"
|
|
176
|
+
f"Please set AZURE_OPENAI_DEPLOYMENT environment variable.\n"
|
|
177
|
+
f"Example: export AZURE_OPENAI_DEPLOYMENT='gpt-4o'"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
response = await client.chat.completions.create(
|
|
182
|
+
model=deployment_name,
|
|
183
|
+
messages=messages,
|
|
184
|
+
temperature=temperature,
|
|
185
|
+
)
|
|
186
|
+
text = response.choices[0].message.content.strip()
|
|
187
|
+
cost = _calculate_cost(llm, response.usage)
|
|
188
|
+
return text, cost
|
|
189
|
+
except Exception as e:
|
|
190
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
191
|
+
raise LLMConfigurationError(
|
|
192
|
+
f"❌ Azure OpenAI authentication failed!\n\n"
|
|
193
|
+
f"Error: {str(e)}\n\n"
|
|
194
|
+
f"Please check your Azure OpenAI configuration:\n"
|
|
195
|
+
f" - AZURE_OPENAI_API_KEY\n"
|
|
196
|
+
f" - AZURE_OPENAI_ENDPOINT\n"
|
|
197
|
+
f" - AZURE_OPENAI_DEPLOYMENT"
|
|
198
|
+
)
|
|
199
|
+
raise
|
|
113
200
|
|
|
114
201
|
|
|
115
202
|
async def _google_chat_complete(
|
|
@@ -118,27 +205,35 @@ async def _google_chat_complete(
|
|
|
118
205
|
messages: list[dict[str, str]],
|
|
119
206
|
temperature: float,
|
|
120
207
|
):
|
|
121
|
-
"""
|
|
122
|
-
Google GenAI / Gemini 2.x
|
|
123
|
-
"""
|
|
208
|
+
"""Google GenAI / Gemini chat completion."""
|
|
124
209
|
prompt = "\n".join(m["content"] for m in messages)
|
|
125
210
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
211
|
+
try:
|
|
212
|
+
response = await client.aio.models.generate_content(
|
|
213
|
+
model=llm.model,
|
|
214
|
+
contents=prompt,
|
|
215
|
+
config=GenerateContentConfig(temperature=temperature),
|
|
216
|
+
)
|
|
131
217
|
|
|
132
|
-
|
|
218
|
+
text = response.text.strip()
|
|
133
219
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
220
|
+
um = response.usage_metadata
|
|
221
|
+
usage = SimpleNamespace(
|
|
222
|
+
prompt_tokens=um.prompt_token_count,
|
|
223
|
+
completion_tokens=um.candidates_token_count,
|
|
224
|
+
)
|
|
139
225
|
|
|
140
|
-
|
|
141
|
-
|
|
226
|
+
cost = _calculate_cost(llm, usage)
|
|
227
|
+
return text, cost
|
|
228
|
+
except Exception as e:
|
|
229
|
+
if "API key" in str(e) or "authentication" in str(e).lower() or "credentials" in str(e).lower():
|
|
230
|
+
raise LLMConfigurationError(
|
|
231
|
+
f"❌ Google Gemini API authentication failed!\n\n"
|
|
232
|
+
f"Error: {str(e)}\n\n"
|
|
233
|
+
f"Please check that your GOOGLE_API_KEY is valid.\n"
|
|
234
|
+
f"Get your API key at: https://aistudio.google.com/apikey"
|
|
235
|
+
)
|
|
236
|
+
raise
|
|
142
237
|
|
|
143
238
|
|
|
144
239
|
async def _ollama_chat_complete(
|
|
@@ -147,14 +242,29 @@ async def _ollama_chat_complete(
|
|
|
147
242
|
messages: list[dict[str, str]],
|
|
148
243
|
temperature: float,
|
|
149
244
|
):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
245
|
+
"""Ollama (local) chat completion."""
|
|
246
|
+
try:
|
|
247
|
+
response = await client.chat.completions.create(
|
|
248
|
+
model=llm.model,
|
|
249
|
+
messages=messages,
|
|
250
|
+
temperature=temperature,
|
|
251
|
+
)
|
|
252
|
+
text = response.choices[0].message.content.strip()
|
|
253
|
+
cost = _calculate_cost(llm, response.usage)
|
|
254
|
+
return text, cost
|
|
255
|
+
except Exception as e:
|
|
256
|
+
if "Connection" in str(e) or "refused" in str(e).lower():
|
|
257
|
+
raise LLMConfigurationError(
|
|
258
|
+
f"❌ Cannot connect to Ollama server!\n\n"
|
|
259
|
+
f"Error: {str(e)}\n\n"
|
|
260
|
+
f"Make sure Ollama is running:\n"
|
|
261
|
+
f" 1. Install Ollama: https://ollama.ai/download\n"
|
|
262
|
+
f" 2. Start Ollama: ollama serve\n"
|
|
263
|
+
f" 3. Pull model: ollama pull {llm.model}\n\n"
|
|
264
|
+
f"Or set OLLAMA_API_BASE_URL to your Ollama server:\n"
|
|
265
|
+
f" export OLLAMA_API_BASE_URL='http://localhost:11434/v1'"
|
|
266
|
+
)
|
|
267
|
+
raise
|
|
158
268
|
|
|
159
269
|
|
|
160
270
|
async def _anthropic_chat_complete(
|
|
@@ -163,23 +273,31 @@ async def _anthropic_chat_complete(
|
|
|
163
273
|
messages: list[dict[str, str]],
|
|
164
274
|
temperature: float,
|
|
165
275
|
):
|
|
166
|
-
"""
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
276
|
+
"""Anthropic Claude chat completion."""
|
|
277
|
+
try:
|
|
278
|
+
response = await client.messages.create(
|
|
279
|
+
model=llm.model,
|
|
280
|
+
messages=messages,
|
|
281
|
+
temperature=temperature,
|
|
282
|
+
max_tokens=4096,
|
|
283
|
+
)
|
|
284
|
+
if isinstance(response.content, list):
|
|
285
|
+
text = "".join(
|
|
286
|
+
block.text for block in response.content if block.type == "text").strip()
|
|
287
|
+
else:
|
|
288
|
+
text = response.content.strip()
|
|
289
|
+
|
|
290
|
+
cost = _calculate_cost(llm, response.usage)
|
|
291
|
+
return text, cost
|
|
292
|
+
except Exception as e:
|
|
293
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
294
|
+
raise LLMConfigurationError(
|
|
295
|
+
f"❌ Anthropic Claude API authentication failed!\n\n"
|
|
296
|
+
f"Error: {str(e)}\n\n"
|
|
297
|
+
f"Please check that your ANTHROPIC_API_KEY is valid.\n"
|
|
298
|
+
f"Get your API key at: https://console.anthropic.com/settings/keys"
|
|
299
|
+
)
|
|
300
|
+
raise
|
|
183
301
|
|
|
184
302
|
|
|
185
303
|
_HELPERS = {
|
|
@@ -196,20 +314,33 @@ async def chat_complete(
|
|
|
196
314
|
messages: list[dict[str, str]],
|
|
197
315
|
temperature: float = 0.0,
|
|
198
316
|
):
|
|
317
|
+
"""
|
|
318
|
+
Complete a chat conversation using the specified LLM.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
llm: LLM specification (e.g., "gpt-4o-mini", "openai:gpt-4o", or LLMDescriptor)
|
|
322
|
+
messages: List of message dicts with "role" and "content"
|
|
323
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Tuple of (response_text, cost_in_usd)
|
|
327
|
+
|
|
328
|
+
Raises:
|
|
329
|
+
LLMConfigurationError: If required API keys or configuration are missing
|
|
330
|
+
ValueError: If provider is not supported
|
|
331
|
+
"""
|
|
199
332
|
llm = LLMDescriptor.parse(llm)
|
|
200
333
|
helper = _HELPERS.get(llm.provider)
|
|
201
334
|
|
|
202
335
|
if helper is None:
|
|
203
|
-
raise ValueError(f"Unsupported provider {llm.provider}")
|
|
336
|
+
raise ValueError(f"Unsupported provider: {llm.provider}")
|
|
204
337
|
|
|
205
338
|
client = _get_client(llm.provider)
|
|
206
339
|
return await helper(client, llm, messages, temperature)
|
|
207
340
|
|
|
208
341
|
|
|
209
342
|
def _calculate_cost(llm: LLMDescriptor, usage) -> Optional[float]:
|
|
210
|
-
"""
|
|
211
|
-
Calculate the cost of the LLM usage based on the model and usage data.
|
|
212
|
-
"""
|
|
343
|
+
"""Calculate the cost of the LLM usage based on the model and usage data."""
|
|
213
344
|
if llm.provider == Provider.OLLAMA:
|
|
214
345
|
return 0.0
|
|
215
346
|
if not usage:
|
|
@@ -219,7 +350,7 @@ def _calculate_cost(llm: LLMDescriptor, usage) -> Optional[float]:
|
|
|
219
350
|
if not price:
|
|
220
351
|
return None
|
|
221
352
|
|
|
222
|
-
prompt = getattr(usage, "prompt_tokens",
|
|
353
|
+
prompt = getattr(usage, "prompt_tokens", 0)
|
|
223
354
|
completion = getattr(usage, "completion_tokens", 0)
|
|
224
355
|
|
|
225
356
|
return round(
|
|
@@ -242,6 +373,10 @@ async def get_embeddings(
|
|
|
242
373
|
|
|
243
374
|
Returns:
|
|
244
375
|
Tuple of (embeddings_list, total_cost)
|
|
376
|
+
|
|
377
|
+
Raises:
|
|
378
|
+
LLMConfigurationError: If required API keys are missing
|
|
379
|
+
ValueError: If non-OpenAI provider is specified
|
|
245
380
|
"""
|
|
246
381
|
llm = LLMDescriptor.parse(model)
|
|
247
382
|
|
|
@@ -259,16 +394,26 @@ async def _openai_get_embeddings(
|
|
|
259
394
|
texts: list[str],
|
|
260
395
|
) -> tuple[list[list[float]], Optional[float]]:
|
|
261
396
|
"""OpenAI embeddings implementation."""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
embeddings = [data.embedding for data in response.data]
|
|
269
|
-
cost = _calculate_embedding_cost(llm, response.usage)
|
|
397
|
+
try:
|
|
398
|
+
response = await client.embeddings.create(
|
|
399
|
+
model=llm.model,
|
|
400
|
+
input=texts,
|
|
401
|
+
encoding_format="float"
|
|
402
|
+
)
|
|
270
403
|
|
|
271
|
-
|
|
404
|
+
embeddings = [data.embedding for data in response.data]
|
|
405
|
+
cost = _calculate_embedding_cost(llm, response.usage)
|
|
406
|
+
|
|
407
|
+
return embeddings, cost
|
|
408
|
+
except Exception as e:
|
|
409
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
410
|
+
raise LLMConfigurationError(
|
|
411
|
+
f"❌ OpenAI API authentication failed for embeddings!\n\n"
|
|
412
|
+
f"Error: {str(e)}\n\n"
|
|
413
|
+
f"Please check that your OPENAI_API_KEY is valid.\n"
|
|
414
|
+
f"Get your API key at: https://platform.openai.com/api-keys"
|
|
415
|
+
)
|
|
416
|
+
raise
|
|
272
417
|
|
|
273
418
|
|
|
274
419
|
def _calculate_embedding_cost(llm: LLMDescriptor, usage) -> Optional[float]:
|