eval-ai-library 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eval-ai-library might be problematic. Click here for more details.
- eval_ai_library-0.3.1.dist-info/METADATA +1042 -0
- eval_ai_library-0.3.1.dist-info/RECORD +34 -0
- eval_lib/__init__.py +19 -6
- eval_lib/agent_metrics/knowledge_retention_metric/knowledge_retention.py +9 -3
- eval_lib/agent_metrics/role_adherence_metric/role_adherence.py +13 -4
- eval_lib/agent_metrics/task_success_metric/task_success_rate.py +24 -23
- eval_lib/agent_metrics/tools_correctness_metric/tool_correctness.py +8 -2
- eval_lib/datagenerator/datagenerator.py +208 -12
- eval_lib/datagenerator/document_loader.py +29 -29
- eval_lib/evaluate.py +0 -22
- eval_lib/llm_client.py +221 -78
- eval_lib/metric_pattern.py +208 -152
- eval_lib/metrics/answer_precision_metric/answer_precision.py +8 -3
- eval_lib/metrics/answer_relevancy_metric/answer_relevancy.py +8 -2
- eval_lib/metrics/bias_metric/bias.py +12 -2
- eval_lib/metrics/contextual_precision_metric/contextual_precision.py +9 -4
- eval_lib/metrics/contextual_recall_metric/contextual_recall.py +7 -3
- eval_lib/metrics/contextual_relevancy_metric/contextual_relevancy.py +9 -2
- eval_lib/metrics/custom_metric/custom_eval.py +238 -204
- eval_lib/metrics/faithfulness_metric/faithfulness.py +7 -2
- eval_lib/metrics/geval/geval.py +8 -2
- eval_lib/metrics/restricted_refusal_metric/restricted_refusal.py +7 -3
- eval_lib/metrics/toxicity_metric/toxicity.py +8 -2
- eval_lib/utils.py +44 -29
- eval_ai_library-0.2.2.dist-info/METADATA +0 -779
- eval_ai_library-0.2.2.dist-info/RECORD +0 -34
- {eval_ai_library-0.2.2.dist-info → eval_ai_library-0.3.1.dist-info}/WHEEL +0 -0
- {eval_ai_library-0.2.2.dist-info → eval_ai_library-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {eval_ai_library-0.2.2.dist-info → eval_ai_library-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -2,22 +2,41 @@
|
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import List
|
|
5
|
-
|
|
6
5
|
from langchain_core.documents import Document
|
|
7
6
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
8
|
-
|
|
9
|
-
# LangChain loaders (оставляем существующие)
|
|
10
7
|
from langchain_community.document_loaders import PyPDFLoader
|
|
11
8
|
from langchain_community.document_loaders import Docx2txtLoader
|
|
12
9
|
from langchain_community.document_loaders import TextLoader
|
|
13
10
|
|
|
14
11
|
import html2text
|
|
15
12
|
import markdown
|
|
16
|
-
|
|
13
|
+
from pptx import Presentation
|
|
14
|
+
from striprtf.striprtf import rtf_to_text
|
|
15
|
+
from pypdf import PdfReader
|
|
16
|
+
import fitz
|
|
17
|
+
import zipfile
|
|
18
|
+
from xml.etree import ElementTree as ET
|
|
19
|
+
import pandas as pd
|
|
20
|
+
import yaml
|
|
21
|
+
import pytesseract
|
|
22
|
+
from PIL import Image
|
|
23
|
+
import io as _io
|
|
24
|
+
import pytesseract
|
|
25
|
+
from PIL import Image
|
|
26
|
+
import io as _io
|
|
27
|
+
import docx # python-docx
|
|
28
|
+
import mammoth
|
|
17
29
|
import io
|
|
18
30
|
import json
|
|
19
31
|
import zipfile
|
|
20
32
|
|
|
33
|
+
try:
|
|
34
|
+
import textract
|
|
35
|
+
HAS_TEXTRACT = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
HAS_TEXTRACT = False
|
|
38
|
+
textract = None
|
|
39
|
+
|
|
21
40
|
# ---------------------------
|
|
22
41
|
# Helper functions
|
|
23
42
|
# ---------------------------
|
|
@@ -33,7 +52,7 @@ def _read_bytes(p: Path) -> bytes:
|
|
|
33
52
|
|
|
34
53
|
def _csv_tsv_to_text(p: Path) -> str:
|
|
35
54
|
try:
|
|
36
|
-
|
|
55
|
+
|
|
37
56
|
sep = "," if p.suffix.lower() == ".csv" else "\t"
|
|
38
57
|
df = pd.read_csv(str(p), dtype=str, sep=sep,
|
|
39
58
|
encoding="utf-8", engine="python")
|
|
@@ -50,7 +69,6 @@ def _csv_tsv_to_text(p: Path) -> str:
|
|
|
50
69
|
|
|
51
70
|
def _xlsx_to_text(p: Path) -> str:
|
|
52
71
|
try:
|
|
53
|
-
import pandas as pd
|
|
54
72
|
df = pd.read_excel(str(p), dtype=str, engine="openpyxl")
|
|
55
73
|
df = df.fillna("")
|
|
56
74
|
buf = io.StringIO()
|
|
@@ -62,7 +80,6 @@ def _xlsx_to_text(p: Path) -> str:
|
|
|
62
80
|
|
|
63
81
|
def _pptx_to_text(p: Path) -> str:
|
|
64
82
|
try:
|
|
65
|
-
from pptx import Presentation
|
|
66
83
|
prs = Presentation(str(p))
|
|
67
84
|
texts = []
|
|
68
85
|
for slide in prs.slides:
|
|
@@ -96,7 +113,6 @@ def _json_to_text(p: Path) -> str:
|
|
|
96
113
|
|
|
97
114
|
def _yaml_to_text(p: Path) -> str:
|
|
98
115
|
try:
|
|
99
|
-
import yaml
|
|
100
116
|
data = yaml.safe_load(_read_text(p))
|
|
101
117
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
|
102
118
|
except Exception:
|
|
@@ -105,7 +121,6 @@ def _yaml_to_text(p: Path) -> str:
|
|
|
105
121
|
|
|
106
122
|
def _xml_to_text(p: Path) -> str:
|
|
107
123
|
try:
|
|
108
|
-
from xml.etree import ElementTree as ET
|
|
109
124
|
tree = ET.parse(str(p))
|
|
110
125
|
root = tree.getroot()
|
|
111
126
|
lines = []
|
|
@@ -125,7 +140,7 @@ def _xml_to_text(p: Path) -> str:
|
|
|
125
140
|
|
|
126
141
|
def _rtf_to_text(p: Path) -> str:
|
|
127
142
|
try:
|
|
128
|
-
|
|
143
|
+
|
|
129
144
|
return rtf_to_text(_read_text(p))
|
|
130
145
|
except Exception:
|
|
131
146
|
return ""
|
|
@@ -153,7 +168,6 @@ def _odt_to_text(p: Path) -> str:
|
|
|
153
168
|
|
|
154
169
|
def _pdf_text_pypdf(p: Path) -> str:
|
|
155
170
|
try:
|
|
156
|
-
from pypdf import PdfReader # <- именно pypdf
|
|
157
171
|
reader = PdfReader(str(p))
|
|
158
172
|
texts = []
|
|
159
173
|
for page in reader.pages:
|
|
@@ -167,7 +181,6 @@ def _pdf_text_pypdf(p: Path) -> str:
|
|
|
167
181
|
|
|
168
182
|
def _pdf_text_pymupdf(p: Path) -> str:
|
|
169
183
|
try:
|
|
170
|
-
import fitz # PyMuPDF
|
|
171
184
|
text_parts = []
|
|
172
185
|
with fitz.open(str(p)) as doc:
|
|
173
186
|
for page in doc:
|
|
@@ -182,11 +195,6 @@ def _pdf_text_pymupdf(p: Path) -> str:
|
|
|
182
195
|
def _pdf_ocr_via_pymupdf(p: Path) -> str:
|
|
183
196
|
"""Render pages via PyMuPDF and OCR pytesseract. Will work if pytesseract + tesseract are installed."""
|
|
184
197
|
try:
|
|
185
|
-
import fitz # PyMuPDF
|
|
186
|
-
import pytesseract
|
|
187
|
-
from PIL import Image
|
|
188
|
-
import io as _io
|
|
189
|
-
|
|
190
198
|
texts = []
|
|
191
199
|
zoom = 2.0
|
|
192
200
|
mat = fitz.Matrix(zoom, zoom)
|
|
@@ -208,9 +216,6 @@ def _pdf_ocr_via_pymupdf(p: Path) -> str:
|
|
|
208
216
|
|
|
209
217
|
def _ocr_image_bytes(img_bytes: bytes) -> str:
|
|
210
218
|
try:
|
|
211
|
-
import pytesseract
|
|
212
|
-
from PIL import Image
|
|
213
|
-
import io as _io
|
|
214
219
|
img = Image.open(_io.BytesIO(img_bytes))
|
|
215
220
|
return pytesseract.image_to_string(img) or ""
|
|
216
221
|
except Exception:
|
|
@@ -223,7 +228,6 @@ def _ocr_image_bytes(img_bytes: bytes) -> str:
|
|
|
223
228
|
|
|
224
229
|
def _docx_to_text_python_docx(p: Path) -> str:
|
|
225
230
|
try:
|
|
226
|
-
import docx # python-docx
|
|
227
231
|
d = docx.Document(str(p))
|
|
228
232
|
parts = []
|
|
229
233
|
for para in d.paragraphs:
|
|
@@ -242,7 +246,6 @@ def _docx_to_text_python_docx(p: Path) -> str:
|
|
|
242
246
|
|
|
243
247
|
def _docx_to_text_mammoth(p: Path) -> str:
|
|
244
248
|
try:
|
|
245
|
-
import mammoth
|
|
246
249
|
with open(str(p), "rb") as f:
|
|
247
250
|
result = mammoth.extract_raw_text(f)
|
|
248
251
|
return (result.value or "").strip()
|
|
@@ -253,19 +256,16 @@ def _docx_to_text_mammoth(p: Path) -> str:
|
|
|
253
256
|
def _docx_to_text_zipxml(p: Path) -> str:
|
|
254
257
|
"""Без зависимостей: читаем word/document.xml и вытаскиваем все w:t."""
|
|
255
258
|
try:
|
|
256
|
-
|
|
257
|
-
from xml.etree import ElementTree as ET
|
|
259
|
+
|
|
258
260
|
texts = []
|
|
259
261
|
with zipfile.ZipFile(str(p)) as z:
|
|
260
|
-
# основной документ
|
|
261
262
|
if "word/document.xml" in z.namelist():
|
|
262
263
|
with z.open("word/document.xml") as f:
|
|
263
264
|
root = ET.parse(f).getroot()
|
|
264
265
|
for el in root.iter():
|
|
265
|
-
tag = el.tag.rsplit("}", 1)[-1]
|
|
266
|
+
tag = el.tag.rsplit("}", 1)[-1]
|
|
266
267
|
if tag == "t" and el.text and el.text.strip():
|
|
267
268
|
texts.append(el.text.strip())
|
|
268
|
-
# заголовки/футеры тоже могут содержать текст
|
|
269
269
|
for name in z.namelist():
|
|
270
270
|
if name.startswith("word/header") and name.endswith(".xml"):
|
|
271
271
|
with z.open(name) as f:
|
|
@@ -287,9 +287,9 @@ def _docx_to_text_zipxml(p: Path) -> str:
|
|
|
287
287
|
|
|
288
288
|
|
|
289
289
|
def _doc_to_text_textract(p: Path) -> str:
|
|
290
|
-
|
|
290
|
+
if not HAS_TEXTRACT:
|
|
291
|
+
return ""
|
|
291
292
|
try:
|
|
292
|
-
import textract
|
|
293
293
|
return textract.process(str(p)).decode("utf-8", errors="ignore")
|
|
294
294
|
except Exception:
|
|
295
295
|
return ""
|
eval_lib/evaluate.py
CHANGED
|
@@ -183,18 +183,6 @@ async def evaluate(
|
|
|
183
183
|
_print_summary(results, total_cost, total_time,
|
|
184
184
|
total_passed, total_tests)
|
|
185
185
|
|
|
186
|
-
# Print detailed results if requested
|
|
187
|
-
if verbose:
|
|
188
|
-
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
|
189
|
-
for idx, (meta, tc_list) in enumerate(results, 1):
|
|
190
|
-
print(f"\n{Colors.DIM}{'─'*70}{Colors.ENDC}")
|
|
191
|
-
print(f"{Colors.BOLD}Test Case {idx}:{Colors.ENDC}")
|
|
192
|
-
for tc in tc_list:
|
|
193
|
-
tc_dict = asdict(tc)
|
|
194
|
-
# Pretty print with indentation
|
|
195
|
-
print(json.dumps(tc_dict, indent=2, ensure_ascii=False))
|
|
196
|
-
print(f"{Colors.DIM}{'─'*70}{Colors.ENDC}\n")
|
|
197
|
-
|
|
198
186
|
return results
|
|
199
187
|
|
|
200
188
|
|
|
@@ -322,14 +310,4 @@ async def evaluate_conversations(
|
|
|
322
310
|
_print_summary(results, total_cost, total_time,
|
|
323
311
|
total_passed, total_conversations)
|
|
324
312
|
|
|
325
|
-
# Print detailed results if requested
|
|
326
|
-
if verbose:
|
|
327
|
-
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
|
328
|
-
for idx, (_, conv_list) in enumerate(results, 1):
|
|
329
|
-
print(f"\n{Colors.DIM}{'─'*70}{Colors.ENDC}")
|
|
330
|
-
print(f"{Colors.BOLD}Conversation {idx}:{Colors.ENDC}")
|
|
331
|
-
for conv in conv_list:
|
|
332
|
-
print(json.dumps(asdict(conv), indent=2, ensure_ascii=False))
|
|
333
|
-
print(f"{Colors.DIM}{'─'*70}{Colors.ENDC}\n")
|
|
334
|
-
|
|
335
313
|
return results
|
eval_lib/llm_client.py
CHANGED
|
@@ -13,6 +13,11 @@ from types import SimpleNamespace
|
|
|
13
13
|
from .price import model_pricing
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
class LLMConfigurationError(Exception):
|
|
17
|
+
"""Raised when LLM client configuration is missing or invalid."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
16
21
|
class Provider(str, Enum):
|
|
17
22
|
OPENAI = "openai"
|
|
18
23
|
AZURE = "azure"
|
|
@@ -45,12 +50,58 @@ class LLMDescriptor:
|
|
|
45
50
|
return f"{self.provider}:{self.model}"
|
|
46
51
|
|
|
47
52
|
|
|
53
|
+
def _check_env_var(var_name: str, provider: str, required: bool = True) -> Optional[str]:
|
|
54
|
+
"""
|
|
55
|
+
Check if environment variable is set and return its value.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
var_name: Name of the environment variable
|
|
59
|
+
provider: Provider name for error message
|
|
60
|
+
required: Whether this variable is required
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Value of the environment variable or None if not required
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
LLMConfigurationError: If required variable is missing
|
|
67
|
+
"""
|
|
68
|
+
value = os.getenv(var_name)
|
|
69
|
+
if required and not value:
|
|
70
|
+
raise LLMConfigurationError(
|
|
71
|
+
f"❌ Missing {provider} configuration!\n\n"
|
|
72
|
+
f"Environment variable '{var_name}' is not set.\n\n"
|
|
73
|
+
f"To fix this, set the environment variable:\n"
|
|
74
|
+
f" export {var_name}='your-api-key-here'\n\n"
|
|
75
|
+
f"Or add it to your .env file:\n"
|
|
76
|
+
f" {var_name}=your-api-key-here\n\n"
|
|
77
|
+
f"📖 Documentation: https://github.com/meshkovQA/Eval-ai-library#environment-variables"
|
|
78
|
+
)
|
|
79
|
+
return value
|
|
80
|
+
|
|
81
|
+
|
|
48
82
|
@functools.cache
|
|
49
83
|
def _get_client(provider: Provider):
|
|
84
|
+
"""
|
|
85
|
+
Get or create LLM client for the specified provider.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
provider: LLM provider enum
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Configured client instance
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
LLMConfigurationError: If required configuration is missing
|
|
95
|
+
ValueError: If provider is not supported
|
|
96
|
+
"""
|
|
50
97
|
if provider == Provider.OPENAI:
|
|
98
|
+
_check_env_var("OPENAI_API_KEY", "OpenAI")
|
|
51
99
|
return openai.AsyncOpenAI()
|
|
52
100
|
|
|
53
101
|
if provider == Provider.AZURE:
|
|
102
|
+
_check_env_var("AZURE_OPENAI_API_KEY", "Azure OpenAI")
|
|
103
|
+
_check_env_var("AZURE_OPENAI_ENDPOINT", "Azure OpenAI")
|
|
104
|
+
|
|
54
105
|
return AsyncAzureOpenAI(
|
|
55
106
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
|
56
107
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
|
@@ -58,20 +109,27 @@ def _get_client(provider: Provider):
|
|
|
58
109
|
)
|
|
59
110
|
|
|
60
111
|
if provider == Provider.GOOGLE:
|
|
112
|
+
_check_env_var("GOOGLE_API_KEY", "Google Gemini")
|
|
61
113
|
return genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
|
|
62
114
|
|
|
63
115
|
if provider == Provider.OLLAMA:
|
|
116
|
+
api_key = _check_env_var(
|
|
117
|
+
"OLLAMA_API_KEY", "Ollama", required=False) or "ollama"
|
|
118
|
+
base_url = _check_env_var(
|
|
119
|
+
"OLLAMA_API_BASE_URL", "Ollama", required=False) or "http://localhost:11434/v1"
|
|
120
|
+
|
|
64
121
|
return openai.AsyncOpenAI(
|
|
65
|
-
api_key=
|
|
66
|
-
base_url=
|
|
122
|
+
api_key=api_key,
|
|
123
|
+
base_url=base_url
|
|
67
124
|
)
|
|
68
125
|
|
|
69
126
|
if provider == Provider.ANTHROPIC:
|
|
127
|
+
_check_env_var("ANTHROPIC_API_KEY", "Anthropic Claude")
|
|
70
128
|
return anthropic.AsyncAnthropic(
|
|
71
129
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
|
72
130
|
)
|
|
73
131
|
|
|
74
|
-
raise ValueError(f"Unsupported provider {provider}")
|
|
132
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
75
133
|
|
|
76
134
|
|
|
77
135
|
async def _openai_chat_complete(
|
|
@@ -80,17 +138,25 @@ async def _openai_chat_complete(
|
|
|
80
138
|
messages: list[dict[str, str]],
|
|
81
139
|
temperature: float,
|
|
82
140
|
):
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
141
|
+
"""OpenAI chat completion."""
|
|
142
|
+
try:
|
|
143
|
+
response = await client.chat.completions.create(
|
|
144
|
+
model=llm.model,
|
|
145
|
+
messages=messages,
|
|
146
|
+
temperature=temperature,
|
|
147
|
+
)
|
|
148
|
+
text = response.choices[0].message.content.strip()
|
|
149
|
+
cost = _calculate_cost(llm, response.usage)
|
|
150
|
+
return text, cost
|
|
151
|
+
except Exception as e:
|
|
152
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
153
|
+
raise LLMConfigurationError(
|
|
154
|
+
f"❌ OpenAI API authentication failed!\n\n"
|
|
155
|
+
f"Error: {str(e)}\n\n"
|
|
156
|
+
f"Please check that your OPENAI_API_KEY is valid.\n"
|
|
157
|
+
f"Get your API key at: https://platform.openai.com/api-keys"
|
|
158
|
+
)
|
|
159
|
+
raise
|
|
94
160
|
|
|
95
161
|
|
|
96
162
|
async def _azure_chat_complete(
|
|
@@ -99,17 +165,36 @@ async def _azure_chat_complete(
|
|
|
99
165
|
messages: list[dict[str, str]],
|
|
100
166
|
temperature: float,
|
|
101
167
|
):
|
|
102
|
-
|
|
168
|
+
"""Azure OpenAI chat completion."""
|
|
103
169
|
deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT") or llm.model
|
|
104
170
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
171
|
+
if not deployment_name:
|
|
172
|
+
raise LLMConfigurationError(
|
|
173
|
+
f"❌ Missing Azure OpenAI deployment name!\n\n"
|
|
174
|
+
f"Please set AZURE_OPENAI_DEPLOYMENT environment variable.\n"
|
|
175
|
+
f"Example: export AZURE_OPENAI_DEPLOYMENT='gpt-4o'"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
response = await client.chat.completions.create(
|
|
180
|
+
model=deployment_name,
|
|
181
|
+
messages=messages,
|
|
182
|
+
temperature=temperature,
|
|
183
|
+
)
|
|
184
|
+
text = response.choices[0].message.content.strip()
|
|
185
|
+
cost = _calculate_cost(llm, response.usage)
|
|
186
|
+
return text, cost
|
|
187
|
+
except Exception as e:
|
|
188
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
189
|
+
raise LLMConfigurationError(
|
|
190
|
+
f"❌ Azure OpenAI authentication failed!\n\n"
|
|
191
|
+
f"Error: {str(e)}\n\n"
|
|
192
|
+
f"Please check your Azure OpenAI configuration:\n"
|
|
193
|
+
f" - AZURE_OPENAI_API_KEY\n"
|
|
194
|
+
f" - AZURE_OPENAI_ENDPOINT\n"
|
|
195
|
+
f" - AZURE_OPENAI_DEPLOYMENT"
|
|
196
|
+
)
|
|
197
|
+
raise
|
|
113
198
|
|
|
114
199
|
|
|
115
200
|
async def _google_chat_complete(
|
|
@@ -118,27 +203,35 @@ async def _google_chat_complete(
|
|
|
118
203
|
messages: list[dict[str, str]],
|
|
119
204
|
temperature: float,
|
|
120
205
|
):
|
|
121
|
-
"""
|
|
122
|
-
Google GenAI / Gemini 2.x
|
|
123
|
-
"""
|
|
206
|
+
"""Google GenAI / Gemini chat completion."""
|
|
124
207
|
prompt = "\n".join(m["content"] for m in messages)
|
|
125
208
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
209
|
+
try:
|
|
210
|
+
response = await client.aio.models.generate_content(
|
|
211
|
+
model=llm.model,
|
|
212
|
+
contents=prompt,
|
|
213
|
+
config=GenerateContentConfig(temperature=temperature),
|
|
214
|
+
)
|
|
131
215
|
|
|
132
|
-
|
|
216
|
+
text = response.text.strip()
|
|
133
217
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
218
|
+
um = response.usage_metadata
|
|
219
|
+
usage = SimpleNamespace(
|
|
220
|
+
prompt_tokens=um.prompt_token_count,
|
|
221
|
+
completion_tokens=um.candidates_token_count,
|
|
222
|
+
)
|
|
139
223
|
|
|
140
|
-
|
|
141
|
-
|
|
224
|
+
cost = _calculate_cost(llm, usage)
|
|
225
|
+
return text, cost
|
|
226
|
+
except Exception as e:
|
|
227
|
+
if "API key" in str(e) or "authentication" in str(e).lower() or "credentials" in str(e).lower():
|
|
228
|
+
raise LLMConfigurationError(
|
|
229
|
+
f"❌ Google Gemini API authentication failed!\n\n"
|
|
230
|
+
f"Error: {str(e)}\n\n"
|
|
231
|
+
f"Please check that your GOOGLE_API_KEY is valid.\n"
|
|
232
|
+
f"Get your API key at: https://aistudio.google.com/apikey"
|
|
233
|
+
)
|
|
234
|
+
raise
|
|
142
235
|
|
|
143
236
|
|
|
144
237
|
async def _ollama_chat_complete(
|
|
@@ -147,14 +240,29 @@ async def _ollama_chat_complete(
|
|
|
147
240
|
messages: list[dict[str, str]],
|
|
148
241
|
temperature: float,
|
|
149
242
|
):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
243
|
+
"""Ollama (local) chat completion."""
|
|
244
|
+
try:
|
|
245
|
+
response = await client.chat.completions.create(
|
|
246
|
+
model=llm.model,
|
|
247
|
+
messages=messages,
|
|
248
|
+
temperature=temperature,
|
|
249
|
+
)
|
|
250
|
+
text = response.choices[0].message.content.strip()
|
|
251
|
+
cost = _calculate_cost(llm, response.usage)
|
|
252
|
+
return text, cost
|
|
253
|
+
except Exception as e:
|
|
254
|
+
if "Connection" in str(e) or "refused" in str(e).lower():
|
|
255
|
+
raise LLMConfigurationError(
|
|
256
|
+
f"❌ Cannot connect to Ollama server!\n\n"
|
|
257
|
+
f"Error: {str(e)}\n\n"
|
|
258
|
+
f"Make sure Ollama is running:\n"
|
|
259
|
+
f" 1. Install Ollama: https://ollama.ai/download\n"
|
|
260
|
+
f" 2. Start Ollama: ollama serve\n"
|
|
261
|
+
f" 3. Pull model: ollama pull {llm.model}\n\n"
|
|
262
|
+
f"Or set OLLAMA_API_BASE_URL to your Ollama server:\n"
|
|
263
|
+
f" export OLLAMA_API_BASE_URL='http://localhost:11434/v1'"
|
|
264
|
+
)
|
|
265
|
+
raise
|
|
158
266
|
|
|
159
267
|
|
|
160
268
|
async def _anthropic_chat_complete(
|
|
@@ -163,23 +271,31 @@ async def _anthropic_chat_complete(
|
|
|
163
271
|
messages: list[dict[str, str]],
|
|
164
272
|
temperature: float,
|
|
165
273
|
):
|
|
166
|
-
"""
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
274
|
+
"""Anthropic Claude chat completion."""
|
|
275
|
+
try:
|
|
276
|
+
response = await client.messages.create(
|
|
277
|
+
model=llm.model,
|
|
278
|
+
messages=messages,
|
|
279
|
+
temperature=temperature,
|
|
280
|
+
max_tokens=4096,
|
|
281
|
+
)
|
|
282
|
+
if isinstance(response.content, list):
|
|
283
|
+
text = "".join(
|
|
284
|
+
block.text for block in response.content if block.type == "text").strip()
|
|
285
|
+
else:
|
|
286
|
+
text = response.content.strip()
|
|
287
|
+
|
|
288
|
+
cost = _calculate_cost(llm, response.usage)
|
|
289
|
+
return text, cost
|
|
290
|
+
except Exception as e:
|
|
291
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
292
|
+
raise LLMConfigurationError(
|
|
293
|
+
f"❌ Anthropic Claude API authentication failed!\n\n"
|
|
294
|
+
f"Error: {str(e)}\n\n"
|
|
295
|
+
f"Please check that your ANTHROPIC_API_KEY is valid.\n"
|
|
296
|
+
f"Get your API key at: https://console.anthropic.com/settings/keys"
|
|
297
|
+
)
|
|
298
|
+
raise
|
|
183
299
|
|
|
184
300
|
|
|
185
301
|
_HELPERS = {
|
|
@@ -196,20 +312,33 @@ async def chat_complete(
|
|
|
196
312
|
messages: list[dict[str, str]],
|
|
197
313
|
temperature: float = 0.0,
|
|
198
314
|
):
|
|
315
|
+
"""
|
|
316
|
+
Complete a chat conversation using the specified LLM.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
llm: LLM specification (e.g., "gpt-4o-mini", "openai:gpt-4o", or LLMDescriptor)
|
|
320
|
+
messages: List of message dicts with "role" and "content"
|
|
321
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Tuple of (response_text, cost_in_usd)
|
|
325
|
+
|
|
326
|
+
Raises:
|
|
327
|
+
LLMConfigurationError: If required API keys or configuration are missing
|
|
328
|
+
ValueError: If provider is not supported
|
|
329
|
+
"""
|
|
199
330
|
llm = LLMDescriptor.parse(llm)
|
|
200
331
|
helper = _HELPERS.get(llm.provider)
|
|
201
332
|
|
|
202
333
|
if helper is None:
|
|
203
|
-
raise ValueError(f"Unsupported provider {llm.provider}")
|
|
334
|
+
raise ValueError(f"Unsupported provider: {llm.provider}")
|
|
204
335
|
|
|
205
336
|
client = _get_client(llm.provider)
|
|
206
337
|
return await helper(client, llm, messages, temperature)
|
|
207
338
|
|
|
208
339
|
|
|
209
340
|
def _calculate_cost(llm: LLMDescriptor, usage) -> Optional[float]:
|
|
210
|
-
"""
|
|
211
|
-
Calculate the cost of the LLM usage based on the model and usage data.
|
|
212
|
-
"""
|
|
341
|
+
"""Calculate the cost of the LLM usage based on the model and usage data."""
|
|
213
342
|
if llm.provider == Provider.OLLAMA:
|
|
214
343
|
return 0.0
|
|
215
344
|
if not usage:
|
|
@@ -219,7 +348,7 @@ def _calculate_cost(llm: LLMDescriptor, usage) -> Optional[float]:
|
|
|
219
348
|
if not price:
|
|
220
349
|
return None
|
|
221
350
|
|
|
222
|
-
prompt = getattr(usage, "prompt_tokens",
|
|
351
|
+
prompt = getattr(usage, "prompt_tokens", 0)
|
|
223
352
|
completion = getattr(usage, "completion_tokens", 0)
|
|
224
353
|
|
|
225
354
|
return round(
|
|
@@ -242,6 +371,10 @@ async def get_embeddings(
|
|
|
242
371
|
|
|
243
372
|
Returns:
|
|
244
373
|
Tuple of (embeddings_list, total_cost)
|
|
374
|
+
|
|
375
|
+
Raises:
|
|
376
|
+
LLMConfigurationError: If required API keys are missing
|
|
377
|
+
ValueError: If non-OpenAI provider is specified
|
|
245
378
|
"""
|
|
246
379
|
llm = LLMDescriptor.parse(model)
|
|
247
380
|
|
|
@@ -259,16 +392,26 @@ async def _openai_get_embeddings(
|
|
|
259
392
|
texts: list[str],
|
|
260
393
|
) -> tuple[list[list[float]], Optional[float]]:
|
|
261
394
|
"""OpenAI embeddings implementation."""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
embeddings = [data.embedding for data in response.data]
|
|
269
|
-
cost = _calculate_embedding_cost(llm, response.usage)
|
|
395
|
+
try:
|
|
396
|
+
response = await client.embeddings.create(
|
|
397
|
+
model=llm.model,
|
|
398
|
+
input=texts,
|
|
399
|
+
encoding_format="float"
|
|
400
|
+
)
|
|
270
401
|
|
|
271
|
-
|
|
402
|
+
embeddings = [data.embedding for data in response.data]
|
|
403
|
+
cost = _calculate_embedding_cost(llm, response.usage)
|
|
404
|
+
|
|
405
|
+
return embeddings, cost
|
|
406
|
+
except Exception as e:
|
|
407
|
+
if "API key" in str(e) or "authentication" in str(e).lower():
|
|
408
|
+
raise LLMConfigurationError(
|
|
409
|
+
f"❌ OpenAI API authentication failed for embeddings!\n\n"
|
|
410
|
+
f"Error: {str(e)}\n\n"
|
|
411
|
+
f"Please check that your OPENAI_API_KEY is valid.\n"
|
|
412
|
+
f"Get your API key at: https://platform.openai.com/api-keys"
|
|
413
|
+
)
|
|
414
|
+
raise
|
|
272
415
|
|
|
273
416
|
|
|
274
417
|
def _calculate_embedding_cost(llm: LLMDescriptor, usage) -> Optional[float]:
|