chibi-bot 1.6.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chibi/__init__.py +0 -0
- chibi/__main__.py +343 -0
- chibi/cli.py +90 -0
- chibi/config/__init__.py +6 -0
- chibi/config/app.py +123 -0
- chibi/config/gpt.py +108 -0
- chibi/config/logging.py +15 -0
- chibi/config/telegram.py +43 -0
- chibi/config_generator.py +233 -0
- chibi/constants.py +362 -0
- chibi/exceptions.py +58 -0
- chibi/models.py +496 -0
- chibi/schemas/__init__.py +0 -0
- chibi/schemas/anthropic.py +20 -0
- chibi/schemas/app.py +54 -0
- chibi/schemas/cloudflare.py +65 -0
- chibi/schemas/mistralai.py +56 -0
- chibi/schemas/suno.py +83 -0
- chibi/service.py +135 -0
- chibi/services/bot.py +276 -0
- chibi/services/lock_manager.py +20 -0
- chibi/services/mcp/manager.py +242 -0
- chibi/services/metrics.py +54 -0
- chibi/services/providers/__init__.py +16 -0
- chibi/services/providers/alibaba.py +79 -0
- chibi/services/providers/anthropic.py +40 -0
- chibi/services/providers/cloudflare.py +98 -0
- chibi/services/providers/constants/suno.py +2 -0
- chibi/services/providers/customopenai.py +11 -0
- chibi/services/providers/deepseek.py +15 -0
- chibi/services/providers/eleven_labs.py +85 -0
- chibi/services/providers/gemini_native.py +489 -0
- chibi/services/providers/grok.py +40 -0
- chibi/services/providers/minimax.py +96 -0
- chibi/services/providers/mistralai_native.py +312 -0
- chibi/services/providers/moonshotai.py +20 -0
- chibi/services/providers/openai.py +74 -0
- chibi/services/providers/provider.py +892 -0
- chibi/services/providers/suno.py +130 -0
- chibi/services/providers/tools/__init__.py +23 -0
- chibi/services/providers/tools/cmd.py +132 -0
- chibi/services/providers/tools/common.py +127 -0
- chibi/services/providers/tools/constants.py +78 -0
- chibi/services/providers/tools/exceptions.py +1 -0
- chibi/services/providers/tools/file_editor.py +875 -0
- chibi/services/providers/tools/mcp_management.py +274 -0
- chibi/services/providers/tools/mcp_simple.py +72 -0
- chibi/services/providers/tools/media.py +451 -0
- chibi/services/providers/tools/memory.py +252 -0
- chibi/services/providers/tools/schemas.py +10 -0
- chibi/services/providers/tools/send.py +435 -0
- chibi/services/providers/tools/tool.py +163 -0
- chibi/services/providers/tools/utils.py +146 -0
- chibi/services/providers/tools/web.py +261 -0
- chibi/services/providers/utils.py +182 -0
- chibi/services/task_manager.py +93 -0
- chibi/services/user.py +269 -0
- chibi/storage/abstract.py +54 -0
- chibi/storage/database.py +86 -0
- chibi/storage/dynamodb.py +257 -0
- chibi/storage/local.py +70 -0
- chibi/storage/redis.py +91 -0
- chibi/utils/__init__.py +0 -0
- chibi/utils/app.py +249 -0
- chibi/utils/telegram.py +521 -0
- chibi_bot-1.6.0b0.dist-info/LICENSE +21 -0
- chibi_bot-1.6.0b0.dist-info/METADATA +340 -0
- chibi_bot-1.6.0b0.dist-info/RECORD +70 -0
- chibi_bot-1.6.0b0.dist-info/WHEEL +4 -0
- chibi_bot-1.6.0b0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from contextlib import AsyncExitStack
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from mcp import ClientSession, StdioServerParameters
|
|
7
|
+
from mcp.client.sse import sse_client
|
|
8
|
+
from mcp.client.stdio import stdio_client
|
|
9
|
+
from mcp.types import CallToolResult
|
|
10
|
+
|
|
11
|
+
from chibi.services.task_manager import task_manager
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MCPManager:
|
|
15
|
+
"""Manages the lifecycle of MCP server connections and sessions.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
_sessions: Active MCP sessions.
|
|
19
|
+
_server_tasks: Background tasks maintaining the connections.
|
|
20
|
+
_lock: Lock for thread-safe session management.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_sessions: dict[str, ClientSession] = {}
|
|
24
|
+
_server_tasks: dict[str, asyncio.Task] = {}
|
|
25
|
+
_lock: asyncio.Lock = asyncio.Lock()
|
|
26
|
+
_session_tools_map: dict[str, list[str]] = {}
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
async def _session_lifecycle(
|
|
30
|
+
cls,
|
|
31
|
+
name: str,
|
|
32
|
+
transport_factory: Callable[[AsyncExitStack], Any],
|
|
33
|
+
ready_event: asyncio.Event,
|
|
34
|
+
init_timeout: float,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Lifecycle manager for an MCP session running in a background task."""
|
|
37
|
+
async with AsyncExitStack() as stack:
|
|
38
|
+
try:
|
|
39
|
+
# Initialize transport
|
|
40
|
+
read, write = await transport_factory(stack)
|
|
41
|
+
|
|
42
|
+
# Initialize session
|
|
43
|
+
session = await stack.enter_async_context(ClientSession(read, write))
|
|
44
|
+
await asyncio.wait_for(session.initialize(), timeout=init_timeout)
|
|
45
|
+
|
|
46
|
+
# Register session
|
|
47
|
+
cls._sessions[name] = session
|
|
48
|
+
logger.log("TOOL", f"MCP server '{name}' connected and initialized.")
|
|
49
|
+
ready_event.set()
|
|
50
|
+
|
|
51
|
+
# Keep session alive until cancelled
|
|
52
|
+
# We use a Future that never completes to hang here efficiently
|
|
53
|
+
await asyncio.Future()
|
|
54
|
+
|
|
55
|
+
except asyncio.CancelledError:
|
|
56
|
+
logger.log("TOOL", f"MCP server '{name}' session cancelled, cleaning up...")
|
|
57
|
+
raise
|
|
58
|
+
except Exception as e:
|
|
59
|
+
logger.error(f"Error in MCP server '{name}' lifecycle: {e}")
|
|
60
|
+
# If we failed during init, we must signal ready_event so the waiter doesn't hang forever
|
|
61
|
+
if not ready_event.is_set():
|
|
62
|
+
ready_event.set()
|
|
63
|
+
raise
|
|
64
|
+
finally:
|
|
65
|
+
# Cleanup global registry
|
|
66
|
+
if name in cls._sessions:
|
|
67
|
+
cls._sessions.pop(name, None)
|
|
68
|
+
logger.log("TOOL", f"MCP server '{name}' disconnected.")
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
async def connect_stdio(
|
|
72
|
+
cls, name: str, command: str, args: list[str], env: dict[str, str] | None = None, timeout: float = 20.0
|
|
73
|
+
) -> ClientSession:
|
|
74
|
+
"""Connect to an MCP server via stdio transport.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
name: Unique name for the server session.
|
|
78
|
+
command: Command to execute.
|
|
79
|
+
args: Arguments for the command.
|
|
80
|
+
env: Environment variables.
|
|
81
|
+
timeout: Connection timeout in seconds.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The connected ClientSession.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
RuntimeError: If connection fails or times out.
|
|
88
|
+
"""
|
|
89
|
+
async with cls._lock:
|
|
90
|
+
if name in cls._sessions:
|
|
91
|
+
return cls._sessions[name]
|
|
92
|
+
|
|
93
|
+
logger.log("TOOL", f"Connecting to MCP server '{name}' via stdio: {command} {' '.join(args)}")
|
|
94
|
+
|
|
95
|
+
async def stdio_factory(stack: AsyncExitStack):
|
|
96
|
+
server_params = StdioServerParameters(command=command, args=args, env=env)
|
|
97
|
+
return await stack.enter_async_context(stdio_client(server_params))
|
|
98
|
+
|
|
99
|
+
ready_event = asyncio.Event()
|
|
100
|
+
|
|
101
|
+
task = task_manager.run_task(cls._session_lifecycle(name, stdio_factory, ready_event, timeout))
|
|
102
|
+
|
|
103
|
+
if not task:
|
|
104
|
+
raise RuntimeError("Failed to schedule MCP connection task")
|
|
105
|
+
|
|
106
|
+
cls._server_tasks[name] = task
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
await asyncio.wait_for(ready_event.wait(), timeout=timeout + 5.0)
|
|
110
|
+
except Exception:
|
|
111
|
+
task.cancel()
|
|
112
|
+
cls._server_tasks.pop(name, None)
|
|
113
|
+
raise
|
|
114
|
+
|
|
115
|
+
session = cls._sessions.get(name)
|
|
116
|
+
if not session:
|
|
117
|
+
cls._server_tasks.pop(name, None)
|
|
118
|
+
raise RuntimeError(f"Failed to initialize MCP server '{name}'")
|
|
119
|
+
|
|
120
|
+
return session
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
async def connect_sse(cls, name: str, url: str, timeout: float = 20.0) -> ClientSession:
|
|
124
|
+
"""Connect to an MCP server via SSE transport.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
name: Unique name for the server session.
|
|
128
|
+
url: SSE endpoint URL.
|
|
129
|
+
timeout: Connection timeout in seconds.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The connected ClientSession.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
RuntimeError: If connection fails or times out.
|
|
136
|
+
"""
|
|
137
|
+
async with cls._lock:
|
|
138
|
+
if name in cls._sessions:
|
|
139
|
+
return cls._sessions[name]
|
|
140
|
+
|
|
141
|
+
logger.log("TOOL", f"Connecting to MCP server '{name}' via SSE: {url}")
|
|
142
|
+
|
|
143
|
+
async def sse_factory(stack: AsyncExitStack):
|
|
144
|
+
return await stack.enter_async_context(sse_client(url))
|
|
145
|
+
|
|
146
|
+
ready_event = asyncio.Event()
|
|
147
|
+
|
|
148
|
+
task = task_manager.run_task(cls._session_lifecycle(name, sse_factory, ready_event, timeout))
|
|
149
|
+
if not task:
|
|
150
|
+
raise RuntimeError("Failed to schedule MCP connection task")
|
|
151
|
+
|
|
152
|
+
cls._server_tasks[name] = task
|
|
153
|
+
try:
|
|
154
|
+
await asyncio.wait_for(ready_event.wait(), timeout=timeout + 5.0)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logger.error(f"An exception occurred while initializing MCP server: {e}")
|
|
157
|
+
task.cancel()
|
|
158
|
+
cls._server_tasks.pop(name, None)
|
|
159
|
+
raise e
|
|
160
|
+
|
|
161
|
+
session = cls._sessions.get(name)
|
|
162
|
+
if not session:
|
|
163
|
+
cls._server_tasks.pop(name, None)
|
|
164
|
+
raise RuntimeError(f"Failed to initialize MCP server '{name}'")
|
|
165
|
+
|
|
166
|
+
return session
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def associate_tools_with_server(cls, server_name: str, tool_names: list[str]) -> None:
|
|
170
|
+
"""Register tools associated with a specific server session."""
|
|
171
|
+
cls._session_tools_map[server_name] = tool_names
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def pop_server_tools(cls, server_name: str) -> list[str]:
|
|
175
|
+
"""Remove and return tools associated with a specific server session."""
|
|
176
|
+
if server_name not in cls._session_tools_map:
|
|
177
|
+
logger.warning(f"No Tools registered for server {server_name}. Nothing to deregister.")
|
|
178
|
+
return []
|
|
179
|
+
removed_tools = cls._session_tools_map.pop(server_name)
|
|
180
|
+
return removed_tools
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
async def disconnect(cls, name: str) -> list[str]:
|
|
184
|
+
"""Disconnects and cleans up an MCP session."""
|
|
185
|
+
if name not in cls._sessions and name not in cls._server_tasks:
|
|
186
|
+
return []
|
|
187
|
+
|
|
188
|
+
logger.log("TOOL", f"Disconnecting from MCP server '{name}'")
|
|
189
|
+
async with cls._lock:
|
|
190
|
+
# Cancel the background task
|
|
191
|
+
task = cls._server_tasks.pop(name, None)
|
|
192
|
+
if task:
|
|
193
|
+
task.cancel()
|
|
194
|
+
try:
|
|
195
|
+
await task
|
|
196
|
+
except asyncio.CancelledError:
|
|
197
|
+
pass
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.error(f"Error during MCP server '{name}' disconnect: {e}")
|
|
200
|
+
|
|
201
|
+
# Session removal from _sessions happens in _session_lifecycle finally block,
|
|
202
|
+
# but we can double check here or just trust the flow.
|
|
203
|
+
# However, we need to return deregistered tools immediately.
|
|
204
|
+
deregistered_tools = cls.pop_server_tools(name)
|
|
205
|
+
|
|
206
|
+
# Ensure session is gone (in case task died before we cancelled)
|
|
207
|
+
cls._sessions.pop(name, None)
|
|
208
|
+
|
|
209
|
+
return deregistered_tools
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def get_session(cls, name: str) -> ClientSession | None:
|
|
213
|
+
"""Get an active session by name."""
|
|
214
|
+
return cls._sessions.get(name)
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
async def list_tools(cls, name: str):
|
|
218
|
+
"""List tools available on the specified server."""
|
|
219
|
+
session = cls.get_session(name)
|
|
220
|
+
if not session:
|
|
221
|
+
raise ValueError(f"No active session for MCP server: {name}")
|
|
222
|
+
return await session.list_tools()
|
|
223
|
+
|
|
224
|
+
@classmethod
|
|
225
|
+
async def call_tool(
|
|
226
|
+
cls, server_name: str, tool_name: str, arguments: dict[str, Any], timeout: float = 600.0
|
|
227
|
+
) -> CallToolResult:
|
|
228
|
+
"""Call a tool on the specified server.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
server_name: The server session name.
|
|
232
|
+
tool_name: The tool to call.
|
|
233
|
+
arguments: Tool arguments.
|
|
234
|
+
timeout: Execution timeout in seconds.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
The tool execution result.
|
|
238
|
+
"""
|
|
239
|
+
session = cls.get_session(server_name)
|
|
240
|
+
if not session:
|
|
241
|
+
raise ValueError(f"No active session for MCP server: {server_name}")
|
|
242
|
+
return await asyncio.wait_for(session.call_tool(tool_name, arguments), timeout=timeout)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from influxdb_client import Point
|
|
2
|
+
from influxdb_client.client.influxdb_client_async import InfluxDBClientAsync
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
from chibi.config import application_settings
|
|
6
|
+
from chibi.models import User
|
|
7
|
+
from chibi.schemas.app import MetricTagsSchema, UsageSchema
|
|
8
|
+
from chibi.services.task_manager import task_manager
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MetricsService:
|
|
12
|
+
@classmethod
|
|
13
|
+
def _prepare_point(cls, metric: UsageSchema, tags: MetricTagsSchema) -> Point:
|
|
14
|
+
point = Point("usage")
|
|
15
|
+
|
|
16
|
+
for tag_name, value in tags.model_dump(exclude_none=True).items():
|
|
17
|
+
point.tag(tag_name, str(value))
|
|
18
|
+
|
|
19
|
+
exclude_fields = {"completion_tokens_details", "prompt_tokens_details"}
|
|
20
|
+
for field_name, value in metric.model_dump(exclude=exclude_fields).items():
|
|
21
|
+
point.field(field_name, value)
|
|
22
|
+
return point
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
async def _send_to_influx(cls, metric: UsageSchema, tags: MetricTagsSchema) -> None:
|
|
26
|
+
point = cls._prepare_point(metric=metric, tags=tags)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
assert application_settings.influxdb_url
|
|
30
|
+
assert application_settings.influxdb_token
|
|
31
|
+
assert application_settings.influxdb_org
|
|
32
|
+
assert application_settings.influxdb_bucket
|
|
33
|
+
|
|
34
|
+
async with InfluxDBClientAsync(
|
|
35
|
+
url=application_settings.influxdb_url,
|
|
36
|
+
token=application_settings.influxdb_token,
|
|
37
|
+
org=application_settings.influxdb_org,
|
|
38
|
+
) as client:
|
|
39
|
+
write_api = client.write_api()
|
|
40
|
+
await write_api.write(bucket=application_settings.influxdb_bucket, record=point)
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.error(f"Failed to send metrics to InfluxDB due to exception: {e}")
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def send_usage_metrics(cls, metric: UsageSchema, model: str, provider: str, user: User | None = None) -> None:
|
|
46
|
+
if not application_settings.is_influx_configured:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
tags = MetricTagsSchema(
|
|
50
|
+
user_id=user.id if user else 0,
|
|
51
|
+
provider=provider,
|
|
52
|
+
model=model,
|
|
53
|
+
)
|
|
54
|
+
task_manager.run_task(cls._send_to_influx(metric=metric, tags=tags))
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# flake8: noqa: F401
|
|
2
|
+
|
|
3
|
+
from chibi.services.providers.alibaba import Alibaba
|
|
4
|
+
from chibi.services.providers.anthropic import Anthropic
|
|
5
|
+
from chibi.services.providers.cloudflare import Cloudflare
|
|
6
|
+
from chibi.services.providers.customopenai import CustomOpenAI
|
|
7
|
+
from chibi.services.providers.deepseek import DeepSeek
|
|
8
|
+
from chibi.services.providers.eleven_labs import ElevenLabs
|
|
9
|
+
from chibi.services.providers.gemini_native import Gemini
|
|
10
|
+
from chibi.services.providers.grok import Grok
|
|
11
|
+
from chibi.services.providers.minimax import Minimax
|
|
12
|
+
from chibi.services.providers.mistralai_native import MistralAI
|
|
13
|
+
from chibi.services.providers.moonshotai import MoonshotAI
|
|
14
|
+
from chibi.services.providers.openai import OpenAI
|
|
15
|
+
from chibi.services.providers.provider import RegisteredProviders
|
|
16
|
+
from chibi.services.providers.suno import Suno
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from http import HTTPStatus
|
|
2
|
+
|
|
3
|
+
import dashscope
|
|
4
|
+
from dashscope.aigc.image_generation import AioImageGeneration
|
|
5
|
+
from dashscope.api_entities.dashscope_response import Choice, ImageGenerationResponse, Message
|
|
6
|
+
|
|
7
|
+
from chibi.config import gpt_settings
|
|
8
|
+
from chibi.exceptions import ServiceResponseError
|
|
9
|
+
from chibi.schemas.app import ModelChangeSchema
|
|
10
|
+
from chibi.services.providers.provider import OpenAIFriendlyProvider
|
|
11
|
+
|
|
12
|
+
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Alibaba(OpenAIFriendlyProvider):
|
|
16
|
+
api_key = gpt_settings.alibaba_key
|
|
17
|
+
chat_ready = True
|
|
18
|
+
image_generation_ready = True
|
|
19
|
+
moderation_ready = True
|
|
20
|
+
|
|
21
|
+
base_url = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
|
22
|
+
name = "Alibaba"
|
|
23
|
+
model_name_keywords = ["qwen"]
|
|
24
|
+
model_name_keywords_exclude = ["tts", "stt"]
|
|
25
|
+
default_model = "qwen-plus"
|
|
26
|
+
default_moderation_model = "qwen-turbo"
|
|
27
|
+
max_tokens: int = 8192
|
|
28
|
+
default_image_model = "qwen-image-plus"
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def _get_image_url(choice: Choice) -> str | None:
|
|
32
|
+
if isinstance(choice.message.content, str):
|
|
33
|
+
return choice.message.content
|
|
34
|
+
return choice.message.content[0].get("image")
|
|
35
|
+
|
|
36
|
+
async def get_images(self, prompt: str, model: str | None = None) -> list[str]:
|
|
37
|
+
model = model or self.default_model
|
|
38
|
+
message = Message(role="user", content=[{"text": prompt}])
|
|
39
|
+
number_of_images = 1 if "qwen" in model or "z-image" in model else gpt_settings.image_n_choices
|
|
40
|
+
response: ImageGenerationResponse = await AioImageGeneration.call(
|
|
41
|
+
api_key=self.token,
|
|
42
|
+
model=model,
|
|
43
|
+
messages=[message],
|
|
44
|
+
n=number_of_images,
|
|
45
|
+
size=gpt_settings.image_size_alibaba,
|
|
46
|
+
prompt_extend=True,
|
|
47
|
+
watermark=False,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if response.status_code != HTTPStatus.OK:
|
|
51
|
+
raise ServiceResponseError(
|
|
52
|
+
provider=self.name,
|
|
53
|
+
model=model,
|
|
54
|
+
detail=(
|
|
55
|
+
f"Unexpected response status code: {response.status_code}. "
|
|
56
|
+
f"Response code: {response.code}. Message: {response.message}"
|
|
57
|
+
),
|
|
58
|
+
)
|
|
59
|
+
image_urls: list[str] = []
|
|
60
|
+
for choice in response.output.choices:
|
|
61
|
+
if url := self._get_image_url(choice):
|
|
62
|
+
image_urls.append(url)
|
|
63
|
+
return image_urls
|
|
64
|
+
|
|
65
|
+
async def get_available_models(self, image_generation: bool = False) -> list[ModelChangeSchema]:
|
|
66
|
+
models = await super().get_available_models(image_generation=image_generation)
|
|
67
|
+
|
|
68
|
+
if image_generation:
|
|
69
|
+
wan_models = [
|
|
70
|
+
ModelChangeSchema(
|
|
71
|
+
provider=self.name,
|
|
72
|
+
name="wan2.6-t2i",
|
|
73
|
+
display_name="Wan 2.6",
|
|
74
|
+
image_generation=True,
|
|
75
|
+
),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
models += wan_models
|
|
79
|
+
return models
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from anthropic import AsyncClient
|
|
2
|
+
|
|
3
|
+
from chibi.config import gpt_settings
|
|
4
|
+
from chibi.exceptions import NoApiKeyProvidedError
|
|
5
|
+
from chibi.services.providers.provider import AnthropicFriendlyProvider
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Anthropic(AnthropicFriendlyProvider):
|
|
9
|
+
api_key = gpt_settings.anthropic_key
|
|
10
|
+
chat_ready = True
|
|
11
|
+
moderation_ready = True
|
|
12
|
+
|
|
13
|
+
name = "Anthropic"
|
|
14
|
+
model_name_keywords = ["claude"]
|
|
15
|
+
default_model = "claude-sonnet-4-5-20250929"
|
|
16
|
+
default_moderation_model = "claude-haiku-4-5-20251001"
|
|
17
|
+
|
|
18
|
+
def __init__(self, token: str) -> None:
|
|
19
|
+
self._client: AsyncClient | None = None
|
|
20
|
+
super().__init__(token=token)
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def client(self) -> AsyncClient:
|
|
24
|
+
if self._client:
|
|
25
|
+
return self._client
|
|
26
|
+
|
|
27
|
+
if not self.token:
|
|
28
|
+
raise NoApiKeyProvidedError(provider=self.name)
|
|
29
|
+
|
|
30
|
+
self._client = AsyncClient(api_key=self.token)
|
|
31
|
+
return self._client
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def _headers(self) -> dict[str, str]:
|
|
35
|
+
return {
|
|
36
|
+
"Accept": "application/json",
|
|
37
|
+
"Content-Type": "application/json",
|
|
38
|
+
"x-api-key": self.token,
|
|
39
|
+
"anthropic-version": "2023-06-01",
|
|
40
|
+
}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from loguru import logger
|
|
2
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
3
|
+
|
|
4
|
+
from chibi.config import gpt_settings
|
|
5
|
+
from chibi.exceptions import NoAccountIDSetError
|
|
6
|
+
from chibi.schemas.app import ChatResponseSchema, ModelChangeSchema, UsageSchema
|
|
7
|
+
from chibi.schemas.cloudflare import (
|
|
8
|
+
ChatCompletionResponseSchema,
|
|
9
|
+
ModelsSearchResponseSchema,
|
|
10
|
+
)
|
|
11
|
+
from chibi.services.providers.provider import RestApiFriendlyProvider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Cloudflare(RestApiFriendlyProvider):
|
|
15
|
+
api_key = gpt_settings.cloudflare_key
|
|
16
|
+
chat_ready = True
|
|
17
|
+
|
|
18
|
+
name = "Cloudflare"
|
|
19
|
+
model_name_keywords = ["@cf", "@hf"]
|
|
20
|
+
default_model = "@cf/meta/llama-3.2-3b-instruct"
|
|
21
|
+
base_url = f"https://api.cloudflare.com/client/v4/accounts/{gpt_settings.cloudflare_account_id}/ai"
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def _headers(self) -> dict[str, str]:
|
|
25
|
+
return {
|
|
26
|
+
"Authorization": f"Bearer {self.token}",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
async def _get_chat_completion_response(
|
|
30
|
+
self,
|
|
31
|
+
messages: list[ChatCompletionMessageParam],
|
|
32
|
+
model: str,
|
|
33
|
+
system_prompt: str | None = None,
|
|
34
|
+
) -> ChatResponseSchema:
|
|
35
|
+
if not gpt_settings.cloudflare_account_id:
|
|
36
|
+
raise NoAccountIDSetError
|
|
37
|
+
|
|
38
|
+
url = f"{self.base_url}/run/{model}"
|
|
39
|
+
|
|
40
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
41
|
+
|
|
42
|
+
dialog = (
|
|
43
|
+
[system_message] + [dict(m) for m in messages]
|
|
44
|
+
if system_message not in messages
|
|
45
|
+
else [dict(m) for m in messages]
|
|
46
|
+
)
|
|
47
|
+
data = {"messages": dialog}
|
|
48
|
+
response = await self._request(method="POST", url=url, data=data)
|
|
49
|
+
|
|
50
|
+
response_data = ChatCompletionResponseSchema(**response.json())
|
|
51
|
+
if response_data.success:
|
|
52
|
+
answer_data = response_data.result
|
|
53
|
+
answer = answer_data.response
|
|
54
|
+
usage = UsageSchema(
|
|
55
|
+
completion_tokens=answer_data.usage.completion_tokens,
|
|
56
|
+
prompt_tokens=answer_data.usage.prompt_tokens,
|
|
57
|
+
total_tokens=answer_data.usage.total_tokens,
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
answer = ""
|
|
61
|
+
usage = None
|
|
62
|
+
|
|
63
|
+
return ChatResponseSchema(answer=answer, provider=self.name, model=model, usage=usage)
|
|
64
|
+
|
|
65
|
+
async def get_available_models(self, image_generation: bool = False) -> list[ModelChangeSchema]:
|
|
66
|
+
if image_generation:
|
|
67
|
+
return []
|
|
68
|
+
|
|
69
|
+
if not gpt_settings.cloudflare_account_id:
|
|
70
|
+
logger.error("No Cloudflare account ID set. Please, check the CLOUDFLARE_ACCOUNT_ID env value.")
|
|
71
|
+
return []
|
|
72
|
+
|
|
73
|
+
url = f"{self.base_url}/models/search"
|
|
74
|
+
params = {"task": "Text Generation"}
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
response = await self._request(method="GET", url=url, params=params)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"Failed to get available models for provider {self.name} due to exception: {e}")
|
|
80
|
+
return []
|
|
81
|
+
|
|
82
|
+
data = response.json()
|
|
83
|
+
response_data = ModelsSearchResponseSchema(**data)
|
|
84
|
+
|
|
85
|
+
all_models = [
|
|
86
|
+
ModelChangeSchema(
|
|
87
|
+
provider=self.name,
|
|
88
|
+
name=model.name,
|
|
89
|
+
image_generation=False,
|
|
90
|
+
)
|
|
91
|
+
for model in response_data.result
|
|
92
|
+
]
|
|
93
|
+
all_models.sort(key=lambda model: model.name)
|
|
94
|
+
|
|
95
|
+
if gpt_settings.models_whitelist:
|
|
96
|
+
return [model for model in all_models if model.name in gpt_settings.models_whitelist]
|
|
97
|
+
|
|
98
|
+
return all_models
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from chibi.config import gpt_settings
|
|
2
|
+
from chibi.services.providers.provider import OpenAIFriendlyProvider
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CustomOpenAI(OpenAIFriendlyProvider):
|
|
6
|
+
api_key = gpt_settings.customopenai_key
|
|
7
|
+
chat_ready = True
|
|
8
|
+
|
|
9
|
+
name = "CustomOpenAI"
|
|
10
|
+
base_url = gpt_settings.customopenai_url
|
|
11
|
+
default_model = ""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from chibi.config import gpt_settings
|
|
2
|
+
from chibi.services.providers.provider import OpenAIFriendlyProvider
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DeepSeek(OpenAIFriendlyProvider):
|
|
6
|
+
api_key = gpt_settings.deepseek_key
|
|
7
|
+
chat_ready = True
|
|
8
|
+
moderation_ready = True
|
|
9
|
+
|
|
10
|
+
name = "DeepSeek"
|
|
11
|
+
model_name_keywords = ["deepseek"]
|
|
12
|
+
base_url = "https://api.deepseek.com"
|
|
13
|
+
default_model = "deepseek-chat"
|
|
14
|
+
default_moderation_model = "deepseek-chat"
|
|
15
|
+
max_tokens: int = 8192
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from io import BytesIO
|
|
2
|
+
|
|
3
|
+
from elevenlabs.client import AsyncElevenLabs
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from chibi.config import application_settings, gpt_settings
|
|
7
|
+
from chibi.exceptions import NoApiKeyProvidedError
|
|
8
|
+
from chibi.schemas.app import ModelChangeSchema
|
|
9
|
+
from chibi.services.providers.provider import Provider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ElevenLabs(Provider):
|
|
13
|
+
api_key = gpt_settings.elevenlabs_api_key
|
|
14
|
+
chat_ready = False
|
|
15
|
+
tts_ready = True
|
|
16
|
+
stt_ready = True
|
|
17
|
+
|
|
18
|
+
name = "ElevenLabs"
|
|
19
|
+
default_stt_model = "scribe_v1"
|
|
20
|
+
tag_audio_events = False
|
|
21
|
+
language_code = None
|
|
22
|
+
default_tts_voice = "cgSgspJ2msm6clMCkdW9"
|
|
23
|
+
default_tts_model = "eleven_multilingual_v2"
|
|
24
|
+
output_format = "mp3_44100_128"
|
|
25
|
+
music_length_ms = 10000
|
|
26
|
+
|
|
27
|
+
def __init__(self, token: str) -> None:
|
|
28
|
+
self._client: AsyncElevenLabs | None = None
|
|
29
|
+
super().__init__(token=token)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def client(self) -> AsyncElevenLabs:
|
|
33
|
+
if self._client:
|
|
34
|
+
return self._client
|
|
35
|
+
|
|
36
|
+
if not self.token:
|
|
37
|
+
raise NoApiKeyProvidedError(provider=self.name)
|
|
38
|
+
|
|
39
|
+
self._client = AsyncElevenLabs(
|
|
40
|
+
api_key=self.token,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return self._client
|
|
44
|
+
|
|
45
|
+
async def speech(self, text: str, voice: str | None = None, model: str | None = None) -> bytes:
|
|
46
|
+
voice = voice or self.default_tts_voice
|
|
47
|
+
model = model or self.default_tts_model
|
|
48
|
+
logger.info(f"Recording a voice message with model {model}...")
|
|
49
|
+
response = self.client.text_to_speech.convert(
|
|
50
|
+
text=text,
|
|
51
|
+
voice_id=voice,
|
|
52
|
+
model_id=model,
|
|
53
|
+
output_format=self.output_format,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
buf = bytearray()
|
|
57
|
+
async for chunk in response:
|
|
58
|
+
buf.extend(chunk)
|
|
59
|
+
|
|
60
|
+
return bytes(buf)
|
|
61
|
+
|
|
62
|
+
async def transcribe(self, audio: BytesIO, model: str | None = None) -> str:
|
|
63
|
+
model = model or self.default_stt_model
|
|
64
|
+
logger.info(f"Transcribing audio with model {model}...")
|
|
65
|
+
response = await self.client.speech_to_text.convert(
|
|
66
|
+
file=audio, model_id=model, tag_audio_events=self.tag_audio_events, language_code=self.language_code
|
|
67
|
+
)
|
|
68
|
+
if response:
|
|
69
|
+
if application_settings.log_prompt_data:
|
|
70
|
+
logger.info(f"Transcribed text: {response.text}")
|
|
71
|
+
return response.text
|
|
72
|
+
raise ValueError("Could not transcribe audio message")
|
|
73
|
+
|
|
74
|
+
async def generate_music(self, prompt: str, music_length_ms: int = music_length_ms) -> bytes:
|
|
75
|
+
logger.info(f"Generating music for {music_length_ms} ms...")
|
|
76
|
+
response = self.client.music.compose(prompt=prompt, music_length_ms=music_length_ms)
|
|
77
|
+
|
|
78
|
+
buf = bytearray()
|
|
79
|
+
async for chunk in response:
|
|
80
|
+
buf.extend(chunk)
|
|
81
|
+
|
|
82
|
+
return bytes(buf)
|
|
83
|
+
|
|
84
|
+
async def get_available_models(self, image_generation: bool = False) -> list[ModelChangeSchema]:
|
|
85
|
+
return []
|