kader 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/README.md +169 -0
- cli/__init__.py +5 -0
- cli/__main__.py +6 -0
- cli/app.py +707 -0
- cli/app.tcss +664 -0
- cli/utils.py +68 -0
- cli/widgets/__init__.py +13 -0
- cli/widgets/confirmation.py +309 -0
- cli/widgets/conversation.py +55 -0
- cli/widgets/loading.py +59 -0
- kader/__init__.py +22 -0
- kader/agent/__init__.py +8 -0
- kader/agent/agents.py +126 -0
- kader/agent/base.py +927 -0
- kader/agent/logger.py +170 -0
- kader/config.py +139 -0
- kader/memory/__init__.py +66 -0
- kader/memory/conversation.py +409 -0
- kader/memory/session.py +385 -0
- kader/memory/state.py +211 -0
- kader/memory/types.py +116 -0
- kader/prompts/__init__.py +9 -0
- kader/prompts/agent_prompts.py +27 -0
- kader/prompts/base.py +81 -0
- kader/prompts/templates/planning_agent.j2 +26 -0
- kader/prompts/templates/react_agent.j2 +18 -0
- kader/providers/__init__.py +9 -0
- kader/providers/base.py +581 -0
- kader/providers/mock.py +96 -0
- kader/providers/ollama.py +447 -0
- kader/tools/README.md +483 -0
- kader/tools/__init__.py +130 -0
- kader/tools/base.py +955 -0
- kader/tools/exec_commands.py +249 -0
- kader/tools/filesys.py +650 -0
- kader/tools/filesystem.py +607 -0
- kader/tools/protocol.py +456 -0
- kader/tools/rag.py +555 -0
- kader/tools/todo.py +210 -0
- kader/tools/utils.py +456 -0
- kader/tools/web.py +246 -0
- kader-0.1.5.dist-info/METADATA +321 -0
- kader-0.1.5.dist-info/RECORD +45 -0
- kader-0.1.5.dist-info/WHEEL +4 -0
- kader-0.1.5.dist-info/entry_points.txt +2 -0
kader/providers/base.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base class for LLM Providers.
|
|
3
|
+
|
|
4
|
+
A versatile, provider-agnostic base class for LLM interactions supporting
|
|
5
|
+
OpenAI, Google, Anthropic, Mistral, and other providers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
AsyncIterator,
|
|
14
|
+
Iterator,
|
|
15
|
+
Literal,
|
|
16
|
+
TypeAlias,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Type Aliases
|
|
20
|
+
Role: TypeAlias = Literal["system", "user", "assistant", "tool"]
|
|
21
|
+
FinishReason: TypeAlias = Literal[
|
|
22
|
+
"stop", "length", "tool_calls", "content_filter", "error", None
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MessageRole(str, Enum):
|
|
27
|
+
"""Enumeration of message roles for type safety."""
|
|
28
|
+
|
|
29
|
+
SYSTEM = "system"
|
|
30
|
+
USER = "user"
|
|
31
|
+
ASSISTANT = "assistant"
|
|
32
|
+
TOOL = "tool"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class Message:
|
|
37
|
+
"""Represents a chat message in a conversation."""
|
|
38
|
+
|
|
39
|
+
role: Role
|
|
40
|
+
content: str
|
|
41
|
+
name: str | None = None
|
|
42
|
+
tool_call_id: str | None = None
|
|
43
|
+
tool_calls: list[dict[str, Any]] | None = None
|
|
44
|
+
|
|
45
|
+
def to_dict(self) -> dict[str, Any]:
|
|
46
|
+
"""Convert message to dictionary format for API calls."""
|
|
47
|
+
data: dict[str, Any] = {
|
|
48
|
+
"role": self.role,
|
|
49
|
+
"content": self.content,
|
|
50
|
+
}
|
|
51
|
+
if self.name:
|
|
52
|
+
data["name"] = self.name
|
|
53
|
+
if self.tool_call_id:
|
|
54
|
+
data["tool_call_id"] = self.tool_call_id
|
|
55
|
+
if self.tool_calls:
|
|
56
|
+
data["tool_calls"] = self.tool_calls
|
|
57
|
+
return data
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def system(cls, content: str) -> "Message":
|
|
61
|
+
"""Create a system message."""
|
|
62
|
+
return cls(role="system", content=content)
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def user(cls, content: str) -> "Message":
|
|
66
|
+
"""Create a user message."""
|
|
67
|
+
return cls(role="user", content=content)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def assistant(cls, content: str) -> "Message":
|
|
71
|
+
"""Create an assistant message."""
|
|
72
|
+
return cls(role="assistant", content=content)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def tool(cls, tool_call_id: str, content: str) -> "Message":
|
|
76
|
+
"""Create a tool message."""
|
|
77
|
+
return cls(role="tool", tool_call_id=tool_call_id, content=content)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class Usage:
|
|
82
|
+
"""Tracks token usage for an LLM request."""
|
|
83
|
+
|
|
84
|
+
prompt_tokens: int = 0
|
|
85
|
+
completion_tokens: int = 0
|
|
86
|
+
total_tokens: int = 0
|
|
87
|
+
|
|
88
|
+
# Additional usage details (provider-specific)
|
|
89
|
+
cached_tokens: int = 0
|
|
90
|
+
reasoning_tokens: int = 0
|
|
91
|
+
|
|
92
|
+
def __post_init__(self) -> None:
|
|
93
|
+
"""Calculate total tokens if not provided."""
|
|
94
|
+
if self.total_tokens == 0:
|
|
95
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
|
96
|
+
|
|
97
|
+
def __add__(self, other: "Usage") -> "Usage":
|
|
98
|
+
"""Add two Usage instances together."""
|
|
99
|
+
return Usage(
|
|
100
|
+
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
|
101
|
+
completion_tokens=self.completion_tokens + other.completion_tokens,
|
|
102
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
|
103
|
+
cached_tokens=self.cached_tokens + other.cached_tokens,
|
|
104
|
+
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class CostInfo:
|
|
110
|
+
"""Cost breakdown for an LLM request."""
|
|
111
|
+
|
|
112
|
+
input_cost: float = 0.0
|
|
113
|
+
output_cost: float = 0.0
|
|
114
|
+
total_cost: float = 0.0
|
|
115
|
+
currency: str = "USD"
|
|
116
|
+
|
|
117
|
+
# Additional cost details
|
|
118
|
+
cached_input_cost: float = 0.0
|
|
119
|
+
|
|
120
|
+
def __post_init__(self) -> None:
|
|
121
|
+
"""Calculate total cost if not provided."""
|
|
122
|
+
if self.total_cost == 0.0:
|
|
123
|
+
self.total_cost = self.input_cost + self.output_cost
|
|
124
|
+
|
|
125
|
+
def __add__(self, other: "CostInfo") -> "CostInfo":
|
|
126
|
+
"""Add two CostInfo instances together."""
|
|
127
|
+
if self.currency != other.currency:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Cannot add costs with different currencies: {self.currency} vs {other.currency}"
|
|
130
|
+
)
|
|
131
|
+
return CostInfo(
|
|
132
|
+
input_cost=self.input_cost + other.input_cost,
|
|
133
|
+
output_cost=self.output_cost + other.output_cost,
|
|
134
|
+
total_cost=self.total_cost + other.total_cost,
|
|
135
|
+
currency=self.currency,
|
|
136
|
+
cached_input_cost=self.cached_input_cost + other.cached_input_cost,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def format(self, precision: int = 6) -> str:
|
|
140
|
+
"""Format cost as a readable string."""
|
|
141
|
+
return f"${self.total_cost:.{precision}f} {self.currency}"
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class ModelConfig:
|
|
146
|
+
"""Configuration for model inference parameters."""
|
|
147
|
+
|
|
148
|
+
# Core parameters
|
|
149
|
+
temperature: float = 1.0
|
|
150
|
+
max_tokens: int | None = None
|
|
151
|
+
top_p: float = 1.0
|
|
152
|
+
|
|
153
|
+
# Sampling parameters
|
|
154
|
+
top_k: int | None = None
|
|
155
|
+
frequency_penalty: float = 0.0
|
|
156
|
+
presence_penalty: float = 0.0
|
|
157
|
+
|
|
158
|
+
# Stop sequences
|
|
159
|
+
stop_sequences: list[str] | None = None
|
|
160
|
+
|
|
161
|
+
# Streaming
|
|
162
|
+
stream: bool = False
|
|
163
|
+
|
|
164
|
+
# Tool/Function calling
|
|
165
|
+
tools: list[dict[str, Any]] | None = None
|
|
166
|
+
tool_choice: str | dict[str, Any] | None = None
|
|
167
|
+
|
|
168
|
+
# Response format
|
|
169
|
+
response_format: dict[str, Any] | None = None
|
|
170
|
+
|
|
171
|
+
# Seed for reproducibility
|
|
172
|
+
seed: int | None = None
|
|
173
|
+
|
|
174
|
+
# Additional provider-specific parameters
|
|
175
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
|
176
|
+
|
|
177
|
+
def to_dict(self) -> dict[str, Any]:
|
|
178
|
+
"""Convert config to dictionary, excluding None values."""
|
|
179
|
+
data: dict[str, Any] = {}
|
|
180
|
+
|
|
181
|
+
if self.temperature != 1.0:
|
|
182
|
+
data["temperature"] = self.temperature
|
|
183
|
+
if self.max_tokens is not None:
|
|
184
|
+
data["max_tokens"] = self.max_tokens
|
|
185
|
+
if self.top_p != 1.0:
|
|
186
|
+
data["top_p"] = self.top_p
|
|
187
|
+
if self.top_k is not None:
|
|
188
|
+
data["top_k"] = self.top_k
|
|
189
|
+
if self.frequency_penalty != 0.0:
|
|
190
|
+
data["frequency_penalty"] = self.frequency_penalty
|
|
191
|
+
if self.presence_penalty != 0.0:
|
|
192
|
+
data["presence_penalty"] = self.presence_penalty
|
|
193
|
+
if self.stop_sequences:
|
|
194
|
+
data["stop"] = self.stop_sequences
|
|
195
|
+
if self.tools:
|
|
196
|
+
data["tools"] = self.tools
|
|
197
|
+
if self.tool_choice is not None:
|
|
198
|
+
data["tool_choice"] = self.tool_choice
|
|
199
|
+
if self.response_format is not None:
|
|
200
|
+
data["response_format"] = self.response_format
|
|
201
|
+
if self.seed is not None:
|
|
202
|
+
data["seed"] = self.seed
|
|
203
|
+
|
|
204
|
+
# Merge extra parameters
|
|
205
|
+
data.update(self.extra)
|
|
206
|
+
|
|
207
|
+
return data
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@dataclass
|
|
211
|
+
class LLMResponse:
|
|
212
|
+
"""Complete response from an LLM provider."""
|
|
213
|
+
|
|
214
|
+
content: str
|
|
215
|
+
model: str
|
|
216
|
+
usage: Usage
|
|
217
|
+
finish_reason: FinishReason = None
|
|
218
|
+
|
|
219
|
+
# Cost information (optional, calculated if pricing is available)
|
|
220
|
+
cost: CostInfo | None = None
|
|
221
|
+
|
|
222
|
+
# Tool calls (if any)
|
|
223
|
+
tool_calls: list[dict[str, Any]] | None = None
|
|
224
|
+
|
|
225
|
+
# Raw response from provider (for debugging/extension)
|
|
226
|
+
raw_response: Any = None
|
|
227
|
+
|
|
228
|
+
# Additional metadata
|
|
229
|
+
id: str | None = None
|
|
230
|
+
created: int | None = None
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def has_tool_calls(self) -> bool:
|
|
234
|
+
"""Check if response contains tool calls."""
|
|
235
|
+
return self.tool_calls is not None and len(self.tool_calls) > 0
|
|
236
|
+
|
|
237
|
+
def to_message(self) -> Message:
|
|
238
|
+
"""Convert response to an assistant message."""
|
|
239
|
+
return Message(
|
|
240
|
+
role="assistant",
|
|
241
|
+
content=self.content,
|
|
242
|
+
tool_calls=self.tool_calls,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@dataclass
|
|
247
|
+
class StreamChunk:
|
|
248
|
+
"""A chunk from a streaming LLM response."""
|
|
249
|
+
|
|
250
|
+
content: str = ""
|
|
251
|
+
delta: str = ""
|
|
252
|
+
finish_reason: FinishReason = None
|
|
253
|
+
|
|
254
|
+
# Partial usage (available at end of stream for some providers)
|
|
255
|
+
usage: Usage | None = None
|
|
256
|
+
|
|
257
|
+
# Tool call deltas
|
|
258
|
+
tool_calls: list[dict[str, Any]] | None = None
|
|
259
|
+
|
|
260
|
+
# Index of this chunk in the stream
|
|
261
|
+
index: int = 0
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def is_final(self) -> bool:
|
|
265
|
+
"""Check if this is the final chunk."""
|
|
266
|
+
return self.finish_reason is not None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@dataclass
|
|
270
|
+
class ModelPricing:
|
|
271
|
+
"""Pricing information for a model."""
|
|
272
|
+
|
|
273
|
+
input_cost_per_million: float # Cost per million input tokens
|
|
274
|
+
output_cost_per_million: float # Cost per million output tokens
|
|
275
|
+
cached_input_cost_per_million: float | None = (
|
|
276
|
+
None # Cached input cost (if supported)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def calculate_cost(self, usage: Usage) -> CostInfo:
|
|
280
|
+
"""Calculate cost from usage."""
|
|
281
|
+
input_cost = (usage.prompt_tokens / 1_000_000) * self.input_cost_per_million
|
|
282
|
+
output_cost = (
|
|
283
|
+
usage.completion_tokens / 1_000_000
|
|
284
|
+
) * self.output_cost_per_million
|
|
285
|
+
|
|
286
|
+
cached_cost = 0.0
|
|
287
|
+
if self.cached_input_cost_per_million and usage.cached_tokens > 0:
|
|
288
|
+
cached_cost = (
|
|
289
|
+
usage.cached_tokens / 1_000_000
|
|
290
|
+
) * self.cached_input_cost_per_million
|
|
291
|
+
|
|
292
|
+
return CostInfo(
|
|
293
|
+
input_cost=input_cost,
|
|
294
|
+
output_cost=output_cost,
|
|
295
|
+
cached_input_cost=cached_cost,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@dataclass
|
|
300
|
+
class ModelInfo:
|
|
301
|
+
"""Information about an LLM model."""
|
|
302
|
+
|
|
303
|
+
name: str
|
|
304
|
+
provider: str
|
|
305
|
+
context_window: int
|
|
306
|
+
max_output_tokens: int | None = None
|
|
307
|
+
pricing: ModelPricing | None = None
|
|
308
|
+
supports_vision: bool = False
|
|
309
|
+
supports_tools: bool = False
|
|
310
|
+
supports_json_mode: bool = False
|
|
311
|
+
supports_streaming: bool = True
|
|
312
|
+
|
|
313
|
+
# Additional capabilities
|
|
314
|
+
capabilities: dict[str, Any] = field(default_factory=dict)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class BaseLLMProvider(ABC):
|
|
318
|
+
"""
|
|
319
|
+
Abstract base class for LLM providers.
|
|
320
|
+
|
|
321
|
+
Provides a unified interface for interacting with various LLM providers
|
|
322
|
+
including OpenAI, Google, Anthropic, Mistral, and others.
|
|
323
|
+
|
|
324
|
+
Subclasses must implement:
|
|
325
|
+
- invoke: Synchronous single completion
|
|
326
|
+
- ainvoke: Asynchronous single completion
|
|
327
|
+
- stream: Synchronous streaming completion
|
|
328
|
+
- astream: Asynchronous streaming completion
|
|
329
|
+
- count_tokens: Count tokens in text/messages
|
|
330
|
+
- estimate_cost: Estimate cost from usage
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
model: str,
|
|
336
|
+
default_config: ModelConfig | None = None,
|
|
337
|
+
) -> None:
|
|
338
|
+
"""
|
|
339
|
+
Initialize the LLM provider.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
model: The model identifier to use
|
|
343
|
+
default_config: Default configuration for all requests
|
|
344
|
+
"""
|
|
345
|
+
self._model = model
|
|
346
|
+
self._default_config = default_config or ModelConfig()
|
|
347
|
+
self._total_usage = Usage()
|
|
348
|
+
self._total_cost = CostInfo()
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def model(self) -> str:
|
|
352
|
+
"""Get the current model identifier."""
|
|
353
|
+
return self._model
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def total_usage(self) -> Usage:
|
|
357
|
+
"""Get total token usage across all requests."""
|
|
358
|
+
return self._total_usage
|
|
359
|
+
|
|
360
|
+
@property
|
|
361
|
+
def total_cost(self) -> CostInfo:
|
|
362
|
+
"""Get total cost across all requests."""
|
|
363
|
+
return self._total_cost
|
|
364
|
+
|
|
365
|
+
def reset_tracking(self) -> None:
|
|
366
|
+
"""Reset usage and cost tracking."""
|
|
367
|
+
self._total_usage = Usage()
|
|
368
|
+
self._total_cost = CostInfo()
|
|
369
|
+
|
|
370
|
+
def _merge_config(self, config: ModelConfig | None) -> ModelConfig:
|
|
371
|
+
"""Merge provided config with defaults."""
|
|
372
|
+
if config is None:
|
|
373
|
+
return self._default_config
|
|
374
|
+
|
|
375
|
+
# Create a new config with merged values
|
|
376
|
+
return ModelConfig(
|
|
377
|
+
temperature=config.temperature
|
|
378
|
+
if config.temperature != 1.0
|
|
379
|
+
else self._default_config.temperature,
|
|
380
|
+
max_tokens=config.max_tokens or self._default_config.max_tokens,
|
|
381
|
+
top_p=config.top_p if config.top_p != 1.0 else self._default_config.top_p,
|
|
382
|
+
top_k=config.top_k or self._default_config.top_k,
|
|
383
|
+
frequency_penalty=config.frequency_penalty
|
|
384
|
+
if config.frequency_penalty != 0.0
|
|
385
|
+
else self._default_config.frequency_penalty,
|
|
386
|
+
presence_penalty=config.presence_penalty
|
|
387
|
+
if config.presence_penalty != 0.0
|
|
388
|
+
else self._default_config.presence_penalty,
|
|
389
|
+
stop_sequences=config.stop_sequences or self._default_config.stop_sequences,
|
|
390
|
+
stream=config.stream,
|
|
391
|
+
tools=config.tools or self._default_config.tools,
|
|
392
|
+
tool_choice=config.tool_choice or self._default_config.tool_choice,
|
|
393
|
+
response_format=config.response_format
|
|
394
|
+
or self._default_config.response_format,
|
|
395
|
+
seed=config.seed or self._default_config.seed,
|
|
396
|
+
extra={**self._default_config.extra, **config.extra},
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def _update_tracking(self, response: LLMResponse) -> None:
|
|
400
|
+
"""Update usage and cost tracking from a response."""
|
|
401
|
+
self._total_usage = self._total_usage + response.usage
|
|
402
|
+
self._total_usage.__post_init__() # Recalculate total_tokens
|
|
403
|
+
|
|
404
|
+
if response.cost:
|
|
405
|
+
self._total_cost = self._total_cost + response.cost
|
|
406
|
+
self._total_cost.__post_init__() # Recalculate total_cost
|
|
407
|
+
|
|
408
|
+
# -------------------------------------------------------------------------
|
|
409
|
+
# Abstract Methods - Must be implemented by subclasses
|
|
410
|
+
# -------------------------------------------------------------------------
|
|
411
|
+
|
|
412
|
+
@abstractmethod
|
|
413
|
+
def invoke(
|
|
414
|
+
self,
|
|
415
|
+
messages: list[Message],
|
|
416
|
+
config: ModelConfig | None = None,
|
|
417
|
+
) -> LLMResponse:
|
|
418
|
+
"""
|
|
419
|
+
Synchronously invoke the LLM with the given messages.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
messages: List of messages in the conversation
|
|
423
|
+
config: Optional configuration overrides
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
LLMResponse with the model's response
|
|
427
|
+
"""
|
|
428
|
+
...
|
|
429
|
+
|
|
430
|
+
@abstractmethod
|
|
431
|
+
async def ainvoke(
|
|
432
|
+
self,
|
|
433
|
+
messages: list[Message],
|
|
434
|
+
config: ModelConfig | None = None,
|
|
435
|
+
) -> LLMResponse:
|
|
436
|
+
"""
|
|
437
|
+
Asynchronously invoke the LLM with the given messages.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
messages: List of messages in the conversation
|
|
441
|
+
config: Optional configuration overrides
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
LLMResponse with the model's response
|
|
445
|
+
"""
|
|
446
|
+
...
|
|
447
|
+
|
|
448
|
+
@abstractmethod
|
|
449
|
+
def stream(
|
|
450
|
+
self,
|
|
451
|
+
messages: list[Message],
|
|
452
|
+
config: ModelConfig | None = None,
|
|
453
|
+
) -> Iterator[StreamChunk]:
|
|
454
|
+
"""
|
|
455
|
+
Synchronously stream the LLM response.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
messages: List of messages in the conversation
|
|
459
|
+
config: Optional configuration overrides
|
|
460
|
+
|
|
461
|
+
Yields:
|
|
462
|
+
StreamChunk objects as they arrive
|
|
463
|
+
"""
|
|
464
|
+
...
|
|
465
|
+
|
|
466
|
+
@abstractmethod
|
|
467
|
+
async def astream(
|
|
468
|
+
self,
|
|
469
|
+
messages: list[Message],
|
|
470
|
+
config: ModelConfig | None = None,
|
|
471
|
+
) -> AsyncIterator[StreamChunk]:
|
|
472
|
+
"""
|
|
473
|
+
Asynchronously stream the LLM response.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
messages: List of messages in the conversation
|
|
477
|
+
config: Optional configuration overrides
|
|
478
|
+
|
|
479
|
+
Yields:
|
|
480
|
+
StreamChunk objects as they arrive
|
|
481
|
+
"""
|
|
482
|
+
...
|
|
483
|
+
|
|
484
|
+
@abstractmethod
|
|
485
|
+
def count_tokens(
|
|
486
|
+
self,
|
|
487
|
+
text: str | list[Message],
|
|
488
|
+
) -> int:
|
|
489
|
+
"""
|
|
490
|
+
Count the number of tokens in the given text or messages.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
text: A string or list of messages to count tokens for
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
Number of tokens
|
|
497
|
+
"""
|
|
498
|
+
...
|
|
499
|
+
|
|
500
|
+
@abstractmethod
|
|
501
|
+
def estimate_cost(
|
|
502
|
+
self,
|
|
503
|
+
usage: Usage,
|
|
504
|
+
) -> CostInfo:
|
|
505
|
+
"""
|
|
506
|
+
Estimate the cost for the given token usage.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
usage: Token usage information
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
CostInfo with cost breakdown
|
|
513
|
+
"""
|
|
514
|
+
...
|
|
515
|
+
|
|
516
|
+
# -------------------------------------------------------------------------
|
|
517
|
+
# Concrete Methods - Can be overridden if needed
|
|
518
|
+
# -------------------------------------------------------------------------
|
|
519
|
+
|
|
520
|
+
def get_model_info(self) -> ModelInfo | None:
|
|
521
|
+
"""
|
|
522
|
+
Get information about the current model.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
ModelInfo if available, None otherwise
|
|
526
|
+
"""
|
|
527
|
+
return None
|
|
528
|
+
|
|
529
|
+
@classmethod
|
|
530
|
+
def get_supported_models(cls) -> list[str]:
|
|
531
|
+
"""
|
|
532
|
+
Get list of models supported by this provider.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
List of model identifiers
|
|
536
|
+
"""
|
|
537
|
+
return []
|
|
538
|
+
|
|
539
|
+
def validate_config(self, config: ModelConfig) -> bool:
|
|
540
|
+
"""
|
|
541
|
+
Validate the given configuration.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
config: Configuration to validate
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
True if valid, False otherwise
|
|
548
|
+
"""
|
|
549
|
+
if config.temperature < 0 or config.temperature > 2:
|
|
550
|
+
return False
|
|
551
|
+
if config.top_p < 0 or config.top_p > 1:
|
|
552
|
+
return False
|
|
553
|
+
if config.max_tokens is not None and config.max_tokens < 1:
|
|
554
|
+
return False
|
|
555
|
+
return True
|
|
556
|
+
|
|
557
|
+
def validate_messages(self, messages: list[Message]) -> bool:
|
|
558
|
+
"""
|
|
559
|
+
Validate the given messages.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
messages: Messages to validate
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
True if valid, False otherwise
|
|
566
|
+
"""
|
|
567
|
+
if not messages:
|
|
568
|
+
return False
|
|
569
|
+
|
|
570
|
+
valid_roles = {"system", "user", "assistant", "tool"}
|
|
571
|
+
for msg in messages:
|
|
572
|
+
if msg.role not in valid_roles:
|
|
573
|
+
return False
|
|
574
|
+
if not msg.content and not msg.tool_calls:
|
|
575
|
+
return False
|
|
576
|
+
|
|
577
|
+
return True
|
|
578
|
+
|
|
579
|
+
def __repr__(self) -> str:
|
|
580
|
+
"""String representation of the provider."""
|
|
581
|
+
return f"{self.__class__.__name__}(model='{self._model}')"
|
kader/providers/mock.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mock LLM Provider for testing and development.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import AsyncIterator, Iterator, List
|
|
6
|
+
|
|
7
|
+
from .base import (
|
|
8
|
+
BaseLLMProvider,
|
|
9
|
+
CostInfo,
|
|
10
|
+
LLMResponse,
|
|
11
|
+
Message,
|
|
12
|
+
ModelConfig,
|
|
13
|
+
StreamChunk,
|
|
14
|
+
Usage,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MockLLM(BaseLLMProvider):
|
|
19
|
+
"""
|
|
20
|
+
A mock LLM provider that echoes inputs or returns predefined responses.
|
|
21
|
+
Useful for testing without incurring costs or latency.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def invoke(
|
|
25
|
+
self,
|
|
26
|
+
messages: List[Message],
|
|
27
|
+
config: ModelConfig | None = None,
|
|
28
|
+
) -> LLMResponse:
|
|
29
|
+
"""Synchronous mock invocation."""
|
|
30
|
+
last_msg = messages[-1] if messages else Message.user("")
|
|
31
|
+
content = f"Mock response to: {last_msg.content}"
|
|
32
|
+
|
|
33
|
+
usage = Usage(prompt_tokens=10, completion_tokens=10)
|
|
34
|
+
|
|
35
|
+
return LLMResponse(
|
|
36
|
+
content=content, model=self.model, usage=usage, finish_reason="stop"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
async def ainvoke(
|
|
40
|
+
self,
|
|
41
|
+
messages: List[Message],
|
|
42
|
+
config: ModelConfig | None = None,
|
|
43
|
+
) -> LLMResponse:
|
|
44
|
+
"""Asynchronous mock invocation."""
|
|
45
|
+
import asyncio
|
|
46
|
+
|
|
47
|
+
return await asyncio.to_thread(self.invoke, messages, config)
|
|
48
|
+
|
|
49
|
+
def stream(
|
|
50
|
+
self,
|
|
51
|
+
messages: List[Message],
|
|
52
|
+
config: ModelConfig | None = None,
|
|
53
|
+
) -> Iterator[StreamChunk]:
|
|
54
|
+
"""Synchronous mock streaming."""
|
|
55
|
+
last_msg = messages[-1] if messages else Message.user("")
|
|
56
|
+
content = f"Mock response to: {last_msg.content}"
|
|
57
|
+
words = content.split()
|
|
58
|
+
|
|
59
|
+
accumulated = ""
|
|
60
|
+
for i, word in enumerate(words):
|
|
61
|
+
word_with_space = word + " "
|
|
62
|
+
accumulated += word_with_space
|
|
63
|
+
yield StreamChunk(
|
|
64
|
+
content=accumulated, delta=word_with_space, index=i, finish_reason=None
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
yield StreamChunk(
|
|
68
|
+
content=content,
|
|
69
|
+
delta="",
|
|
70
|
+
index=len(words),
|
|
71
|
+
finish_reason="stop",
|
|
72
|
+
usage=Usage(prompt_tokens=10, completion_tokens=10),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
async def astream(
|
|
76
|
+
self,
|
|
77
|
+
messages: List[Message],
|
|
78
|
+
config: ModelConfig | None = None,
|
|
79
|
+
) -> AsyncIterator[StreamChunk]:
|
|
80
|
+
"""Asynchronous mock streaming."""
|
|
81
|
+
for chunk in self.stream(messages, config):
|
|
82
|
+
yield chunk
|
|
83
|
+
|
|
84
|
+
def count_tokens(self, text: str | List[Message]) -> int:
|
|
85
|
+
"""Mock token counting (1 word = 1 token)."""
|
|
86
|
+
if isinstance(text, str):
|
|
87
|
+
return len(text.split())
|
|
88
|
+
|
|
89
|
+
count = 0
|
|
90
|
+
for msg in text:
|
|
91
|
+
count += len(msg.content.split())
|
|
92
|
+
return count
|
|
93
|
+
|
|
94
|
+
def estimate_cost(self, usage: Usage) -> CostInfo:
|
|
95
|
+
"""Mock cost estimation (free)."""
|
|
96
|
+
return CostInfo(total_cost=0.0)
|