multillm-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multillm_core-0.1.0/API_SPEC.md +59 -0
- multillm_core-0.1.0/PKG-INFO +69 -0
- multillm_core-0.1.0/README.md +45 -0
- multillm_core-0.1.0/pyproject.toml +61 -0
- multillm_core-0.1.0/src/multillm/__init__.py +13 -0
- multillm_core-0.1.0/src/multillm/exceptions.py +46 -0
- multillm_core-0.1.0/src/multillm/factory.py +39 -0
- multillm_core-0.1.0/src/multillm/providers/__init__.py +3 -0
- multillm_core-0.1.0/src/multillm/providers/anthropic.py +59 -0
- multillm_core-0.1.0/src/multillm/providers/base.py +42 -0
- multillm_core-0.1.0/src/multillm/providers/gemini.py +46 -0
- multillm_core-0.1.0/src/multillm/providers/openai.py +57 -0
- multillm_core-0.1.0/src/multillm/schemas.py +23 -0
- multillm_core-0.1.0/src/multillm/version.py +1 -0
- multillm_core-0.1.0/tests/test_core.py +98 -0
- multillm_core-0.1.0/uv.lock +1504 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Internal API Specification — multillm
|
|
2
|
+
|
|
3
|
+
## 1. Public API Contract
|
|
4
|
+
|
|
5
|
+
### Package Name
|
|
6
|
+
- **Internal/Import Name**: `multillm`
|
|
7
|
+
- **PyPI Name**: `multillm-core` (Suggested, as `multillm` is taken)
|
|
8
|
+
|
|
9
|
+
### Entrypoint
|
|
10
|
+
```python
|
|
11
|
+
from multillm import create_client
|
|
12
|
+
|
|
13
|
+
client = create_client("openai", api_key="sk-...")
|
|
14
|
+
response = await client.generate(model="gpt-4o", prompt="Hello!")
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
### Core Abstractions
|
|
18
|
+
|
|
19
|
+
#### BaseLLMClient
|
|
20
|
+
Abstract base class for all providers.
|
|
21
|
+
- `async generate(model: str, prompt: str, **kwargs) -> LLMResponse`
|
|
22
|
+
- `async stream(model: str, prompt: str, **kwargs) -> AsyncGenerator[str, None]` (Future proofing)
|
|
23
|
+
|
|
24
|
+
#### LLMResponse (Pydantic)
|
|
25
|
+
Unified response object.
|
|
26
|
+
- `text`: str
|
|
27
|
+
- `provider`: str
|
|
28
|
+
- `model`: str
|
|
29
|
+
- `usage`: `TokenUsage` (input_tokens, output_tokens)
|
|
30
|
+
- `raw`: Full response from provider SDK
|
|
31
|
+
|
|
32
|
+
#### Error Hierarchy
|
|
33
|
+
- `MultiLLMError` (Base)
|
|
34
|
+
- `ProviderNotFoundError`: Factory can't find provider
|
|
35
|
+
- `SDKNotInstalledError`: Provider dependencies missing
|
|
36
|
+
- `APIError`: Base for provider-side errors
|
|
37
|
+
- `AuthenticationError`: 401
|
|
38
|
+
- `RateLimitError`: 429
|
|
39
|
+
- `InvalidRequestError`: 400
|
|
40
|
+
|
|
41
|
+
## 2. Supported Providers (v0.1 Scope)
|
|
42
|
+
- **OpenAI**: Using `openai` SDK
|
|
43
|
+
- **Anthropic**: Using `anthropic` SDK
|
|
44
|
+
- **Google Gemini**: Using `google-generativeai` SDK
|
|
45
|
+
|
|
46
|
+
## 3. Project Structure
|
|
47
|
+
```text
|
|
48
|
+
src/multillm/
|
|
49
|
+
├── providers/
|
|
50
|
+
│ ├── base.py
|
|
51
|
+
│ ├── openai.py
|
|
52
|
+
│ ├── anthropic.py
|
|
53
|
+
│ └── gemini.py
|
|
54
|
+
├── factory.py
|
|
55
|
+
├── schemas.py
|
|
56
|
+
├── exceptions.py
|
|
57
|
+
├── version.py
|
|
58
|
+
└── __init__.py
|
|
59
|
+
```
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: multillm-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A unified interface for multiple LLM providers
|
|
5
|
+
Author: Divas Rajan
|
|
6
|
+
License: MIT
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Requires-Python: >=3.9
|
|
11
|
+
Requires-Dist: google-genai>=1.47.0
|
|
12
|
+
Requires-Dist: pydantic>=2.0.0
|
|
13
|
+
Provides-Extra: all
|
|
14
|
+
Requires-Dist: anthropic>=0.10.0; extra == 'all'
|
|
15
|
+
Requires-Dist: google-genai>=0.1.0; extra == 'all'
|
|
16
|
+
Requires-Dist: openai>=1.0.0; extra == 'all'
|
|
17
|
+
Provides-Extra: anthropic
|
|
18
|
+
Requires-Dist: anthropic>=0.10.0; extra == 'anthropic'
|
|
19
|
+
Provides-Extra: gemini
|
|
20
|
+
Requires-Dist: google-genai>=0.1.0; extra == 'gemini'
|
|
21
|
+
Provides-Extra: openai
|
|
22
|
+
Requires-Dist: openai>=1.0.0; extra == 'openai'
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# multillm-core
|
|
26
|
+
|
|
27
|
+
A unified interface for multiple LLM providers (OpenAI, Anthropic, Gemini).
|
|
28
|
+
|
|
29
|
+
## Installation
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
pip install multillm-core
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
### Install with specific providers:
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install "multillm-core[openai]"
|
|
39
|
+
pip install "multillm-core[anthropic]"
|
|
40
|
+
pip install "multillm-core[gemini]"
|
|
41
|
+
pip install "multillm-core[all]"
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Example Usage
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import asyncio
|
|
48
|
+
from multillm import create_client
|
|
49
|
+
|
|
50
|
+
async def main():
|
|
51
|
+
client = create_client("openai", api_key="your-key")
|
|
52
|
+
response = await client.generate(
|
|
53
|
+
model="gpt-4o",
|
|
54
|
+
prompt="Explain quantum entanglement in one sentence."
|
|
55
|
+
)
|
|
56
|
+
print(f"[{response.provider}] {response.text}")
|
|
57
|
+
print(f"Usage: {response.usage}")
|
|
58
|
+
|
|
59
|
+
if __name__ == "__main__":
|
|
60
|
+
asyncio.run(main())
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Supported Providers
|
|
64
|
+
- OpenAI
|
|
65
|
+
- Anthropic
|
|
66
|
+
- Google Gemini
|
|
67
|
+
|
|
68
|
+
## License
|
|
69
|
+
MIT
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# multillm-core
|
|
2
|
+
|
|
3
|
+
A unified interface for multiple LLM providers (OpenAI, Anthropic, Gemini).
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install multillm-core
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
### Install with specific providers:
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install "multillm-core[openai]"
|
|
15
|
+
pip install "multillm-core[anthropic]"
|
|
16
|
+
pip install "multillm-core[gemini]"
|
|
17
|
+
pip install "multillm-core[all]"
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Example Usage
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import asyncio
|
|
24
|
+
from multillm import create_client
|
|
25
|
+
|
|
26
|
+
async def main():
|
|
27
|
+
client = create_client("openai", api_key="your-key")
|
|
28
|
+
response = await client.generate(
|
|
29
|
+
model="gpt-4o",
|
|
30
|
+
prompt="Explain quantum entanglement in one sentence."
|
|
31
|
+
)
|
|
32
|
+
print(f"[{response.provider}] {response.text}")
|
|
33
|
+
print(f"Usage: {response.usage}")
|
|
34
|
+
|
|
35
|
+
if __name__ == "__main__":
|
|
36
|
+
asyncio.run(main())
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Supported Providers
|
|
40
|
+
- OpenAI
|
|
41
|
+
- Anthropic
|
|
42
|
+
- Google Gemini
|
|
43
|
+
|
|
44
|
+
## License
|
|
45
|
+
MIT
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "multillm-core"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "A unified interface for multiple LLM providers"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Divas Rajan" }
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"License :: OSI Approved :: MIT License",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
]
|
|
20
|
+
dependencies = [
|
|
21
|
+
"google-genai>=1.47.0",
|
|
22
|
+
"pydantic>=2.0.0",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[project.optional-dependencies]
|
|
26
|
+
openai = ["openai>=1.0.0"]
|
|
27
|
+
anthropic = ["anthropic>=0.10.0"]
|
|
28
|
+
gemini = ["google-genai>=0.1.0"]
|
|
29
|
+
all = [
|
|
30
|
+
"openai>=1.0.0",
|
|
31
|
+
"anthropic>=0.10.0",
|
|
32
|
+
"google-genai>=0.1.0",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[tool.hatch.build.targets.wheel]
|
|
36
|
+
packages = ["src/multillm"]
|
|
37
|
+
|
|
38
|
+
[tool.pytest.ini_options]
|
|
39
|
+
asyncio_mode = "auto"
|
|
40
|
+
testpaths = ["tests"]
|
|
41
|
+
|
|
42
|
+
[tool.ruff]
|
|
43
|
+
line-length = 88
|
|
44
|
+
target-version = "py39"
|
|
45
|
+
|
|
46
|
+
[tool.ruff.lint]
|
|
47
|
+
select = ["E", "F", "I", "W", "UP"]
|
|
48
|
+
ignore = []
|
|
49
|
+
|
|
50
|
+
[tool.mypy]
|
|
51
|
+
python_version = "3.9"
|
|
52
|
+
strict = true
|
|
53
|
+
ignore_missing_imports = true
|
|
54
|
+
|
|
55
|
+
[dependency-groups]
|
|
56
|
+
dev = [
|
|
57
|
+
"mypy>=1.19.1",
|
|
58
|
+
"pytest>=8.4.2",
|
|
59
|
+
"pytest-asyncio>=1.2.0",
|
|
60
|
+
"ruff>=0.15.1",
|
|
61
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .exceptions import MultiLLMError
|
|
2
|
+
from .factory import create_client
|
|
3
|
+
from .providers.base import BaseLLMClient
|
|
4
|
+
from .schemas import LLMResponse
|
|
5
|
+
from .version import __version__
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"create_client",
|
|
9
|
+
"BaseLLMClient",
|
|
10
|
+
"LLMResponse",
|
|
11
|
+
"MultiLLMError",
|
|
12
|
+
"__version__",
|
|
13
|
+
]
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
class MultiLLMError(Exception):
|
|
2
|
+
"""Base exception for all multillm errors."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ProviderNotFoundError(MultiLLMError):
|
|
8
|
+
"""Raised when a requested provider is not found."""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SDKNotInstalledError(MultiLLMError):
|
|
14
|
+
"""Raised when a provider's SDK is not installed."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, provider: str, package_name: str):
|
|
17
|
+
self.provider = provider
|
|
18
|
+
self.package_name = package_name
|
|
19
|
+
super().__init__(
|
|
20
|
+
f"SDK for '{provider}' is not installed. Please install it using: "
|
|
21
|
+
f"pip install 'multillm-core[{package_name}]'"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class APIError(MultiLLMError):
|
|
26
|
+
"""Base for provider-side API errors."""
|
|
27
|
+
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AuthenticationError(APIError):
|
|
32
|
+
"""Raised when authentication fails."""
|
|
33
|
+
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RateLimitError(APIError):
|
|
38
|
+
"""Raised when rate limits are exceeded."""
|
|
39
|
+
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class InvalidRequestError(APIError):
|
|
44
|
+
"""Raised when the request is invalid."""
|
|
45
|
+
|
|
46
|
+
pass
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from .exceptions import ProviderNotFoundError
|
|
4
|
+
from .providers.anthropic import AnthropicClient
|
|
5
|
+
from .providers.base import BaseLLMClient
|
|
6
|
+
from .providers.gemini import GeminiClient
|
|
7
|
+
from .providers.openai import OpenAIClient
|
|
8
|
+
|
|
9
|
+
PROVIDER_MAP: dict[str, type[BaseLLMClient]] = {
|
|
10
|
+
"openai": OpenAIClient,
|
|
11
|
+
"anthropic": AnthropicClient,
|
|
12
|
+
"gemini": GeminiClient,
|
|
13
|
+
"google": GeminiClient,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_client(provider: str, **kwargs: Any) -> BaseLLMClient:
|
|
18
|
+
"""
|
|
19
|
+
Factory function to create an LLM client.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
provider: Name of the provider ('openai', 'anthropic', 'gemini').
|
|
23
|
+
**kwargs: Arguments passed to the provider client constructor (e.g., api_key).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
An instance of BaseLLMClient.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ProviderNotFoundError: If the provider is not supported.
|
|
30
|
+
"""
|
|
31
|
+
provider_key = provider.lower()
|
|
32
|
+
if provider_key not in PROVIDER_MAP:
|
|
33
|
+
raise ProviderNotFoundError(
|
|
34
|
+
f"Provider '{provider}' is not supported. "
|
|
35
|
+
f"Available providers: {', '.join(PROVIDER_MAP.keys())}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
client_cls = PROVIDER_MAP[provider_key]
|
|
39
|
+
return client_cls(**kwargs)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Any, NoReturn, Optional
|
|
2
|
+
|
|
3
|
+
from ..exceptions import (
|
|
4
|
+
APIError,
|
|
5
|
+
AuthenticationError,
|
|
6
|
+
InvalidRequestError,
|
|
7
|
+
RateLimitError,
|
|
8
|
+
SDKNotInstalledError,
|
|
9
|
+
)
|
|
10
|
+
from ..schemas import LLMResponse, TokenUsage
|
|
11
|
+
from .base import BaseLLMClient
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AnthropicClient(BaseLLMClient):
|
|
15
|
+
"""Anthropic provider client."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, api_key: Optional[str] = None, **kwargs: Any):
|
|
18
|
+
try:
|
|
19
|
+
import anthropic
|
|
20
|
+
|
|
21
|
+
self._client = anthropic.AsyncAnthropic(api_key=api_key, **kwargs)
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise SDKNotInstalledError("anthropic", "anthropic")
|
|
24
|
+
|
|
25
|
+
async def generate(self, model: str, prompt: str, **kwargs: Any) -> LLMResponse:
|
|
26
|
+
try:
|
|
27
|
+
response = await self._client.messages.create(
|
|
28
|
+
model=model,
|
|
29
|
+
max_tokens=kwargs.get("max_tokens", 4096),
|
|
30
|
+
messages=[{"role": "user", "content": prompt}],
|
|
31
|
+
**{k: v for k, v in kwargs.items() if k != "max_tokens"},
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
text = response.content[0].text
|
|
35
|
+
|
|
36
|
+
usage = TokenUsage(
|
|
37
|
+
input_tokens=response.usage.input_tokens,
|
|
38
|
+
output_tokens=response.usage.output_tokens,
|
|
39
|
+
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return LLMResponse(
|
|
43
|
+
text=text, provider="anthropic", model=model, usage=usage, raw=response
|
|
44
|
+
)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
return self._handle_error(e)
|
|
47
|
+
|
|
48
|
+
def _handle_error(self, e: Exception) -> NoReturn:
|
|
49
|
+
import anthropic
|
|
50
|
+
|
|
51
|
+
if isinstance(e, anthropic.AuthenticationError):
|
|
52
|
+
raise AuthenticationError(str(e))
|
|
53
|
+
if isinstance(e, anthropic.RateLimitError):
|
|
54
|
+
raise RateLimitError(str(e))
|
|
55
|
+
if isinstance(e, anthropic.BadRequestError):
|
|
56
|
+
raise InvalidRequestError(str(e))
|
|
57
|
+
if isinstance(e, anthropic.APIError):
|
|
58
|
+
raise APIError(str(e))
|
|
59
|
+
raise APIError(f"Unexpected Anthropic error: {str(e)}")
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import AsyncGenerator
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from ..schemas import LLMResponse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseLLMClient(ABC):
|
|
9
|
+
"""Base class for all LLM provider clients."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def generate(self, model: str, prompt: str, **kwargs: Any) -> LLMResponse:
|
|
13
|
+
"""
|
|
14
|
+
Generate a response from the LLM.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
model: The model name to use.
|
|
18
|
+
prompt: The input text prompt.
|
|
19
|
+
**kwargs: Provider-specific arguments.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
LLMResponse object.
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
async def stream(
|
|
27
|
+
self, model: str, prompt: str, **kwargs: Any
|
|
28
|
+
) -> AsyncGenerator[str, None]:
|
|
29
|
+
"""
|
|
30
|
+
Stream back the response from the LLM.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: The model name to use.
|
|
34
|
+
prompt: The input text prompt.
|
|
35
|
+
**kwargs: Provider-specific arguments.
|
|
36
|
+
|
|
37
|
+
Yields:
|
|
38
|
+
Response chunks text.
|
|
39
|
+
"""
|
|
40
|
+
raise NotImplementedError(
|
|
41
|
+
f"Streaming is not implemented for {self.__class__.__name__}"
|
|
42
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from ..exceptions import APIError, SDKNotInstalledError
|
|
4
|
+
from ..schemas import LLMResponse, TokenUsage
|
|
5
|
+
from .base import BaseLLMClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GeminiClient(BaseLLMClient):
|
|
9
|
+
"""Google Gemini provider client."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, api_key: Optional[str] = None, **kwargs: Any):
|
|
12
|
+
try:
|
|
13
|
+
from google import genai
|
|
14
|
+
|
|
15
|
+
self._client = genai.Client(api_key=api_key, **kwargs)
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise SDKNotInstalledError("google-genai", "gemini")
|
|
18
|
+
|
|
19
|
+
async def generate(self, model: str, prompt: str, **kwargs: Any) -> LLMResponse:
|
|
20
|
+
try:
|
|
21
|
+
# The new SDK uses client.aio.models.generate_content for async
|
|
22
|
+
response = await self._client.aio.models.generate_content(
|
|
23
|
+
model=model, contents=prompt, **kwargs
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
text = response.text or ""
|
|
27
|
+
usage = TokenUsage()
|
|
28
|
+
|
|
29
|
+
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
|
30
|
+
usage = TokenUsage(
|
|
31
|
+
input_tokens=getattr(
|
|
32
|
+
response.usage_metadata, "prompt_token_count", 0
|
|
33
|
+
),
|
|
34
|
+
output_tokens=getattr(
|
|
35
|
+
response.usage_metadata, "candidates_token_count", 0
|
|
36
|
+
),
|
|
37
|
+
total_tokens=getattr(
|
|
38
|
+
response.usage_metadata, "total_token_count", 0
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return LLMResponse(
|
|
43
|
+
text=text, provider="gemini", model=model, usage=usage, raw=response
|
|
44
|
+
)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
raise APIError(f"Gemini error: {str(e)}")
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Any, NoReturn, Optional
|
|
2
|
+
|
|
3
|
+
from ..exceptions import (
|
|
4
|
+
APIError,
|
|
5
|
+
AuthenticationError,
|
|
6
|
+
InvalidRequestError,
|
|
7
|
+
RateLimitError,
|
|
8
|
+
SDKNotInstalledError,
|
|
9
|
+
)
|
|
10
|
+
from ..schemas import LLMResponse, TokenUsage
|
|
11
|
+
from .base import BaseLLMClient
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OpenAIClient(BaseLLMClient):
|
|
15
|
+
"""OpenAI provider client."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, api_key: Optional[str] = None, **kwargs: Any):
|
|
18
|
+
try:
|
|
19
|
+
import openai
|
|
20
|
+
|
|
21
|
+
self._client = openai.AsyncOpenAI(api_key=api_key, **kwargs)
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise SDKNotInstalledError("openai", "openai")
|
|
24
|
+
|
|
25
|
+
async def generate(self, model: str, prompt: str, **kwargs: Any) -> LLMResponse:
|
|
26
|
+
try:
|
|
27
|
+
response = await self._client.chat.completions.create(
|
|
28
|
+
model=model, messages=[{"role": "user", "content": prompt}], **kwargs
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
choice = response.choices[0]
|
|
32
|
+
text = choice.message.content or ""
|
|
33
|
+
|
|
34
|
+
usage = TokenUsage(
|
|
35
|
+
input_tokens=response.usage.prompt_tokens,
|
|
36
|
+
output_tokens=response.usage.completion_tokens,
|
|
37
|
+
total_tokens=response.usage.total_tokens,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return LLMResponse(
|
|
41
|
+
text=text, provider="openai", model=model, usage=usage, raw=response
|
|
42
|
+
)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
return self._handle_error(e)
|
|
45
|
+
|
|
46
|
+
def _handle_error(self, e: Exception) -> NoReturn:
|
|
47
|
+
import openai
|
|
48
|
+
|
|
49
|
+
if isinstance(e, openai.AuthenticationError):
|
|
50
|
+
raise AuthenticationError(str(e))
|
|
51
|
+
if isinstance(e, openai.RateLimitError):
|
|
52
|
+
raise RateLimitError(str(e))
|
|
53
|
+
if isinstance(e, openai.BadRequestError):
|
|
54
|
+
raise InvalidRequestError(str(e))
|
|
55
|
+
if isinstance(e, openai.APIError):
|
|
56
|
+
raise APIError(str(e))
|
|
57
|
+
raise APIError(f"Unexpected OpenAI error: {str(e)}")
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TokenUsage(BaseModel):
|
|
7
|
+
"""Token usage statistics."""
|
|
8
|
+
|
|
9
|
+
input_tokens: int = Field(default=0, description="Number of tokens in the input.")
|
|
10
|
+
output_tokens: int = Field(default=0, description="Number of tokens in the output.")
|
|
11
|
+
total_tokens: Optional[int] = Field(default=None, description="Total tokens used.")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMResponse(BaseModel):
|
|
15
|
+
"""Unified response model for all LLM providers."""
|
|
16
|
+
|
|
17
|
+
text: str = Field(..., description="The generated text response.")
|
|
18
|
+
provider: str = Field(..., description="The name of the provider.")
|
|
19
|
+
model: str = Field(..., description="The model name used.")
|
|
20
|
+
usage: TokenUsage = Field(
|
|
21
|
+
default_factory=TokenUsage, description="Token usage details."
|
|
22
|
+
)
|
|
23
|
+
raw: Any = Field(None, description="The raw response from the provider's SDK.")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from multillm import LLMResponse, create_client
|
|
6
|
+
from multillm.exceptions import ProviderNotFoundError, SDKNotInstalledError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_factory_invalid_provider():
|
|
10
|
+
with pytest.raises(ProviderNotFoundError):
|
|
11
|
+
create_client("invalid-provider")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@patch("multillm.providers.openai.OpenAIClient.__init__", return_value=None)
|
|
15
|
+
def test_factory_openai_creation(mock_init):
|
|
16
|
+
client = create_client("openai", api_key="test")
|
|
17
|
+
assert client is not None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.asyncio
|
|
21
|
+
async def test_openai_generate_mock():
|
|
22
|
+
with patch("openai.AsyncOpenAI") as mock_openai:
|
|
23
|
+
from multillm.providers.openai import OpenAIClient
|
|
24
|
+
|
|
25
|
+
# Mocking the nested response structure
|
|
26
|
+
mock_response = MagicMock()
|
|
27
|
+
mock_choice = MagicMock()
|
|
28
|
+
mock_choice.message.content = "Hehe"
|
|
29
|
+
mock_response.choices = [mock_choice]
|
|
30
|
+
mock_response.usage.prompt_tokens = 10
|
|
31
|
+
mock_response.usage.completion_tokens = 5
|
|
32
|
+
mock_response.usage.total_tokens = 15
|
|
33
|
+
|
|
34
|
+
mock_instance = mock_openai.return_value
|
|
35
|
+
mock_instance.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
36
|
+
|
|
37
|
+
client = OpenAIClient(api_key="test")
|
|
38
|
+
response = await client.generate(model="gpt-4", prompt="test")
|
|
39
|
+
|
|
40
|
+
assert isinstance(response, LLMResponse)
|
|
41
|
+
assert response.text == "Hehe"
|
|
42
|
+
assert response.provider == "openai"
|
|
43
|
+
assert response.usage.input_tokens == 10
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.asyncio
|
|
47
|
+
async def test_anthropic_generate_mock():
|
|
48
|
+
with patch("anthropic.AsyncAnthropic") as mock_anthropic:
|
|
49
|
+
from multillm.providers.anthropic import AnthropicClient
|
|
50
|
+
|
|
51
|
+
mock_response = MagicMock()
|
|
52
|
+
mock_content = MagicMock()
|
|
53
|
+
mock_content.text = "Hello from Claude"
|
|
54
|
+
mock_response.content = [mock_content]
|
|
55
|
+
mock_response.usage.input_tokens = 20
|
|
56
|
+
mock_response.usage.output_tokens = 10
|
|
57
|
+
|
|
58
|
+
mock_instance = mock_anthropic.return_value
|
|
59
|
+
mock_instance.messages.create = AsyncMock(return_value=mock_response)
|
|
60
|
+
|
|
61
|
+
client = AnthropicClient(api_key="test")
|
|
62
|
+
response = await client.generate(model="claude-3", prompt="test")
|
|
63
|
+
|
|
64
|
+
assert response.text == "Hello from Claude"
|
|
65
|
+
assert response.provider == "anthropic"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_sdk_not_installed():
|
|
69
|
+
from multillm.providers.openai import OpenAIClient
|
|
70
|
+
|
|
71
|
+
with patch("openai.AsyncOpenAI", side_effect=ImportError):
|
|
72
|
+
with pytest.raises(SDKNotInstalledError):
|
|
73
|
+
OpenAIClient(api_key="test")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@pytest.mark.asyncio
|
|
77
|
+
async def test_gemini_generate_mock():
|
|
78
|
+
with patch("google.genai.Client") as mock_client_cls:
|
|
79
|
+
from multillm.providers.gemini import GeminiClient
|
|
80
|
+
|
|
81
|
+
mock_response = MagicMock()
|
|
82
|
+
mock_response.text = "Hello from Gemini"
|
|
83
|
+
mock_response.usage_metadata.prompt_token_count = 10
|
|
84
|
+
mock_response.usage_metadata.candidates_token_count = 5
|
|
85
|
+
mock_response.usage_metadata.total_token_count = 15
|
|
86
|
+
|
|
87
|
+
mock_client = mock_client_cls.return_value
|
|
88
|
+
# The new SDK structure is client.aio.models.generate_content
|
|
89
|
+
mock_client.aio.models.generate_content = AsyncMock(return_value=mock_response)
|
|
90
|
+
|
|
91
|
+
client = GeminiClient(api_key="test")
|
|
92
|
+
response = await client.generate(model="gemini-pro", prompt="test")
|
|
93
|
+
|
|
94
|
+
assert response.text == "Hello from Gemini"
|
|
95
|
+
assert response.provider == "gemini"
|
|
96
|
+
assert response.usage.input_tokens == 10
|
|
97
|
+
assert response.usage.output_tokens == 5
|
|
98
|
+
assert response.usage.total_tokens == 15
|