eval-ai-library 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eval-ai-library might be problematic. Click here for more details.

Files changed (34) hide show
  1. eval_ai_library-0.1.0.dist-info/METADATA +753 -0
  2. eval_ai_library-0.1.0.dist-info/RECORD +34 -0
  3. eval_ai_library-0.1.0.dist-info/WHEEL +5 -0
  4. eval_ai_library-0.1.0.dist-info/licenses/LICENSE +21 -0
  5. eval_ai_library-0.1.0.dist-info/top_level.txt +1 -0
  6. eval_lib/__init__.py +122 -0
  7. eval_lib/agent_metrics/__init__.py +12 -0
  8. eval_lib/agent_metrics/knowledge_retention_metric/knowledge_retention.py +231 -0
  9. eval_lib/agent_metrics/role_adherence_metric/role_adherence.py +251 -0
  10. eval_lib/agent_metrics/task_success_metric/task_success_rate.py +347 -0
  11. eval_lib/agent_metrics/tools_correctness_metric/tool_correctness.py +106 -0
  12. eval_lib/datagenerator/datagenerator.py +230 -0
  13. eval_lib/datagenerator/document_loader.py +510 -0
  14. eval_lib/datagenerator/prompts.py +192 -0
  15. eval_lib/evaluate.py +335 -0
  16. eval_lib/evaluation_schema.py +63 -0
  17. eval_lib/llm_client.py +286 -0
  18. eval_lib/metric_pattern.py +229 -0
  19. eval_lib/metrics/__init__.py +25 -0
  20. eval_lib/metrics/answer_precision_metric/answer_precision.py +405 -0
  21. eval_lib/metrics/answer_relevancy_metric/answer_relevancy.py +195 -0
  22. eval_lib/metrics/bias_metric/bias.py +114 -0
  23. eval_lib/metrics/contextual_precision_metric/contextual_precision.py +102 -0
  24. eval_lib/metrics/contextual_recall_metric/contextual_recall.py +91 -0
  25. eval_lib/metrics/contextual_relevancy_metric/contextual_relevancy.py +169 -0
  26. eval_lib/metrics/custom_metric/custom_eval.py +303 -0
  27. eval_lib/metrics/faithfulness_metric/faithfulness.py +140 -0
  28. eval_lib/metrics/geval/geval.py +326 -0
  29. eval_lib/metrics/restricted_refusal_metric/restricted_refusal.py +102 -0
  30. eval_lib/metrics/toxicity_metric/toxicity.py +113 -0
  31. eval_lib/price.py +37 -0
  32. eval_lib/py.typed +1 -0
  33. eval_lib/testcases_schema.py +27 -0
  34. eval_lib/utils.py +99 -0
eval_lib/evaluate.py ADDED
@@ -0,0 +1,335 @@
1
+ # evaluate.py
2
+ """
3
+ Main evaluation functions with beautiful console progress tracking.
4
+ """
5
+ from dataclasses import asdict
6
+ import json
7
+ import time
8
+ from typing import List, Tuple, Dict, Any
9
+ from eval_lib.testcases_schema import EvalTestCase, ConversationalEvalTestCase
10
+ from eval_lib.metric_pattern import MetricPattern, ConversationalMetricPattern
11
+ from eval_lib.evaluation_schema import TestCaseResult, MetricResult, ConversationalTestCaseResult
12
+
13
+
14
+ # ANSI color codes
15
+ class Colors:
16
+ HEADER = '\033[95m'
17
+ BLUE = '\033[94m'
18
+ CYAN = '\033[96m'
19
+ GREEN = '\033[92m'
20
+ YELLOW = '\033[93m'
21
+ RED = '\033[91m'
22
+ ENDC = '\033[0m'
23
+ BOLD = '\033[1m'
24
+ UNDERLINE = '\033[4m'
25
+ DIM = '\033[2m'
26
+
27
+
28
+ def _print_header(title: str):
29
+ """Print formatted header"""
30
+ print(f"\n{Colors.BOLD}{Colors.HEADER}{'='*70}{Colors.ENDC}")
31
+ print(f"{Colors.BOLD}{Colors.HEADER}{title.center(70)}{Colors.ENDC}")
32
+ print(f"{Colors.BOLD}{Colors.HEADER}{'='*70}{Colors.ENDC}\n")
33
+
34
+
35
+ def _print_progress(current: int, total: int, item_name: str):
36
+ """Print progress bar"""
37
+ percentage = (current / total) * 100
38
+ bar_length = 40
39
+ filled = int(bar_length * current / total)
40
+ bar = '█' * filled + '░' * (bar_length - filled)
41
+
42
+ print(
43
+ f"\r{Colors.CYAN}Progress: [{bar}] {percentage:.0f}% ({current}/{total}) - {item_name}{Colors.ENDC}", end='', flush=True)
44
+
45
+
46
+ def _print_summary(results: List, total_cost: float, total_time: float, passed: int, total: int):
47
+ """Print evaluation summary"""
48
+ print(f"\n\n{Colors.BOLD}{Colors.GREEN}{'='*70}{Colors.ENDC}")
49
+ print(f"{Colors.BOLD}{Colors.GREEN}📋 EVALUATION SUMMARY{Colors.ENDC}")
50
+ print(f"{Colors.BOLD}{Colors.GREEN}{'='*70}{Colors.ENDC}")
51
+
52
+ success_rate = (passed / total * 100) if total > 0 else 0
53
+ status_color = Colors.GREEN if success_rate >= 80 else Colors.YELLOW if success_rate >= 50 else Colors.RED
54
+
55
+ print(f"\n{Colors.BOLD}Overall Results:{Colors.ENDC}")
56
+ print(f" ✅ Passed: {Colors.GREEN}{passed}{Colors.ENDC} / {total}")
57
+ print(f" ❌ Failed: {Colors.RED}{total - passed}{Colors.ENDC} / {total}")
58
+ print(f" 📊 Success Rate: {status_color}{success_rate:.1f}%{Colors.ENDC}")
59
+ print(f"\n{Colors.BOLD}Resource Usage:{Colors.ENDC}")
60
+ print(f" 💰 Total Cost: {Colors.YELLOW}${total_cost:.6f}{Colors.ENDC}")
61
+ print(f" ⏱️ Total Time: {Colors.CYAN}{total_time:.2f}s{Colors.ENDC}")
62
+ print(
63
+ f" 📈 Avg Time per Test: {Colors.DIM}{(total_time/total if total > 0 else 0):.2f}s{Colors.ENDC}")
64
+
65
+ print(f"\n{Colors.BOLD}{Colors.GREEN}{'='*70}{Colors.ENDC}\n")
66
+
67
+
68
+ async def evaluate(
69
+ test_cases: List[EvalTestCase],
70
+ metrics: List[MetricPattern],
71
+ verbose: bool = True
72
+ ) -> List[Tuple[None, List[TestCaseResult]]]:
73
+ """
74
+ Evaluate test cases with multiple metrics.
75
+
76
+ Args:
77
+ test_cases: List of test cases to evaluate
78
+ metrics: List of metrics to apply
79
+ verbose: Enable detailed logging (default: True)
80
+
81
+ Returns:
82
+ List of evaluation results
83
+ """
84
+ start_time = time.time()
85
+ results: List[Tuple[None, List[TestCaseResult]]] = []
86
+
87
+ total_cost = 0.0
88
+ total_passed = 0
89
+ total_tests = len(test_cases)
90
+
91
+ if verbose:
92
+ _print_header("🚀 STARTING EVALUATION")
93
+ print(f"{Colors.BOLD}Configuration:{Colors.ENDC}")
94
+ print(f" 📝 Test Cases: {Colors.CYAN}{total_tests}{Colors.ENDC}")
95
+ print(f" 📊 Metrics: {Colors.CYAN}{len(metrics)}{Colors.ENDC}")
96
+ print(
97
+ f" 🎯 Total Evaluations: {Colors.CYAN}{total_tests * len(metrics)}{Colors.ENDC}")
98
+
99
+ print(f"\n{Colors.BOLD}Metrics:{Colors.ENDC}")
100
+ for i, m in enumerate(metrics, 1):
101
+ print(
102
+ f" {i}. {Colors.BLUE}{m.name}{Colors.ENDC} (threshold: {m.threshold})")
103
+
104
+ # Process each test case
105
+ for tc_idx, tc in enumerate(test_cases, 1):
106
+ if verbose:
107
+ print(f"\n{Colors.BOLD}{Colors.CYAN}{'─'*70}{Colors.ENDC}")
108
+ print(
109
+ f"{Colors.BOLD}{Colors.CYAN}📝 Test Case {tc_idx}/{total_tests}{Colors.ENDC}")
110
+ print(f"{Colors.BOLD}{Colors.CYAN}{'─'*70}{Colors.ENDC}")
111
+ print(
112
+ f"{Colors.DIM}Input: {tc.input[:80]}{'...' if len(tc.input) > 80 else ''}{Colors.ENDC}")
113
+
114
+ mdata = []
115
+ test_cost = 0.0
116
+
117
+ # Evaluate with each metric
118
+ for m_idx, m in enumerate(metrics, 1):
119
+ if verbose:
120
+ _print_progress(m_idx, len(metrics), m.name)
121
+
122
+ # Set verbose flag for metrics
123
+ original_verbose = getattr(m, 'verbose', True)
124
+ m.verbose = verbose
125
+
126
+ res = await m.evaluate(tc)
127
+
128
+ # Restore original verbose setting
129
+ m.verbose = original_verbose
130
+
131
+ # Gather results
132
+ cost = res.get("evaluation_cost", 0) or 0
133
+ test_cost += cost
134
+ total_cost += cost
135
+
136
+ mdata.append(MetricResult(
137
+ name=m.name,
138
+ score=res["score"],
139
+ threshold=m.threshold,
140
+ success=res["success"],
141
+ evaluation_cost=cost,
142
+ reason=res["reason"],
143
+ evaluation_model=m.model,
144
+ evaluation_log=res.get("evaluation_log", None)
145
+ ))
146
+
147
+ overall = all(d.success for d in mdata)
148
+ if overall:
149
+ total_passed += 1
150
+
151
+ if verbose:
152
+ print(f"\n{Colors.BOLD}Test Case Summary:{Colors.ENDC}")
153
+ tc_status_color = Colors.GREEN if overall else Colors.RED
154
+ tc_status_icon = "✅" if overall else "❌"
155
+ print(
156
+ f" {tc_status_icon} Overall: {tc_status_color}{Colors.BOLD}{'PASSED' if overall else 'FAILED'}{Colors.ENDC}")
157
+ print(f" 💰 Cost: {Colors.YELLOW}${test_cost:.6f}{Colors.ENDC}")
158
+
159
+ # Show metric breakdown
160
+ print(f"\n {Colors.BOLD}Metrics Breakdown:{Colors.ENDC}")
161
+ for md in mdata:
162
+ status = "✅" if md.success else "❌"
163
+ color = Colors.GREEN if md.success else Colors.RED
164
+ print(
165
+ f" {status} {md.name}: {color}{md.score:.2f}{Colors.ENDC}")
166
+
167
+ results.append((None, [TestCaseResult(
168
+ input=tc.input,
169
+ actual_output=tc.actual_output,
170
+ expected_output=tc.expected_output,
171
+ retrieval_context=tc.retrieval_context,
172
+ tools_called=tc.tools_called,
173
+ expected_tools=tc.expected_tools,
174
+ success=overall,
175
+ metrics_data=mdata
176
+ )]))
177
+
178
+ # Calculate total time
179
+ total_time = time.time() - start_time
180
+
181
+ # Print summary
182
+ if verbose:
183
+ _print_summary(results, total_cost, total_time,
184
+ total_passed, total_tests)
185
+
186
+ # Print detailed results if requested
187
+ if verbose:
188
+ print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
189
+ for idx, (meta, tc_list) in enumerate(results, 1):
190
+ print(f"\n{Colors.DIM}{'─'*70}{Colors.ENDC}")
191
+ print(f"{Colors.BOLD}Test Case {idx}:{Colors.ENDC}")
192
+ for tc in tc_list:
193
+ tc_dict = asdict(tc)
194
+ # Pretty print with indentation
195
+ print(json.dumps(tc_dict, indent=2, ensure_ascii=False))
196
+ print(f"{Colors.DIM}{'─'*70}{Colors.ENDC}\n")
197
+
198
+ return results
199
+
200
+
201
+ async def evaluate_conversations(
202
+ conv_cases: List[ConversationalEvalTestCase],
203
+ metrics: List[ConversationalMetricPattern],
204
+ verbose: bool = True
205
+ ) -> List[Tuple[None, List[ConversationalTestCaseResult]]]:
206
+ """
207
+ Evaluate conversational test cases with multiple metrics.
208
+
209
+ Args:
210
+ conv_cases: List of conversational test cases
211
+ metrics: List of conversational metrics
212
+ verbose: Enable detailed logging (default: True)
213
+
214
+ Returns:
215
+ List of evaluation results
216
+ """
217
+ start_time = time.time()
218
+ results: List[Tuple[None, List[ConversationalTestCaseResult]]] = []
219
+
220
+ total_cost = 0.0
221
+ total_passed = 0
222
+ total_conversations = len(conv_cases)
223
+
224
+ if verbose:
225
+ _print_header("🚀 STARTING CONVERSATIONAL EVALUATION")
226
+ print(f"{Colors.BOLD}Configuration:{Colors.ENDC}")
227
+ print(
228
+ f" 💬 Conversations: {Colors.CYAN}{total_conversations}{Colors.ENDC}")
229
+ print(f" 📊 Metrics: {Colors.CYAN}{len(metrics)}{Colors.ENDC}")
230
+ print(
231
+ f" 🎯 Total Evaluations: {Colors.CYAN}{total_conversations * len(metrics)}{Colors.ENDC}")
232
+
233
+ print(f"\n{Colors.BOLD}Metrics:{Colors.ENDC}")
234
+ for i, m in enumerate(metrics, 1):
235
+ print(
236
+ f" {i}. {Colors.BLUE}{m.name}{Colors.ENDC} (threshold: {m.threshold})")
237
+
238
+ # Process each conversation
239
+ for conv_idx, dlg in enumerate(conv_cases, 1):
240
+ if verbose:
241
+ print(f"\n{Colors.BOLD}{Colors.CYAN}{'─'*70}{Colors.ENDC}")
242
+ print(
243
+ f"{Colors.BOLD}{Colors.CYAN}💬 Conversation {conv_idx}/{total_conversations}{Colors.ENDC}")
244
+ print(f"{Colors.BOLD}{Colors.CYAN}{'─'*70}{Colors.ENDC}")
245
+ print(f"{Colors.DIM}Turns: {len(dlg.turns)}{Colors.ENDC}")
246
+ if dlg.chatbot_role:
247
+ print(
248
+ f"{Colors.DIM}Role: {dlg.chatbot_role[:60]}{'...' if len(dlg.chatbot_role) > 60 else ''}{Colors.ENDC}")
249
+
250
+ metric_rows: List[MetricResult] = []
251
+ conv_cost = 0.0
252
+
253
+ # Evaluate with each metric
254
+ for m_idx, m in enumerate(metrics, 1):
255
+ if verbose:
256
+ _print_progress(m_idx, len(metrics), m.name)
257
+
258
+ # Set verbose flag for metrics
259
+ original_verbose = getattr(m, 'verbose', True)
260
+ m.verbose = verbose
261
+
262
+ res: Dict[str, Any] = await m.evaluate(dlg)
263
+
264
+ # Restore original verbose setting
265
+ m.verbose = original_verbose
266
+
267
+ cost = res.get("evaluation_cost", 0) or 0
268
+ conv_cost += cost
269
+ total_cost += cost
270
+
271
+ metric_rows.append(
272
+ MetricResult(
273
+ name=m.name,
274
+ score=res["score"],
275
+ threshold=m.threshold,
276
+ success=res["success"],
277
+ evaluation_cost=cost,
278
+ reason=res.get("reason"),
279
+ evaluation_model=m.model,
280
+ evaluation_log=res.get("evaluation_log"),
281
+ )
282
+ )
283
+
284
+ overall_ok = all(r.success for r in metric_rows)
285
+ if overall_ok:
286
+ total_passed += 1
287
+
288
+ if verbose:
289
+ print(f"\n{Colors.BOLD}Conversation Summary:{Colors.ENDC}")
290
+ conv_status_color = Colors.GREEN if overall_ok else Colors.RED
291
+ conv_status_icon = "✅" if overall_ok else "❌"
292
+ print(
293
+ f" {conv_status_icon} Overall: {conv_status_color}{Colors.BOLD}{'PASSED' if overall_ok else 'FAILED'}{Colors.ENDC}")
294
+ print(f" 💰 Cost: {Colors.YELLOW}${conv_cost:.6f}{Colors.ENDC}")
295
+
296
+ # Show metric breakdown
297
+ print(f"\n {Colors.BOLD}Metrics Breakdown:{Colors.ENDC}")
298
+ for mr in metric_rows:
299
+ status = "✅" if mr.success else "❌"
300
+ color = Colors.GREEN if mr.success else Colors.RED
301
+ print(
302
+ f" {status} {mr.name}: {color}{mr.score:.2f}{Colors.ENDC}")
303
+
304
+ dialogue_raw = []
305
+ for turn in dlg.turns:
306
+ dialogue_raw.append({"role": "user", "content": turn.input})
307
+ dialogue_raw.append(
308
+ {"role": "assistant", "content": turn.actual_output})
309
+
310
+ conv_res = ConversationalTestCaseResult(
311
+ dialogue=dialogue_raw,
312
+ success=overall_ok,
313
+ metrics_data=metric_rows,
314
+ )
315
+ results.append((None, [conv_res]))
316
+
317
+ # Calculate total time
318
+ total_time = time.time() - start_time
319
+
320
+ # Print summary
321
+ if verbose:
322
+ _print_summary(results, total_cost, total_time,
323
+ total_passed, total_conversations)
324
+
325
+ # Print detailed results if requested
326
+ if verbose:
327
+ print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
328
+ for idx, (_, conv_list) in enumerate(results, 1):
329
+ print(f"\n{Colors.DIM}{'─'*70}{Colors.ENDC}")
330
+ print(f"{Colors.BOLD}Conversation {idx}:{Colors.ENDC}")
331
+ for conv in conv_list:
332
+ print(json.dumps(asdict(conv), indent=2, ensure_ascii=False))
333
+ print(f"{Colors.DIM}{'─'*70}{Colors.ENDC}\n")
334
+
335
+ return results
@@ -0,0 +1,63 @@
1
+ # evaluation_schema.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Any, Dict
5
+
6
+
7
+ @dataclass
8
+ class MetricResult:
9
+ """
10
+ Result of one metric evaluation.
11
+ In pair with TestCaseResult used in save_results in the form of (MetricResult, None)
12
+ - name: str
13
+ - score: float
14
+ - threshold: float
15
+ - success: bool
16
+ - evaluation_cost: Optional[float]
17
+ - reason: Optional[str]
18
+ - evaluation_model: str
19
+ """
20
+ name: str
21
+ score: float
22
+ threshold: float
23
+ success: bool
24
+ evaluation_cost: Optional[float]
25
+ reason: Optional[str]
26
+ evaluation_model: str
27
+ # can be a dict or any other type depending on the metric
28
+ evaluation_log: Optional[Any] = None
29
+
30
+
31
+ @dataclass
32
+ class TestCaseResult:
33
+ """
34
+ Final result of one test case evaluation.
35
+ In pair with MetricResult used in save_results in the form of (None, [TestCaseResult])
36
+ - input: str
37
+ - actual_output: str
38
+ - expected_output: Optional[str]
39
+ - retrieval_context: Optional[List[str]]
40
+ - success: bool — General success of the test case
41
+ - metrics_data: List[MetricResult] — List of metric results
42
+ """
43
+ input: str
44
+ actual_output: str
45
+ expected_output: Optional[str]
46
+ retrieval_context: Optional[List[str]]
47
+ success: bool
48
+ metrics_data: List[MetricResult]
49
+ tools_called: Optional[List[str]] = None
50
+ expected_tools: Optional[List[str]] = None
51
+
52
+
53
+ @dataclass
54
+ class ConversationalTestCaseResult:
55
+ """
56
+ Result of a conversational test case evaluation.
57
+ - turns: List[TestCaseResult] — List of individual test case results for each turn
58
+ - chatbot_role: Optional[str] — Role of the chatbot in the conversation
59
+ - name: Optional[str] — Name of the conversational test case
60
+ """
61
+ dialogue: List[Dict[str, str]]
62
+ success: bool
63
+ metrics_data: List[MetricResult]
eval_lib/llm_client.py ADDED
@@ -0,0 +1,286 @@
1
+ # llm_client.py
2
+ import openai
3
+ import functools
4
+ import anthropic
5
+ from openai import AsyncAzureOpenAI
6
+ from google import genai
7
+ from google.genai.types import GenerateContentConfig
8
+ import os
9
+ from enum import Enum
10
+ from dataclasses import dataclass
11
+ from typing import Dict, Tuple, Optional
12
+ from types import SimpleNamespace
13
+ from .price import model_pricing
14
+
15
+
16
+ class Provider(str, Enum):
17
+ OPENAI = "openai"
18
+ AZURE = "azure"
19
+ GOOGLE = "google"
20
+ OLLAMA = "ollama"
21
+ ANTHROPIC = "anthropic"
22
+
23
+
24
+ @dataclass(frozen=True, slots=True)
25
+ class LLMDescriptor:
26
+ """'openai:gpt-4o' → provider=openai, model='gpt-4o'"""
27
+ provider: Provider
28
+ model: str
29
+
30
+ @classmethod
31
+ def parse(cls, spec: str | Tuple[str, str] | "LLMDescriptor") -> "LLMDescriptor":
32
+ if isinstance(spec, LLMDescriptor):
33
+ return spec
34
+ if isinstance(spec, tuple):
35
+ provider, model = spec
36
+ return cls(Provider(provider), model)
37
+ try:
38
+ provider, model = spec.split(":", 1)
39
+ except ValueError:
40
+ return cls(Provider.OPENAI, spec)
41
+ return cls(Provider(provider), model)
42
+
43
+ def key(self) -> str:
44
+ """Return a unique key for the LLM descriptor."""
45
+ return f"{self.provider}:{self.model}"
46
+
47
+
48
+ @functools.cache
49
+ def _get_client(provider: Provider):
50
+ if provider == Provider.OPENAI:
51
+ return openai.AsyncOpenAI()
52
+
53
+ if provider == Provider.AZURE:
54
+ return AsyncAzureOpenAI(
55
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
56
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
57
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01"),
58
+ )
59
+
60
+ if provider == Provider.GOOGLE:
61
+ return genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
62
+
63
+ if provider == Provider.OLLAMA:
64
+ return openai.AsyncOpenAI(
65
+ api_key=os.getenv("OLLAMA_API_KEY"),
66
+ base_url=os.getenv("OLLAMA_API_BASE_URL")
67
+ )
68
+
69
+ if provider == Provider.ANTHROPIC:
70
+ return anthropic.AsyncAnthropic(
71
+ api_key=os.getenv("ANTHROPIC_API_KEY"),
72
+ )
73
+
74
+ raise ValueError(f"Unsupported provider {provider}")
75
+
76
+
77
+ async def _openai_chat_complete(
78
+ client,
79
+ llm: LLMDescriptor,
80
+ messages: list[dict[str, str]],
81
+ temperature: float,
82
+ ):
83
+ """
84
+ Обычный OpenAI.
85
+ """
86
+ response = await client.chat.completions.create(
87
+ model=llm.model,
88
+ messages=messages,
89
+ temperature=temperature,
90
+ )
91
+ text = response.choices[0].message.content.strip()
92
+ cost = _calculate_cost(llm, response.usage)
93
+ return text, cost
94
+
95
+
96
+ async def _azure_chat_complete(
97
+ client,
98
+ llm: LLMDescriptor,
99
+ messages: list[dict[str, str]],
100
+ temperature: float,
101
+ ):
102
+
103
+ deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT") or llm.model
104
+
105
+ response = await client.chat.completions.create(
106
+ model=deployment_name,
107
+ messages=messages,
108
+ temperature=temperature,
109
+ )
110
+ text = response.choices[0].message.content.strip()
111
+ cost = _calculate_cost(llm, response.usage)
112
+ return text, cost
113
+
114
+
115
+ async def _google_chat_complete(
116
+ client,
117
+ llm: LLMDescriptor,
118
+ messages: list[dict[str, str]],
119
+ temperature: float,
120
+ ):
121
+ """
122
+ Google GenAI / Gemini 2.x
123
+ """
124
+ prompt = "\n".join(m["content"] for m in messages)
125
+
126
+ response = await client.aio.models.generate_content(
127
+ model=llm.model,
128
+ contents=prompt,
129
+ config=GenerateContentConfig(temperature=temperature),
130
+ )
131
+
132
+ text = response.text.strip()
133
+
134
+ um = response.usage_metadata
135
+ usage = SimpleNamespace(
136
+ prompt_tokens=um.prompt_token_count,
137
+ completion_tokens=um.candidates_token_count,
138
+ )
139
+
140
+ cost = _calculate_cost(llm, usage)
141
+ return text, cost
142
+
143
+
144
+ async def _ollama_chat_complete(
145
+ client,
146
+ llm: LLMDescriptor,
147
+ messages: list[dict[str, str]],
148
+ temperature: float,
149
+ ):
150
+ response = await client.chat.completions.create(
151
+ model=llm.model,
152
+ messages=messages,
153
+ temperature=temperature,
154
+ )
155
+ text = response.choices[0].message.content.strip()
156
+ cost = _calculate_cost(llm, response.usage)
157
+ return text, cost
158
+
159
+
160
+ async def _anthropic_chat_complete(
161
+ client,
162
+ llm: LLMDescriptor,
163
+ messages: list[dict[str, str]],
164
+ temperature: float,
165
+ ):
166
+ """
167
+ Anthropic Claude chat completion.
168
+ """
169
+ response = await client.messages.create(
170
+ model=llm.model,
171
+ messages=messages,
172
+ temperature=temperature,
173
+ max_tokens=4096, # Default max tokens for Claude
174
+ )
175
+ if isinstance(response.content, list):
176
+ text = "".join(
177
+ block.text for block in response.content if block.type == "text").strip()
178
+ else:
179
+ text = response.content.strip()
180
+
181
+ cost = _calculate_cost(llm, response.usage)
182
+ return text, cost
183
+
184
+
185
+ _HELPERS = {
186
+ Provider.OPENAI: _openai_chat_complete,
187
+ Provider.AZURE: _azure_chat_complete,
188
+ Provider.GOOGLE: _google_chat_complete,
189
+ Provider.OLLAMA: _ollama_chat_complete,
190
+ Provider.ANTHROPIC: _anthropic_chat_complete,
191
+ }
192
+
193
+
194
+ async def chat_complete(
195
+ llm: str | tuple[str, str] | LLMDescriptor,
196
+ messages: list[dict[str, str]],
197
+ temperature: float = 0.0,
198
+ ):
199
+ llm = LLMDescriptor.parse(llm)
200
+ helper = _HELPERS.get(llm.provider)
201
+
202
+ if helper is None:
203
+ raise ValueError(f"Unsupported provider {llm.provider}")
204
+
205
+ client = _get_client(llm.provider)
206
+ return await helper(client, llm, messages, temperature)
207
+
208
+
209
+ def _calculate_cost(llm: LLMDescriptor, usage) -> Optional[float]:
210
+ """
211
+ Calculate the cost of the LLM usage based on the model and usage data.
212
+ """
213
+ if llm.provider == Provider.OLLAMA:
214
+ return 0.0
215
+ if not usage:
216
+ return None
217
+
218
+ price = model_pricing.get(llm.model)
219
+ if not price:
220
+ return None
221
+
222
+ prompt = getattr(usage, "prompt_tokens", 0)
223
+ completion = getattr(usage, "completion_tokens", 0)
224
+
225
+ return round(
226
+ prompt * price["input"] / 1_000_000 +
227
+ completion * price["output"] / 1_000_000,
228
+ 6
229
+ )
230
+
231
+
232
+ async def get_embeddings(
233
+ model: str | tuple[str, str] | LLMDescriptor,
234
+ texts: list[str],
235
+ ) -> tuple[list[list[float]], Optional[float]]:
236
+ """
237
+ Get embeddings for a list of texts using OpenAI models.
238
+
239
+ Args:
240
+ model: Model specification (e.g., "openai:text-embedding-3-small")
241
+ texts: List of texts to embed
242
+
243
+ Returns:
244
+ Tuple of (embeddings_list, total_cost)
245
+ """
246
+ llm = LLMDescriptor.parse(model)
247
+
248
+ if llm.provider != Provider.OPENAI:
249
+ raise ValueError(
250
+ f"Only OpenAI embedding models are supported, got {llm.provider}")
251
+
252
+ client = _get_client(llm.provider)
253
+ return await _openai_get_embeddings(client, llm, texts)
254
+
255
+
256
+ async def _openai_get_embeddings(
257
+ client,
258
+ llm: LLMDescriptor,
259
+ texts: list[str],
260
+ ) -> tuple[list[list[float]], Optional[float]]:
261
+ """OpenAI embeddings implementation."""
262
+ response = await client.embeddings.create(
263
+ model=llm.model,
264
+ input=texts,
265
+ encoding_format="float"
266
+ )
267
+
268
+ embeddings = [data.embedding for data in response.data]
269
+ cost = _calculate_embedding_cost(llm, response.usage)
270
+
271
+ return embeddings, cost
272
+
273
+
274
+ def _calculate_embedding_cost(llm: LLMDescriptor, usage) -> Optional[float]:
275
+ """Calculate the cost of embedding usage for OpenAI models."""
276
+ if not usage:
277
+ return None
278
+
279
+ price = model_pricing.get(llm.model)
280
+ if not price:
281
+ return None
282
+
283
+ total_tokens = getattr(usage, 'total_tokens', 0)
284
+ input_price = price.get("input", 0)
285
+
286
+ return round(total_tokens * input_price / 1_000_000, 6)