sigma-terminal 2.0.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sigma/core/llm.py DELETED
@@ -1,794 +0,0 @@
1
- """LLM provider implementations."""
2
-
3
- import json
4
- from abc import ABC, abstractmethod
5
- from typing import Any, AsyncIterator, Optional
6
-
7
- import httpx
8
-
9
- from sigma.core.config import LLMProvider, get_settings
10
- from sigma.core.models import Message, MessageRole, ToolCall
11
-
12
-
13
- class BaseLLM(ABC):
14
- """Base LLM provider."""
15
-
16
- def __init__(self, model: Optional[str] = None):
17
- self.settings = get_settings()
18
- self.model = model or self.settings.get_model(self.provider)
19
-
20
- @property
21
- @abstractmethod
22
- def provider(self) -> LLMProvider:
23
- """Provider type."""
24
- pass
25
-
26
- @abstractmethod
27
- async def generate(
28
- self,
29
- messages: list[Message],
30
- tools: Optional[list[dict]] = None,
31
- temperature: Optional[float] = None,
32
- ) -> tuple[str, list[ToolCall]]:
33
- """Generate response."""
34
- pass
35
-
36
- @abstractmethod
37
- async def stream(
38
- self,
39
- messages: list[Message],
40
- tools: Optional[list[dict]] = None,
41
- temperature: Optional[float] = None,
42
- ) -> AsyncIterator[str]:
43
- """Stream response."""
44
- pass
45
-
46
-
47
- class OpenAILLM(BaseLLM):
48
- """OpenAI provider."""
49
-
50
- @property
51
- def provider(self) -> LLMProvider:
52
- return LLMProvider.OPENAI
53
-
54
- def _convert_messages(self, messages: list[Message]) -> list[dict]:
55
- """Convert to OpenAI format."""
56
- result = []
57
- for msg in messages:
58
- m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
59
- if msg.tool_calls:
60
- m["tool_calls"] = [
61
- {
62
- "id": tc.id,
63
- "type": "function",
64
- "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}
65
- }
66
- for tc in msg.tool_calls
67
- ]
68
- if msg.tool_call_id:
69
- m["tool_call_id"] = msg.tool_call_id
70
- if msg.name:
71
- m["name"] = msg.name
72
- result.append(m)
73
- return result
74
-
75
- async def generate(
76
- self,
77
- messages: list[Message],
78
- tools: Optional[list[dict]] = None,
79
- temperature: Optional[float] = None,
80
- ) -> tuple[str, list[ToolCall]]:
81
- api_key = self.settings.get_api_key(LLMProvider.OPENAI)
82
- if not api_key:
83
- raise ValueError("OpenAI API key not set")
84
-
85
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
86
- data: dict[str, Any] = {
87
- "model": self.model,
88
- "messages": self._convert_messages(messages),
89
- "temperature": temperature or self.settings.temperature,
90
- "max_tokens": self.settings.max_tokens,
91
- }
92
-
93
- if tools:
94
- data["tools"] = [{"type": "function", "function": t} for t in tools]
95
- data["tool_choice"] = "auto"
96
-
97
- async with httpx.AsyncClient(timeout=120.0) as client:
98
- resp = await client.post(
99
- "https://api.openai.com/v1/chat/completions",
100
- headers=headers,
101
- json=data,
102
- )
103
- resp.raise_for_status()
104
- result = resp.json()
105
-
106
- choice = result["choices"][0]
107
- msg = choice["message"]
108
- content = msg.get("content", "") or ""
109
-
110
- tool_calls = []
111
- if "tool_calls" in msg and msg["tool_calls"]:
112
- for tc in msg["tool_calls"]:
113
- tool_calls.append(ToolCall(
114
- id=tc["id"],
115
- name=tc["function"]["name"],
116
- arguments=json.loads(tc["function"]["arguments"]),
117
- ))
118
-
119
- return content, tool_calls
120
-
121
- async def stream(
122
- self,
123
- messages: list[Message],
124
- tools: Optional[list[dict]] = None,
125
- temperature: Optional[float] = None,
126
- ) -> AsyncIterator[str]:
127
- api_key = self.settings.get_api_key(LLMProvider.OPENAI)
128
- if not api_key:
129
- raise ValueError("OpenAI API key not set")
130
-
131
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
132
- data: dict[str, Any] = {
133
- "model": self.model,
134
- "messages": self._convert_messages(messages),
135
- "temperature": temperature or self.settings.temperature,
136
- "max_tokens": self.settings.max_tokens,
137
- "stream": True,
138
- }
139
-
140
- async with httpx.AsyncClient(timeout=120.0) as client:
141
- async with client.stream(
142
- "POST",
143
- "https://api.openai.com/v1/chat/completions",
144
- headers=headers,
145
- json=data,
146
- ) as resp:
147
- async for line in resp.aiter_lines():
148
- if line.startswith("data: "):
149
- payload = line[6:]
150
- if payload == "[DONE]":
151
- break
152
- chunk = json.loads(payload)
153
- delta = chunk["choices"][0].get("delta", {})
154
- if "content" in delta and delta["content"]:
155
- yield delta["content"]
156
-
157
-
158
- class AnthropicLLM(BaseLLM):
159
- """Anthropic provider."""
160
-
161
- @property
162
- def provider(self) -> LLMProvider:
163
- return LLMProvider.ANTHROPIC
164
-
165
- def _convert_messages(self, messages: list[Message]) -> tuple[Optional[str], list[dict]]:
166
- """Convert to Anthropic format."""
167
- system = None
168
- result = []
169
-
170
- for msg in messages:
171
- if msg.role == MessageRole.SYSTEM:
172
- system = msg.content
173
- continue
174
-
175
- if msg.role == MessageRole.TOOL:
176
- result.append({
177
- "role": "user",
178
- "content": [{
179
- "type": "tool_result",
180
- "tool_use_id": msg.tool_call_id,
181
- "content": msg.content,
182
- }]
183
- })
184
- elif msg.tool_calls:
185
- content: list[dict[str, Any]] = []
186
- if msg.content:
187
- content.append({"type": "text", "text": msg.content})
188
- for tc in msg.tool_calls:
189
- content.append({
190
- "type": "tool_use",
191
- "id": tc.id,
192
- "name": tc.name,
193
- "input": tc.arguments,
194
- })
195
- result.append({"role": "assistant", "content": content})
196
- else:
197
- result.append({"role": msg.role.value, "content": msg.content})
198
-
199
- return system, result
200
-
201
- async def generate(
202
- self,
203
- messages: list[Message],
204
- tools: Optional[list[dict]] = None,
205
- temperature: Optional[float] = None,
206
- ) -> tuple[str, list[ToolCall]]:
207
- api_key = self.settings.get_api_key(LLMProvider.ANTHROPIC)
208
- if not api_key:
209
- raise ValueError("Anthropic API key not set")
210
-
211
- headers = {
212
- "x-api-key": api_key,
213
- "Content-Type": "application/json",
214
- "anthropic-version": "2023-06-01",
215
- }
216
-
217
- system, msgs = self._convert_messages(messages)
218
- data: dict[str, Any] = {
219
- "model": self.model,
220
- "messages": msgs,
221
- "max_tokens": self.settings.max_tokens,
222
- "temperature": temperature or self.settings.temperature,
223
- }
224
-
225
- if system:
226
- data["system"] = system
227
-
228
- if tools:
229
- data["tools"] = [
230
- {
231
- "name": t["name"],
232
- "description": t.get("description", ""),
233
- "input_schema": t.get("parameters", {}),
234
- }
235
- for t in tools
236
- ]
237
-
238
- async with httpx.AsyncClient(timeout=120.0) as client:
239
- resp = await client.post(
240
- "https://api.anthropic.com/v1/messages",
241
- headers=headers,
242
- json=data,
243
- )
244
- resp.raise_for_status()
245
- result = resp.json()
246
-
247
- content = ""
248
- tool_calls = []
249
-
250
- for block in result.get("content", []):
251
- if block["type"] == "text":
252
- content += block["text"]
253
- elif block["type"] == "tool_use":
254
- tool_calls.append(ToolCall(
255
- id=block["id"],
256
- name=block["name"],
257
- arguments=block["input"],
258
- ))
259
-
260
- return content, tool_calls
261
-
262
- async def stream(
263
- self,
264
- messages: list[Message],
265
- tools: Optional[list[dict]] = None,
266
- temperature: Optional[float] = None,
267
- ) -> AsyncIterator[str]:
268
- api_key = self.settings.get_api_key(LLMProvider.ANTHROPIC)
269
- if not api_key:
270
- raise ValueError("Anthropic API key not set")
271
-
272
- headers = {
273
- "x-api-key": api_key,
274
- "Content-Type": "application/json",
275
- "anthropic-version": "2023-06-01",
276
- }
277
-
278
- system, msgs = self._convert_messages(messages)
279
- data: dict[str, Any] = {
280
- "model": self.model,
281
- "messages": msgs,
282
- "max_tokens": self.settings.max_tokens,
283
- "temperature": temperature or self.settings.temperature,
284
- "stream": True,
285
- }
286
-
287
- if system:
288
- data["system"] = system
289
-
290
- async with httpx.AsyncClient(timeout=120.0) as client:
291
- async with client.stream(
292
- "POST",
293
- "https://api.anthropic.com/v1/messages",
294
- headers=headers,
295
- json=data,
296
- ) as resp:
297
- async for line in resp.aiter_lines():
298
- if line.startswith("data: "):
299
- event = json.loads(line[6:])
300
- if event["type"] == "content_block_delta":
301
- delta = event.get("delta", {})
302
- if delta.get("type") == "text_delta":
303
- yield delta.get("text", "")
304
-
305
-
306
- class GoogleLLM(BaseLLM):
307
- """Google Gemini provider using REST API."""
308
-
309
- @property
310
- def provider(self) -> LLMProvider:
311
- return LLMProvider.GOOGLE
312
-
313
- def _convert_messages(self, messages: list[Message]) -> tuple[Optional[str], list[dict]]:
314
- """Convert to Gemini format."""
315
- system = None
316
- contents = []
317
-
318
- for msg in messages:
319
- if msg.role == MessageRole.SYSTEM:
320
- system = msg.content
321
- continue
322
-
323
- role = "model" if msg.role == MessageRole.ASSISTANT else "user"
324
-
325
- if msg.role == MessageRole.TOOL:
326
- contents.append({
327
- "role": "user",
328
- "parts": [{
329
- "functionResponse": {
330
- "name": msg.name or "tool",
331
- "response": {"result": msg.content}
332
- }
333
- }]
334
- })
335
- elif msg.tool_calls:
336
- parts: list[dict[str, Any]] = []
337
- if msg.content:
338
- parts.append({"text": msg.content})
339
- for tc in msg.tool_calls:
340
- parts.append({
341
- "functionCall": {
342
- "name": tc.name,
343
- "args": tc.arguments
344
- }
345
- })
346
- contents.append({"role": role, "parts": parts})
347
- else:
348
- contents.append({
349
- "role": role,
350
- "parts": [{"text": msg.content}]
351
- })
352
-
353
- return system, contents
354
-
355
- def _convert_tools(self, tools: list[dict]) -> list[dict]:
356
- """Convert tools to Gemini format."""
357
- declarations = []
358
- for t in tools:
359
- decl: dict[str, Any] = {
360
- "name": t["name"],
361
- "description": t.get("description", ""),
362
- }
363
- if "parameters" in t and t["parameters"]:
364
- params = t["parameters"].copy()
365
- # Gemini doesn't want 'additionalProperties'
366
- params.pop("additionalProperties", None)
367
- decl["parameters"] = params
368
- declarations.append(decl)
369
- return [{"functionDeclarations": declarations}]
370
-
371
- async def generate(
372
- self,
373
- messages: list[Message],
374
- tools: Optional[list[dict]] = None,
375
- temperature: Optional[float] = None,
376
- ) -> tuple[str, list[ToolCall]]:
377
- api_key = self.settings.get_api_key(LLMProvider.GOOGLE)
378
- if not api_key:
379
- raise ValueError("Google API key not set")
380
-
381
- system, contents = self._convert_messages(messages)
382
-
383
- data: dict[str, Any] = {
384
- "contents": contents,
385
- "generationConfig": {
386
- "temperature": temperature or self.settings.temperature,
387
- "maxOutputTokens": self.settings.max_tokens,
388
- }
389
- }
390
-
391
- if system:
392
- data["systemInstruction"] = {"parts": [{"text": system}]}
393
-
394
- if tools:
395
- data["tools"] = self._convert_tools(tools)
396
-
397
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent?key={api_key}"
398
-
399
- async with httpx.AsyncClient(timeout=120.0) as client:
400
- resp = await client.post(url, json=data)
401
- resp.raise_for_status()
402
- result = resp.json()
403
-
404
- content = ""
405
- tool_calls = []
406
-
407
- candidates = result.get("candidates", [])
408
- if candidates:
409
- parts = candidates[0].get("content", {}).get("parts", [])
410
- for i, part in enumerate(parts):
411
- if "text" in part:
412
- content += part["text"]
413
- elif "functionCall" in part:
414
- fc = part["functionCall"]
415
- tool_calls.append(ToolCall(
416
- id=f"call_{i}",
417
- name=fc["name"],
418
- arguments=fc.get("args", {}),
419
- ))
420
-
421
- return content, tool_calls
422
-
423
- async def stream(
424
- self,
425
- messages: list[Message],
426
- tools: Optional[list[dict]] = None,
427
- temperature: Optional[float] = None,
428
- ) -> AsyncIterator[str]:
429
- api_key = self.settings.get_api_key(LLMProvider.GOOGLE)
430
- if not api_key:
431
- raise ValueError("Google API key not set")
432
-
433
- system, contents = self._convert_messages(messages)
434
-
435
- data: dict[str, Any] = {
436
- "contents": contents,
437
- "generationConfig": {
438
- "temperature": temperature or self.settings.temperature,
439
- "maxOutputTokens": self.settings.max_tokens,
440
- }
441
- }
442
-
443
- if system:
444
- data["systemInstruction"] = {"parts": [{"text": system}]}
445
-
446
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:streamGenerateContent?key={api_key}&alt=sse"
447
-
448
- async with httpx.AsyncClient(timeout=120.0) as client:
449
- async with client.stream("POST", url, json=data) as resp:
450
- async for line in resp.aiter_lines():
451
- if line.startswith("data: "):
452
- chunk = json.loads(line[6:])
453
- candidates = chunk.get("candidates", [])
454
- if candidates:
455
- parts = candidates[0].get("content", {}).get("parts", [])
456
- for part in parts:
457
- if "text" in part:
458
- yield part["text"]
459
-
460
-
461
- class OllamaLLM(BaseLLM):
462
- """Ollama provider."""
463
-
464
- @property
465
- def provider(self) -> LLMProvider:
466
- return LLMProvider.OLLAMA
467
-
468
- def _convert_messages(self, messages: list[Message]) -> list[dict]:
469
- """Convert to Ollama format."""
470
- result = []
471
- for msg in messages:
472
- m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
473
- if msg.tool_calls:
474
- m["tool_calls"] = [
475
- {
476
- "id": tc.id,
477
- "type": "function",
478
- "function": {"name": tc.name, "arguments": tc.arguments}
479
- }
480
- for tc in msg.tool_calls
481
- ]
482
- result.append(m)
483
- return result
484
-
485
- async def generate(
486
- self,
487
- messages: list[Message],
488
- tools: Optional[list[dict]] = None,
489
- temperature: Optional[float] = None,
490
- ) -> tuple[str, list[ToolCall]]:
491
- data: dict[str, Any] = {
492
- "model": self.model,
493
- "messages": self._convert_messages(messages),
494
- "stream": False,
495
- "options": {
496
- "temperature": temperature or self.settings.temperature,
497
- "num_predict": self.settings.max_tokens,
498
- }
499
- }
500
-
501
- if tools:
502
- data["tools"] = [{"type": "function", "function": t} for t in tools]
503
-
504
- async with httpx.AsyncClient(timeout=300.0) as client:
505
- resp = await client.post(
506
- f"{self.settings.ollama_base_url}/api/chat",
507
- json=data,
508
- )
509
- resp.raise_for_status()
510
- result = resp.json()
511
-
512
- msg = result.get("message", {})
513
- content = msg.get("content", "") or ""
514
-
515
- tool_calls = []
516
- if "tool_calls" in msg and msg["tool_calls"]:
517
- for i, tc in enumerate(msg["tool_calls"]):
518
- fn = tc.get("function", {})
519
- tool_calls.append(ToolCall(
520
- id=f"call_{i}",
521
- name=fn.get("name", ""),
522
- arguments=fn.get("arguments", {}),
523
- ))
524
-
525
- return content, tool_calls
526
-
527
- async def stream(
528
- self,
529
- messages: list[Message],
530
- tools: Optional[list[dict]] = None,
531
- temperature: Optional[float] = None,
532
- ) -> AsyncIterator[str]:
533
- data: dict[str, Any] = {
534
- "model": self.model,
535
- "messages": self._convert_messages(messages),
536
- "stream": True,
537
- "options": {
538
- "temperature": temperature or self.settings.temperature,
539
- "num_predict": self.settings.max_tokens,
540
- }
541
- }
542
-
543
- async with httpx.AsyncClient(timeout=300.0) as client:
544
- async with client.stream(
545
- "POST",
546
- f"{self.settings.ollama_base_url}/api/chat",
547
- json=data,
548
- ) as resp:
549
- async for line in resp.aiter_lines():
550
- if line:
551
- chunk = json.loads(line)
552
- msg = chunk.get("message", {})
553
- if "content" in msg:
554
- yield msg["content"]
555
-
556
-
557
- class GroqLLM(BaseLLM):
558
- """Groq provider."""
559
-
560
- @property
561
- def provider(self) -> LLMProvider:
562
- return LLMProvider.GROQ
563
-
564
- def _convert_messages(self, messages: list[Message]) -> list[dict]:
565
- """Convert to Groq format (OpenAI compatible)."""
566
- result = []
567
- for msg in messages:
568
- m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
569
- if msg.tool_calls:
570
- m["tool_calls"] = [
571
- {
572
- "id": tc.id,
573
- "type": "function",
574
- "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}
575
- }
576
- for tc in msg.tool_calls
577
- ]
578
- if msg.tool_call_id:
579
- m["tool_call_id"] = msg.tool_call_id
580
- if msg.name:
581
- m["name"] = msg.name
582
- result.append(m)
583
- return result
584
-
585
- async def generate(
586
- self,
587
- messages: list[Message],
588
- tools: Optional[list[dict]] = None,
589
- temperature: Optional[float] = None,
590
- ) -> tuple[str, list[ToolCall]]:
591
- api_key = self.settings.get_api_key(LLMProvider.GROQ)
592
- if not api_key:
593
- raise ValueError("Groq API key not set")
594
-
595
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
596
- data: dict[str, Any] = {
597
- "model": self.model,
598
- "messages": self._convert_messages(messages),
599
- "temperature": temperature or self.settings.temperature,
600
- "max_tokens": self.settings.max_tokens,
601
- }
602
-
603
- if tools:
604
- data["tools"] = [{"type": "function", "function": t} for t in tools]
605
- data["tool_choice"] = "auto"
606
-
607
- async with httpx.AsyncClient(timeout=120.0) as client:
608
- resp = await client.post(
609
- "https://api.groq.com/openai/v1/chat/completions",
610
- headers=headers,
611
- json=data,
612
- )
613
- resp.raise_for_status()
614
- result = resp.json()
615
-
616
- choice = result["choices"][0]
617
- msg = choice["message"]
618
- content = msg.get("content", "") or ""
619
-
620
- tool_calls = []
621
- if "tool_calls" in msg and msg["tool_calls"]:
622
- for tc in msg["tool_calls"]:
623
- tool_calls.append(ToolCall(
624
- id=tc["id"],
625
- name=tc["function"]["name"],
626
- arguments=json.loads(tc["function"]["arguments"]),
627
- ))
628
-
629
- return content, tool_calls
630
-
631
- async def stream(
632
- self,
633
- messages: list[Message],
634
- tools: Optional[list[dict]] = None,
635
- temperature: Optional[float] = None,
636
- ) -> AsyncIterator[str]:
637
- api_key = self.settings.get_api_key(LLMProvider.GROQ)
638
- if not api_key:
639
- raise ValueError("Groq API key not set")
640
-
641
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
642
- data: dict[str, Any] = {
643
- "model": self.model,
644
- "messages": self._convert_messages(messages),
645
- "temperature": temperature or self.settings.temperature,
646
- "max_tokens": self.settings.max_tokens,
647
- "stream": True,
648
- }
649
-
650
- async with httpx.AsyncClient(timeout=120.0) as client:
651
- async with client.stream(
652
- "POST",
653
- "https://api.groq.com/openai/v1/chat/completions",
654
- headers=headers,
655
- json=data,
656
- ) as resp:
657
- async for line in resp.aiter_lines():
658
- if line.startswith("data: "):
659
- payload = line[6:]
660
- if payload == "[DONE]":
661
- break
662
- chunk = json.loads(payload)
663
- delta = chunk["choices"][0].get("delta", {})
664
- if "content" in delta and delta["content"]:
665
- yield delta["content"]
666
-
667
-
668
- class XaiLLM(BaseLLM):
669
- """xAI Grok provider."""
670
-
671
- @property
672
- def provider(self) -> LLMProvider:
673
- return LLMProvider.XAI
674
-
675
- def _convert_messages(self, messages: list[Message]) -> list[dict]:
676
- """Convert to xAI format (OpenAI compatible)."""
677
- result = []
678
- for msg in messages:
679
- m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
680
- if msg.tool_calls:
681
- m["tool_calls"] = [
682
- {
683
- "id": tc.id,
684
- "type": "function",
685
- "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}
686
- }
687
- for tc in msg.tool_calls
688
- ]
689
- if msg.tool_call_id:
690
- m["tool_call_id"] = msg.tool_call_id
691
- result.append(m)
692
- return result
693
-
694
- async def generate(
695
- self,
696
- messages: list[Message],
697
- tools: Optional[list[dict]] = None,
698
- temperature: Optional[float] = None,
699
- ) -> tuple[str, list[ToolCall]]:
700
- api_key = self.settings.get_api_key(LLMProvider.XAI)
701
- if not api_key:
702
- raise ValueError("xAI API key not set")
703
-
704
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
705
- data: dict[str, Any] = {
706
- "model": self.model,
707
- "messages": self._convert_messages(messages),
708
- "temperature": temperature or self.settings.temperature,
709
- "max_tokens": self.settings.max_tokens,
710
- }
711
-
712
- if tools:
713
- data["tools"] = [{"type": "function", "function": t} for t in tools]
714
-
715
- async with httpx.AsyncClient(timeout=120.0) as client:
716
- resp = await client.post(
717
- "https://api.x.ai/v1/chat/completions",
718
- headers=headers,
719
- json=data,
720
- )
721
- resp.raise_for_status()
722
- result = resp.json()
723
-
724
- choice = result["choices"][0]
725
- msg = choice["message"]
726
- content = msg.get("content", "") or ""
727
-
728
- tool_calls = []
729
- if "tool_calls" in msg and msg["tool_calls"]:
730
- for tc in msg["tool_calls"]:
731
- tool_calls.append(ToolCall(
732
- id=tc["id"],
733
- name=tc["function"]["name"],
734
- arguments=json.loads(tc["function"]["arguments"]),
735
- ))
736
-
737
- return content, tool_calls
738
-
739
- async def stream(
740
- self,
741
- messages: list[Message],
742
- tools: Optional[list[dict]] = None,
743
- temperature: Optional[float] = None,
744
- ) -> AsyncIterator[str]:
745
- api_key = self.settings.get_api_key(LLMProvider.XAI)
746
- if not api_key:
747
- raise ValueError("xAI API key not set")
748
-
749
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
750
- data: dict[str, Any] = {
751
- "model": self.model,
752
- "messages": self._convert_messages(messages),
753
- "temperature": temperature or self.settings.temperature,
754
- "max_tokens": self.settings.max_tokens,
755
- "stream": True,
756
- }
757
-
758
- async with httpx.AsyncClient(timeout=120.0) as client:
759
- async with client.stream(
760
- "POST",
761
- "https://api.x.ai/v1/chat/completions",
762
- headers=headers,
763
- json=data,
764
- ) as resp:
765
- async for line in resp.aiter_lines():
766
- if line.startswith("data: "):
767
- payload = line[6:]
768
- if payload == "[DONE]":
769
- break
770
- chunk = json.loads(payload)
771
- delta = chunk["choices"][0].get("delta", {})
772
- if "content" in delta and delta["content"]:
773
- yield delta["content"]
774
-
775
-
776
- def get_llm(provider: Optional[LLMProvider] = None, model: Optional[str] = None) -> BaseLLM:
777
- """Get LLM instance."""
778
- settings = get_settings()
779
- provider = provider or settings.default_provider
780
-
781
- providers = {
782
- LLMProvider.OPENAI: OpenAILLM,
783
- LLMProvider.ANTHROPIC: AnthropicLLM,
784
- LLMProvider.GOOGLE: GoogleLLM,
785
- LLMProvider.OLLAMA: OllamaLLM,
786
- LLMProvider.GROQ: GroqLLM,
787
- LLMProvider.XAI: XaiLLM,
788
- }
789
-
790
- cls = providers.get(provider)
791
- if not cls:
792
- raise ValueError(f"Unknown provider: {provider}")
793
-
794
- return cls(model=model)