kader 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,581 @@
1
+ """
2
+ Base class for LLM Providers.
3
+
4
+ A versatile, provider-agnostic base class for LLM interactions supporting
5
+ OpenAI, Google, Anthropic, Mistral, and other providers.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ from typing import (
12
+ Any,
13
+ AsyncIterator,
14
+ Iterator,
15
+ Literal,
16
+ TypeAlias,
17
+ )
18
+
19
+ # Type Aliases
20
+ Role: TypeAlias = Literal["system", "user", "assistant", "tool"]
21
+ FinishReason: TypeAlias = Literal[
22
+ "stop", "length", "tool_calls", "content_filter", "error", None
23
+ ]
24
+
25
+
26
+ class MessageRole(str, Enum):
27
+ """Enumeration of message roles for type safety."""
28
+
29
+ SYSTEM = "system"
30
+ USER = "user"
31
+ ASSISTANT = "assistant"
32
+ TOOL = "tool"
33
+
34
+
35
+ @dataclass
36
+ class Message:
37
+ """Represents a chat message in a conversation."""
38
+
39
+ role: Role
40
+ content: str
41
+ name: str | None = None
42
+ tool_call_id: str | None = None
43
+ tool_calls: list[dict[str, Any]] | None = None
44
+
45
+ def to_dict(self) -> dict[str, Any]:
46
+ """Convert message to dictionary format for API calls."""
47
+ data: dict[str, Any] = {
48
+ "role": self.role,
49
+ "content": self.content,
50
+ }
51
+ if self.name:
52
+ data["name"] = self.name
53
+ if self.tool_call_id:
54
+ data["tool_call_id"] = self.tool_call_id
55
+ if self.tool_calls:
56
+ data["tool_calls"] = self.tool_calls
57
+ return data
58
+
59
+ @classmethod
60
+ def system(cls, content: str) -> "Message":
61
+ """Create a system message."""
62
+ return cls(role="system", content=content)
63
+
64
+ @classmethod
65
+ def user(cls, content: str) -> "Message":
66
+ """Create a user message."""
67
+ return cls(role="user", content=content)
68
+
69
+ @classmethod
70
+ def assistant(cls, content: str) -> "Message":
71
+ """Create an assistant message."""
72
+ return cls(role="assistant", content=content)
73
+
74
+ @classmethod
75
+ def tool(cls, tool_call_id: str, content: str) -> "Message":
76
+ """Create a tool message."""
77
+ return cls(role="tool", tool_call_id=tool_call_id, content=content)
78
+
79
+
80
+ @dataclass
81
+ class Usage:
82
+ """Tracks token usage for an LLM request."""
83
+
84
+ prompt_tokens: int = 0
85
+ completion_tokens: int = 0
86
+ total_tokens: int = 0
87
+
88
+ # Additional usage details (provider-specific)
89
+ cached_tokens: int = 0
90
+ reasoning_tokens: int = 0
91
+
92
+ def __post_init__(self) -> None:
93
+ """Calculate total tokens if not provided."""
94
+ if self.total_tokens == 0:
95
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
96
+
97
+ def __add__(self, other: "Usage") -> "Usage":
98
+ """Add two Usage instances together."""
99
+ return Usage(
100
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
101
+ completion_tokens=self.completion_tokens + other.completion_tokens,
102
+ total_tokens=self.total_tokens + other.total_tokens,
103
+ cached_tokens=self.cached_tokens + other.cached_tokens,
104
+ reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
105
+ )
106
+
107
+
108
+ @dataclass
109
+ class CostInfo:
110
+ """Cost breakdown for an LLM request."""
111
+
112
+ input_cost: float = 0.0
113
+ output_cost: float = 0.0
114
+ total_cost: float = 0.0
115
+ currency: str = "USD"
116
+
117
+ # Additional cost details
118
+ cached_input_cost: float = 0.0
119
+
120
+ def __post_init__(self) -> None:
121
+ """Calculate total cost if not provided."""
122
+ if self.total_cost == 0.0:
123
+ self.total_cost = self.input_cost + self.output_cost
124
+
125
+ def __add__(self, other: "CostInfo") -> "CostInfo":
126
+ """Add two CostInfo instances together."""
127
+ if self.currency != other.currency:
128
+ raise ValueError(
129
+ f"Cannot add costs with different currencies: {self.currency} vs {other.currency}"
130
+ )
131
+ return CostInfo(
132
+ input_cost=self.input_cost + other.input_cost,
133
+ output_cost=self.output_cost + other.output_cost,
134
+ total_cost=self.total_cost + other.total_cost,
135
+ currency=self.currency,
136
+ cached_input_cost=self.cached_input_cost + other.cached_input_cost,
137
+ )
138
+
139
+ def format(self, precision: int = 6) -> str:
140
+ """Format cost as a readable string."""
141
+ return f"${self.total_cost:.{precision}f} {self.currency}"
142
+
143
+
144
+ @dataclass
145
+ class ModelConfig:
146
+ """Configuration for model inference parameters."""
147
+
148
+ # Core parameters
149
+ temperature: float = 1.0
150
+ max_tokens: int | None = None
151
+ top_p: float = 1.0
152
+
153
+ # Sampling parameters
154
+ top_k: int | None = None
155
+ frequency_penalty: float = 0.0
156
+ presence_penalty: float = 0.0
157
+
158
+ # Stop sequences
159
+ stop_sequences: list[str] | None = None
160
+
161
+ # Streaming
162
+ stream: bool = False
163
+
164
+ # Tool/Function calling
165
+ tools: list[dict[str, Any]] | None = None
166
+ tool_choice: str | dict[str, Any] | None = None
167
+
168
+ # Response format
169
+ response_format: dict[str, Any] | None = None
170
+
171
+ # Seed for reproducibility
172
+ seed: int | None = None
173
+
174
+ # Additional provider-specific parameters
175
+ extra: dict[str, Any] = field(default_factory=dict)
176
+
177
+ def to_dict(self) -> dict[str, Any]:
178
+ """Convert config to dictionary, excluding None values."""
179
+ data: dict[str, Any] = {}
180
+
181
+ if self.temperature != 1.0:
182
+ data["temperature"] = self.temperature
183
+ if self.max_tokens is not None:
184
+ data["max_tokens"] = self.max_tokens
185
+ if self.top_p != 1.0:
186
+ data["top_p"] = self.top_p
187
+ if self.top_k is not None:
188
+ data["top_k"] = self.top_k
189
+ if self.frequency_penalty != 0.0:
190
+ data["frequency_penalty"] = self.frequency_penalty
191
+ if self.presence_penalty != 0.0:
192
+ data["presence_penalty"] = self.presence_penalty
193
+ if self.stop_sequences:
194
+ data["stop"] = self.stop_sequences
195
+ if self.tools:
196
+ data["tools"] = self.tools
197
+ if self.tool_choice is not None:
198
+ data["tool_choice"] = self.tool_choice
199
+ if self.response_format is not None:
200
+ data["response_format"] = self.response_format
201
+ if self.seed is not None:
202
+ data["seed"] = self.seed
203
+
204
+ # Merge extra parameters
205
+ data.update(self.extra)
206
+
207
+ return data
208
+
209
+
210
+ @dataclass
211
+ class LLMResponse:
212
+ """Complete response from an LLM provider."""
213
+
214
+ content: str
215
+ model: str
216
+ usage: Usage
217
+ finish_reason: FinishReason = None
218
+
219
+ # Cost information (optional, calculated if pricing is available)
220
+ cost: CostInfo | None = None
221
+
222
+ # Tool calls (if any)
223
+ tool_calls: list[dict[str, Any]] | None = None
224
+
225
+ # Raw response from provider (for debugging/extension)
226
+ raw_response: Any = None
227
+
228
+ # Additional metadata
229
+ id: str | None = None
230
+ created: int | None = None
231
+
232
+ @property
233
+ def has_tool_calls(self) -> bool:
234
+ """Check if response contains tool calls."""
235
+ return self.tool_calls is not None and len(self.tool_calls) > 0
236
+
237
+ def to_message(self) -> Message:
238
+ """Convert response to an assistant message."""
239
+ return Message(
240
+ role="assistant",
241
+ content=self.content,
242
+ tool_calls=self.tool_calls,
243
+ )
244
+
245
+
246
+ @dataclass
247
+ class StreamChunk:
248
+ """A chunk from a streaming LLM response."""
249
+
250
+ content: str = ""
251
+ delta: str = ""
252
+ finish_reason: FinishReason = None
253
+
254
+ # Partial usage (available at end of stream for some providers)
255
+ usage: Usage | None = None
256
+
257
+ # Tool call deltas
258
+ tool_calls: list[dict[str, Any]] | None = None
259
+
260
+ # Index of this chunk in the stream
261
+ index: int = 0
262
+
263
+ @property
264
+ def is_final(self) -> bool:
265
+ """Check if this is the final chunk."""
266
+ return self.finish_reason is not None
267
+
268
+
269
+ @dataclass
270
+ class ModelPricing:
271
+ """Pricing information for a model."""
272
+
273
+ input_cost_per_million: float # Cost per million input tokens
274
+ output_cost_per_million: float # Cost per million output tokens
275
+ cached_input_cost_per_million: float | None = (
276
+ None # Cached input cost (if supported)
277
+ )
278
+
279
+ def calculate_cost(self, usage: Usage) -> CostInfo:
280
+ """Calculate cost from usage."""
281
+ input_cost = (usage.prompt_tokens / 1_000_000) * self.input_cost_per_million
282
+ output_cost = (
283
+ usage.completion_tokens / 1_000_000
284
+ ) * self.output_cost_per_million
285
+
286
+ cached_cost = 0.0
287
+ if self.cached_input_cost_per_million and usage.cached_tokens > 0:
288
+ cached_cost = (
289
+ usage.cached_tokens / 1_000_000
290
+ ) * self.cached_input_cost_per_million
291
+
292
+ return CostInfo(
293
+ input_cost=input_cost,
294
+ output_cost=output_cost,
295
+ cached_input_cost=cached_cost,
296
+ )
297
+
298
+
299
+ @dataclass
300
+ class ModelInfo:
301
+ """Information about an LLM model."""
302
+
303
+ name: str
304
+ provider: str
305
+ context_window: int
306
+ max_output_tokens: int | None = None
307
+ pricing: ModelPricing | None = None
308
+ supports_vision: bool = False
309
+ supports_tools: bool = False
310
+ supports_json_mode: bool = False
311
+ supports_streaming: bool = True
312
+
313
+ # Additional capabilities
314
+ capabilities: dict[str, Any] = field(default_factory=dict)
315
+
316
+
317
+ class BaseLLMProvider(ABC):
318
+ """
319
+ Abstract base class for LLM providers.
320
+
321
+ Provides a unified interface for interacting with various LLM providers
322
+ including OpenAI, Google, Anthropic, Mistral, and others.
323
+
324
+ Subclasses must implement:
325
+ - invoke: Synchronous single completion
326
+ - ainvoke: Asynchronous single completion
327
+ - stream: Synchronous streaming completion
328
+ - astream: Asynchronous streaming completion
329
+ - count_tokens: Count tokens in text/messages
330
+ - estimate_cost: Estimate cost from usage
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ model: str,
336
+ default_config: ModelConfig | None = None,
337
+ ) -> None:
338
+ """
339
+ Initialize the LLM provider.
340
+
341
+ Args:
342
+ model: The model identifier to use
343
+ default_config: Default configuration for all requests
344
+ """
345
+ self._model = model
346
+ self._default_config = default_config or ModelConfig()
347
+ self._total_usage = Usage()
348
+ self._total_cost = CostInfo()
349
+
350
+ @property
351
+ def model(self) -> str:
352
+ """Get the current model identifier."""
353
+ return self._model
354
+
355
+ @property
356
+ def total_usage(self) -> Usage:
357
+ """Get total token usage across all requests."""
358
+ return self._total_usage
359
+
360
+ @property
361
+ def total_cost(self) -> CostInfo:
362
+ """Get total cost across all requests."""
363
+ return self._total_cost
364
+
365
+ def reset_tracking(self) -> None:
366
+ """Reset usage and cost tracking."""
367
+ self._total_usage = Usage()
368
+ self._total_cost = CostInfo()
369
+
370
+ def _merge_config(self, config: ModelConfig | None) -> ModelConfig:
371
+ """Merge provided config with defaults."""
372
+ if config is None:
373
+ return self._default_config
374
+
375
+ # Create a new config with merged values
376
+ return ModelConfig(
377
+ temperature=config.temperature
378
+ if config.temperature != 1.0
379
+ else self._default_config.temperature,
380
+ max_tokens=config.max_tokens or self._default_config.max_tokens,
381
+ top_p=config.top_p if config.top_p != 1.0 else self._default_config.top_p,
382
+ top_k=config.top_k or self._default_config.top_k,
383
+ frequency_penalty=config.frequency_penalty
384
+ if config.frequency_penalty != 0.0
385
+ else self._default_config.frequency_penalty,
386
+ presence_penalty=config.presence_penalty
387
+ if config.presence_penalty != 0.0
388
+ else self._default_config.presence_penalty,
389
+ stop_sequences=config.stop_sequences or self._default_config.stop_sequences,
390
+ stream=config.stream,
391
+ tools=config.tools or self._default_config.tools,
392
+ tool_choice=config.tool_choice or self._default_config.tool_choice,
393
+ response_format=config.response_format
394
+ or self._default_config.response_format,
395
+ seed=config.seed or self._default_config.seed,
396
+ extra={**self._default_config.extra, **config.extra},
397
+ )
398
+
399
+ def _update_tracking(self, response: LLMResponse) -> None:
400
+ """Update usage and cost tracking from a response."""
401
+ self._total_usage = self._total_usage + response.usage
402
+ self._total_usage.__post_init__() # Recalculate total_tokens
403
+
404
+ if response.cost:
405
+ self._total_cost = self._total_cost + response.cost
406
+ self._total_cost.__post_init__() # Recalculate total_cost
407
+
408
+ # -------------------------------------------------------------------------
409
+ # Abstract Methods - Must be implemented by subclasses
410
+ # -------------------------------------------------------------------------
411
+
412
+ @abstractmethod
413
+ def invoke(
414
+ self,
415
+ messages: list[Message],
416
+ config: ModelConfig | None = None,
417
+ ) -> LLMResponse:
418
+ """
419
+ Synchronously invoke the LLM with the given messages.
420
+
421
+ Args:
422
+ messages: List of messages in the conversation
423
+ config: Optional configuration overrides
424
+
425
+ Returns:
426
+ LLMResponse with the model's response
427
+ """
428
+ ...
429
+
430
+ @abstractmethod
431
+ async def ainvoke(
432
+ self,
433
+ messages: list[Message],
434
+ config: ModelConfig | None = None,
435
+ ) -> LLMResponse:
436
+ """
437
+ Asynchronously invoke the LLM with the given messages.
438
+
439
+ Args:
440
+ messages: List of messages in the conversation
441
+ config: Optional configuration overrides
442
+
443
+ Returns:
444
+ LLMResponse with the model's response
445
+ """
446
+ ...
447
+
448
+ @abstractmethod
449
+ def stream(
450
+ self,
451
+ messages: list[Message],
452
+ config: ModelConfig | None = None,
453
+ ) -> Iterator[StreamChunk]:
454
+ """
455
+ Synchronously stream the LLM response.
456
+
457
+ Args:
458
+ messages: List of messages in the conversation
459
+ config: Optional configuration overrides
460
+
461
+ Yields:
462
+ StreamChunk objects as they arrive
463
+ """
464
+ ...
465
+
466
+ @abstractmethod
467
+ async def astream(
468
+ self,
469
+ messages: list[Message],
470
+ config: ModelConfig | None = None,
471
+ ) -> AsyncIterator[StreamChunk]:
472
+ """
473
+ Asynchronously stream the LLM response.
474
+
475
+ Args:
476
+ messages: List of messages in the conversation
477
+ config: Optional configuration overrides
478
+
479
+ Yields:
480
+ StreamChunk objects as they arrive
481
+ """
482
+ ...
483
+
484
+ @abstractmethod
485
+ def count_tokens(
486
+ self,
487
+ text: str | list[Message],
488
+ ) -> int:
489
+ """
490
+ Count the number of tokens in the given text or messages.
491
+
492
+ Args:
493
+ text: A string or list of messages to count tokens for
494
+
495
+ Returns:
496
+ Number of tokens
497
+ """
498
+ ...
499
+
500
+ @abstractmethod
501
+ def estimate_cost(
502
+ self,
503
+ usage: Usage,
504
+ ) -> CostInfo:
505
+ """
506
+ Estimate the cost for the given token usage.
507
+
508
+ Args:
509
+ usage: Token usage information
510
+
511
+ Returns:
512
+ CostInfo with cost breakdown
513
+ """
514
+ ...
515
+
516
+ # -------------------------------------------------------------------------
517
+ # Concrete Methods - Can be overridden if needed
518
+ # -------------------------------------------------------------------------
519
+
520
+ def get_model_info(self) -> ModelInfo | None:
521
+ """
522
+ Get information about the current model.
523
+
524
+ Returns:
525
+ ModelInfo if available, None otherwise
526
+ """
527
+ return None
528
+
529
+ @classmethod
530
+ def get_supported_models(cls) -> list[str]:
531
+ """
532
+ Get list of models supported by this provider.
533
+
534
+ Returns:
535
+ List of model identifiers
536
+ """
537
+ return []
538
+
539
+ def validate_config(self, config: ModelConfig) -> bool:
540
+ """
541
+ Validate the given configuration.
542
+
543
+ Args:
544
+ config: Configuration to validate
545
+
546
+ Returns:
547
+ True if valid, False otherwise
548
+ """
549
+ if config.temperature < 0 or config.temperature > 2:
550
+ return False
551
+ if config.top_p < 0 or config.top_p > 1:
552
+ return False
553
+ if config.max_tokens is not None and config.max_tokens < 1:
554
+ return False
555
+ return True
556
+
557
+ def validate_messages(self, messages: list[Message]) -> bool:
558
+ """
559
+ Validate the given messages.
560
+
561
+ Args:
562
+ messages: Messages to validate
563
+
564
+ Returns:
565
+ True if valid, False otherwise
566
+ """
567
+ if not messages:
568
+ return False
569
+
570
+ valid_roles = {"system", "user", "assistant", "tool"}
571
+ for msg in messages:
572
+ if msg.role not in valid_roles:
573
+ return False
574
+ if not msg.content and not msg.tool_calls:
575
+ return False
576
+
577
+ return True
578
+
579
+ def __repr__(self) -> str:
580
+ """String representation of the provider."""
581
+ return f"{self.__class__.__name__}(model='{self._model}')"
@@ -0,0 +1,96 @@
1
+ """
2
+ Mock LLM Provider for testing and development.
3
+ """
4
+
5
+ from typing import AsyncIterator, Iterator, List
6
+
7
+ from .base import (
8
+ BaseLLMProvider,
9
+ CostInfo,
10
+ LLMResponse,
11
+ Message,
12
+ ModelConfig,
13
+ StreamChunk,
14
+ Usage,
15
+ )
16
+
17
+
18
+ class MockLLM(BaseLLMProvider):
19
+ """
20
+ A mock LLM provider that echoes inputs or returns predefined responses.
21
+ Useful for testing without incurring costs or latency.
22
+ """
23
+
24
+ def invoke(
25
+ self,
26
+ messages: List[Message],
27
+ config: ModelConfig | None = None,
28
+ ) -> LLMResponse:
29
+ """Synchronous mock invocation."""
30
+ last_msg = messages[-1] if messages else Message.user("")
31
+ content = f"Mock response to: {last_msg.content}"
32
+
33
+ usage = Usage(prompt_tokens=10, completion_tokens=10)
34
+
35
+ return LLMResponse(
36
+ content=content, model=self.model, usage=usage, finish_reason="stop"
37
+ )
38
+
39
+ async def ainvoke(
40
+ self,
41
+ messages: List[Message],
42
+ config: ModelConfig | None = None,
43
+ ) -> LLMResponse:
44
+ """Asynchronous mock invocation."""
45
+ import asyncio
46
+
47
+ return await asyncio.to_thread(self.invoke, messages, config)
48
+
49
+ def stream(
50
+ self,
51
+ messages: List[Message],
52
+ config: ModelConfig | None = None,
53
+ ) -> Iterator[StreamChunk]:
54
+ """Synchronous mock streaming."""
55
+ last_msg = messages[-1] if messages else Message.user("")
56
+ content = f"Mock response to: {last_msg.content}"
57
+ words = content.split()
58
+
59
+ accumulated = ""
60
+ for i, word in enumerate(words):
61
+ word_with_space = word + " "
62
+ accumulated += word_with_space
63
+ yield StreamChunk(
64
+ content=accumulated, delta=word_with_space, index=i, finish_reason=None
65
+ )
66
+
67
+ yield StreamChunk(
68
+ content=content,
69
+ delta="",
70
+ index=len(words),
71
+ finish_reason="stop",
72
+ usage=Usage(prompt_tokens=10, completion_tokens=10),
73
+ )
74
+
75
+ async def astream(
76
+ self,
77
+ messages: List[Message],
78
+ config: ModelConfig | None = None,
79
+ ) -> AsyncIterator[StreamChunk]:
80
+ """Asynchronous mock streaming."""
81
+ for chunk in self.stream(messages, config):
82
+ yield chunk
83
+
84
+ def count_tokens(self, text: str | List[Message]) -> int:
85
+ """Mock token counting (1 word = 1 token)."""
86
+ if isinstance(text, str):
87
+ return len(text.split())
88
+
89
+ count = 0
90
+ for msg in text:
91
+ count += len(msg.content.split())
92
+ return count
93
+
94
+ def estimate_cost(self, usage: Usage) -> CostInfo:
95
+ """Mock cost estimation (free)."""
96
+ return CostInfo(total_cost=0.0)