chibi-bot 1.6.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chibi/__init__.py +0 -0
- chibi/__main__.py +343 -0
- chibi/cli.py +90 -0
- chibi/config/__init__.py +6 -0
- chibi/config/app.py +123 -0
- chibi/config/gpt.py +108 -0
- chibi/config/logging.py +15 -0
- chibi/config/telegram.py +43 -0
- chibi/config_generator.py +233 -0
- chibi/constants.py +362 -0
- chibi/exceptions.py +58 -0
- chibi/models.py +496 -0
- chibi/schemas/__init__.py +0 -0
- chibi/schemas/anthropic.py +20 -0
- chibi/schemas/app.py +54 -0
- chibi/schemas/cloudflare.py +65 -0
- chibi/schemas/mistralai.py +56 -0
- chibi/schemas/suno.py +83 -0
- chibi/service.py +135 -0
- chibi/services/bot.py +276 -0
- chibi/services/lock_manager.py +20 -0
- chibi/services/mcp/manager.py +242 -0
- chibi/services/metrics.py +54 -0
- chibi/services/providers/__init__.py +16 -0
- chibi/services/providers/alibaba.py +79 -0
- chibi/services/providers/anthropic.py +40 -0
- chibi/services/providers/cloudflare.py +98 -0
- chibi/services/providers/constants/suno.py +2 -0
- chibi/services/providers/customopenai.py +11 -0
- chibi/services/providers/deepseek.py +15 -0
- chibi/services/providers/eleven_labs.py +85 -0
- chibi/services/providers/gemini_native.py +489 -0
- chibi/services/providers/grok.py +40 -0
- chibi/services/providers/minimax.py +96 -0
- chibi/services/providers/mistralai_native.py +312 -0
- chibi/services/providers/moonshotai.py +20 -0
- chibi/services/providers/openai.py +74 -0
- chibi/services/providers/provider.py +892 -0
- chibi/services/providers/suno.py +130 -0
- chibi/services/providers/tools/__init__.py +23 -0
- chibi/services/providers/tools/cmd.py +132 -0
- chibi/services/providers/tools/common.py +127 -0
- chibi/services/providers/tools/constants.py +78 -0
- chibi/services/providers/tools/exceptions.py +1 -0
- chibi/services/providers/tools/file_editor.py +875 -0
- chibi/services/providers/tools/mcp_management.py +274 -0
- chibi/services/providers/tools/mcp_simple.py +72 -0
- chibi/services/providers/tools/media.py +451 -0
- chibi/services/providers/tools/memory.py +252 -0
- chibi/services/providers/tools/schemas.py +10 -0
- chibi/services/providers/tools/send.py +435 -0
- chibi/services/providers/tools/tool.py +163 -0
- chibi/services/providers/tools/utils.py +146 -0
- chibi/services/providers/tools/web.py +261 -0
- chibi/services/providers/utils.py +182 -0
- chibi/services/task_manager.py +93 -0
- chibi/services/user.py +269 -0
- chibi/storage/abstract.py +54 -0
- chibi/storage/database.py +86 -0
- chibi/storage/dynamodb.py +257 -0
- chibi/storage/local.py +70 -0
- chibi/storage/redis.py +91 -0
- chibi/utils/__init__.py +0 -0
- chibi/utils/app.py +249 -0
- chibi/utils/telegram.py +521 -0
- chibi_bot-1.6.0b0.dist-info/LICENSE +21 -0
- chibi_bot-1.6.0b0.dist-info/METADATA +340 -0
- chibi_bot-1.6.0b0.dist-info/RECORD +70 -0
- chibi_bot-1.6.0b0.dist-info/WHEEL +4 -0
- chibi_bot-1.6.0b0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Callable, Coroutine, ParamSpec, Type, TypeAlias, TypeVar
|
|
4
|
+
|
|
5
|
+
from anthropic.types import (
|
|
6
|
+
Message as AnthropicMessage,
|
|
7
|
+
)
|
|
8
|
+
from google.genai.types import GenerateContentResponse
|
|
9
|
+
from mistralai import ChatCompletionResponse
|
|
10
|
+
from openai.types import CompletionUsage
|
|
11
|
+
from openai.types.chat import ChatCompletion
|
|
12
|
+
from telegram import Update, constants
|
|
13
|
+
from telegram.ext import ContextTypes
|
|
14
|
+
|
|
15
|
+
from chibi.config import gpt_settings
|
|
16
|
+
from chibi.models import User
|
|
17
|
+
from chibi.schemas.app import UsageSchema
|
|
18
|
+
from chibi.schemas.suno import SunoGetGenerationDetailsSchema
|
|
19
|
+
from chibi.utils.app import get_builtin_skill_names
|
|
20
|
+
|
|
21
|
+
T = TypeVar("T")
|
|
22
|
+
P = ParamSpec("P")
|
|
23
|
+
M = TypeVar("M", bound=Callable[..., Coroutine[Any, Any, Any]])
|
|
24
|
+
AsyncFunc: TypeAlias = Callable[P, Coroutine[Any, Any, T]]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def decorate_async_methods(decorator: Callable[[M], M]) -> Callable[[Type[T]], Type[T]]:
|
|
28
|
+
def decorate(cls: Type[T]) -> Type[T]:
|
|
29
|
+
for attr in cls.__dict__:
|
|
30
|
+
if inspect.iscoroutinefunction(getattr(cls, attr)):
|
|
31
|
+
original_func = getattr(cls, attr)
|
|
32
|
+
decorated_func = decorator(original_func)
|
|
33
|
+
setattr(cls, attr, decorated_func)
|
|
34
|
+
return cls
|
|
35
|
+
|
|
36
|
+
return decorate
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def escape_and_truncate(message: str | dict[str, Any] | list[dict[str, Any]] | None, limit: int = 50) -> str:
|
|
40
|
+
if not message:
|
|
41
|
+
return "no data"
|
|
42
|
+
|
|
43
|
+
if isinstance(message, dict):
|
|
44
|
+
return json.dumps({k: escape_and_truncate(message=v, limit=limit) for k, v in message.items()})
|
|
45
|
+
|
|
46
|
+
if isinstance(message, list):
|
|
47
|
+
return json.dumps([escape_and_truncate(message=m, limit=limit) for m in message])
|
|
48
|
+
|
|
49
|
+
escaped_message = str(message).replace("<", r"\<").replace(">", r"\>")
|
|
50
|
+
if len(escaped_message) < limit + 20:
|
|
51
|
+
return escaped_message
|
|
52
|
+
return f"{escaped_message[:limit]}... (truncated)"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def prepare_system_prompt(base_system_prompt: str, user: User | None = None) -> str:
|
|
56
|
+
if user is None:
|
|
57
|
+
return base_system_prompt
|
|
58
|
+
|
|
59
|
+
prompt = {
|
|
60
|
+
"current_working_dir": user.working_dir,
|
|
61
|
+
"user_id": user.id,
|
|
62
|
+
"user_info": user.info,
|
|
63
|
+
"system_prompt": base_system_prompt,
|
|
64
|
+
"available_builtin_skills": get_builtin_skill_names(),
|
|
65
|
+
"activated_skills": user.llm_skills,
|
|
66
|
+
}
|
|
67
|
+
return json.dumps(prompt)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def send_llm_thoughts(thoughts: str, update: Update | None, context: ContextTypes.DEFAULT_TYPE | None) -> None:
|
|
71
|
+
if not gpt_settings.show_llm_thoughts:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
from chibi.utils.telegram import send_long_message
|
|
75
|
+
|
|
76
|
+
if update is None or context is None:
|
|
77
|
+
return None
|
|
78
|
+
message = f"💡💭 {thoughts}"
|
|
79
|
+
await send_long_message(
|
|
80
|
+
message=message,
|
|
81
|
+
update=update,
|
|
82
|
+
context=context,
|
|
83
|
+
parse_mode=constants.ParseMode.MARKDOWN_V2,
|
|
84
|
+
reply=False,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_usage_from_anthropic_response(response_message: AnthropicMessage) -> UsageSchema:
|
|
89
|
+
return UsageSchema(
|
|
90
|
+
completion_tokens=response_message.usage.output_tokens,
|
|
91
|
+
prompt_tokens=response_message.usage.input_tokens,
|
|
92
|
+
cache_creation_input_tokens=response_message.usage.cache_creation_input_tokens or 0,
|
|
93
|
+
cache_read_input_tokens=response_message.usage.cache_read_input_tokens or 0,
|
|
94
|
+
total_tokens=response_message.usage.output_tokens + response_message.usage.input_tokens,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_usage_from_openai_response(response_message: ChatCompletion) -> UsageSchema:
|
|
99
|
+
if response_message.usage is None:
|
|
100
|
+
return UsageSchema()
|
|
101
|
+
response_usage = response_message.usage
|
|
102
|
+
usage = UsageSchema(
|
|
103
|
+
completion_tokens=response_usage.completion_tokens,
|
|
104
|
+
prompt_tokens=response_usage.prompt_tokens,
|
|
105
|
+
total_tokens=response_usage.total_tokens,
|
|
106
|
+
)
|
|
107
|
+
if prompt_cache := response_usage.prompt_tokens_details:
|
|
108
|
+
usage.cache_read_input_tokens = prompt_cache.cached_tokens or 0
|
|
109
|
+
return usage
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_usage_from_google_response(response_message: GenerateContentResponse) -> UsageSchema:
|
|
113
|
+
if not response_message.usage_metadata:
|
|
114
|
+
return UsageSchema()
|
|
115
|
+
|
|
116
|
+
return UsageSchema(
|
|
117
|
+
total_tokens=response_message.usage_metadata.total_token_count or 0,
|
|
118
|
+
completion_tokens=response_message.usage_metadata.candidates_token_count or 0,
|
|
119
|
+
prompt_tokens=response_message.usage_metadata.prompt_token_count or 0,
|
|
120
|
+
cache_read_input_tokens=response_message.usage_metadata.cached_content_token_count or 0,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_usage_from_mistral_response(response_message: ChatCompletionResponse) -> UsageSchema:
|
|
125
|
+
return UsageSchema(
|
|
126
|
+
completion_tokens=response_message.usage.completion_tokens or 0,
|
|
127
|
+
prompt_tokens=response_message.usage.prompt_tokens or 0,
|
|
128
|
+
cache_creation_input_tokens=0,
|
|
129
|
+
cache_read_input_tokens=0,
|
|
130
|
+
total_tokens=response_message.usage.total_tokens or 0,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_usage_msg(usage: UsageSchema | CompletionUsage | None) -> str:
|
|
135
|
+
if usage is None:
|
|
136
|
+
return ""
|
|
137
|
+
cache_read = getattr(usage, "cache_read_input_tokens", None)
|
|
138
|
+
cache_create = getattr(usage, "cache_creation_input_tokens", None)
|
|
139
|
+
return (
|
|
140
|
+
f"Tokens used: {getattr(usage, 'total_tokens', None) or 'n/a'} "
|
|
141
|
+
f"({getattr(usage, 'prompt_tokens', None)} prompt, "
|
|
142
|
+
f"{getattr(usage, 'completion_tokens', None)} completion, "
|
|
143
|
+
f"{cache_read or 0} cached read/prompt, "
|
|
144
|
+
f"{cache_create or 0} cached creation)"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def suno_task_still_processing(task_data_response: SunoGetGenerationDetailsSchema) -> bool:
|
|
149
|
+
return task_data_response.is_in_progress
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# def limit_recursion(
|
|
153
|
+
# max_depth: int = application_settings.max_consecutive_tool_calls,
|
|
154
|
+
# ) -> Callable[[AsyncFunc[P, T]], AsyncFunc[P, T]]:
|
|
155
|
+
# def decorator(func: AsyncFunc[P, T]) -> AsyncFunc[P, T]:
|
|
156
|
+
# depth_var: ContextVar[int] = ContextVar(f"{func.__name__}_depth", default=0)
|
|
157
|
+
#
|
|
158
|
+
# @wraps(func)
|
|
159
|
+
# async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
160
|
+
# current_depth = depth_var.get()
|
|
161
|
+
# depth_var.set(current_depth + 1)
|
|
162
|
+
# if depth_var.get() > max_depth + 1:
|
|
163
|
+
# depth_var.set(current_depth)
|
|
164
|
+
# class_name = ""
|
|
165
|
+
# if args and hasattr(args[0], "__class__"):
|
|
166
|
+
# class_name = f"{args[0].__class__.__name__}."
|
|
167
|
+
# raise RecursionLimitExceeded(
|
|
168
|
+
# provider=class_name,
|
|
169
|
+
# model=cast(str, kwargs.get("model", "unknown")),
|
|
170
|
+
# detail=f"Recursion depth exceeded: {max_depth} (function: {class_name}{func.__name__})",
|
|
171
|
+
# exceeded_limit=max_depth,
|
|
172
|
+
# )
|
|
173
|
+
#
|
|
174
|
+
# try:
|
|
175
|
+
# result = await func(*args, **kwargs)
|
|
176
|
+
# return result
|
|
177
|
+
# finally:
|
|
178
|
+
# depth_var.set(current_depth)
|
|
179
|
+
#
|
|
180
|
+
# return async_wrapper
|
|
181
|
+
#
|
|
182
|
+
# return decorator
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any, Coroutine
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from chibi.utils.app import SingletonMeta
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BackgroundTaskManager(metaclass=SingletonMeta):
|
|
10
|
+
def __init__(self) -> None:
|
|
11
|
+
"""Initialize the task manager."""
|
|
12
|
+
if not hasattr(self, "_tasks"):
|
|
13
|
+
self._tasks: set[asyncio.Task] = set()
|
|
14
|
+
self._shutting_down: bool = False
|
|
15
|
+
|
|
16
|
+
async def _wrap_with_timeout(self, coro: Coroutine[Any, Any, Any], timeout: float) -> Any:
|
|
17
|
+
"""Wrap a coroutine with a timeout.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
coro: The coroutine to wrap.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The wrapped coroutine result.
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
return await asyncio.wait_for(fut=coro, timeout=timeout)
|
|
27
|
+
except asyncio.TimeoutError:
|
|
28
|
+
logger.warning(f"Background task timed out after {timeout}s")
|
|
29
|
+
raise
|
|
30
|
+
|
|
31
|
+
def run_task(self, coro: Coroutine[Any, Any, Any], timeout: float | None = None) -> asyncio.Task | None:
|
|
32
|
+
"""Schedule a coroutine to run in the background.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
coro: the coroutine to run
|
|
36
|
+
timeout: optional timeout in seconds. If the task doesn't complete
|
|
37
|
+
within this time, it will be cancelled with TimeoutError
|
|
38
|
+
"""
|
|
39
|
+
if self._shutting_down:
|
|
40
|
+
logger.warning("Task manager is shutting down, refusing new task")
|
|
41
|
+
coro.close()
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
# Wrap with timeout if specified
|
|
45
|
+
if timeout is not None:
|
|
46
|
+
coro = self._wrap_with_timeout(coro, timeout)
|
|
47
|
+
|
|
48
|
+
task = asyncio.create_task(coro)
|
|
49
|
+
self._tasks.add(task)
|
|
50
|
+
task.add_done_callback(self._discard_task)
|
|
51
|
+
return task
|
|
52
|
+
|
|
53
|
+
def _discard_task(self, task: asyncio.Task) -> None:
|
|
54
|
+
"""Callback to remove a task from the set when it finishes."""
|
|
55
|
+
try:
|
|
56
|
+
exc = task.exception()
|
|
57
|
+
if exc:
|
|
58
|
+
logger.error(
|
|
59
|
+
f"Background task '{task.get_name()}' failed: {exc.__class__.__name__} ({str(exc) or 'no details'})"
|
|
60
|
+
)
|
|
61
|
+
except asyncio.CancelledError:
|
|
62
|
+
pass
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.error(f"Error checking background task result: {e}")
|
|
65
|
+
finally:
|
|
66
|
+
self._tasks.discard(task)
|
|
67
|
+
|
|
68
|
+
async def shutdown(self, *args: Any) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Wait for all background tasks to complete with a timeout.
|
|
71
|
+
If tasks do not finish within 15 seconds, they are cancelled.
|
|
72
|
+
"""
|
|
73
|
+
logger.info("Shutting down background tasks...")
|
|
74
|
+
self._shutting_down = True
|
|
75
|
+
tasks_to_wait = list(self._tasks)
|
|
76
|
+
if not tasks_to_wait:
|
|
77
|
+
logger.info("No background tasks to wait, we're good.")
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
logger.info(f"Waiting for {len(tasks_to_wait)} background tasks to complete...")
|
|
81
|
+
try:
|
|
82
|
+
await asyncio.wait_for(asyncio.gather(*tasks_to_wait, return_exceptions=True), timeout=5.0)
|
|
83
|
+
logger.info("All background tasks completed.")
|
|
84
|
+
|
|
85
|
+
except asyncio.TimeoutError:
|
|
86
|
+
logger.warning("Timeout reached. Cancelling remaining background tasks...")
|
|
87
|
+
remaining = [t for t in tasks_to_wait if not t.done()]
|
|
88
|
+
for task in remaining:
|
|
89
|
+
task.cancel()
|
|
90
|
+
logger.info(f"Cancelled {len(remaining)} remaining tasks.")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
task_manager = BackgroundTaskManager()
|
chibi/services/user.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import json
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from datetime import timezone
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
7
|
+
|
|
8
|
+
from aiocache import cached
|
|
9
|
+
from telegram import Update
|
|
10
|
+
from telegram.ext import ContextTypes
|
|
11
|
+
|
|
12
|
+
from chibi.config import gpt_settings
|
|
13
|
+
from chibi.models import Message, User
|
|
14
|
+
from chibi.schemas.app import ChatResponseSchema, ModelChangeSchema
|
|
15
|
+
from chibi.services.lock_manager import LockManager
|
|
16
|
+
from chibi.storage.abstract import Database
|
|
17
|
+
from chibi.storage.database import inject_database
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from chibi.services.providers.provider import Provider
|
|
21
|
+
from chibi.services.providers.tools import ToolResponse
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@inject_database
|
|
25
|
+
async def get_chibi_user(db: Database, user_id: int) -> User:
|
|
26
|
+
return await db.get_or_create_user(user_id=user_id)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@inject_database
|
|
30
|
+
async def set_active_model(db: Database, user_id: int, model: ModelChangeSchema) -> None:
|
|
31
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
32
|
+
if model.image_generation:
|
|
33
|
+
user.selected_image_model_name = model.name
|
|
34
|
+
user.selected_image_provider_name = model.provider
|
|
35
|
+
else:
|
|
36
|
+
user.selected_gpt_model_name = model.name
|
|
37
|
+
user.selected_gpt_provider_name = model.provider
|
|
38
|
+
await db.save_user(user)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@inject_database
|
|
42
|
+
async def reset_chat_history(db: Database, user_id: int) -> None:
|
|
43
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
44
|
+
await db.drop_messages(user=user)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@inject_database
|
|
48
|
+
async def emergency_summarization(db: Database, user_id: int) -> None:
|
|
49
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
50
|
+
|
|
51
|
+
chat_history = await db.get_messages(user=user)
|
|
52
|
+
chat_history_string = str(msg for msg in chat_history if not any((msg.get("tool_calls"), msg.get("tool_call_id"))))
|
|
53
|
+
user_messages: list[Message] = [Message(role="user", content=chat_history_string)]
|
|
54
|
+
|
|
55
|
+
response, _ = await user.active_gpt_provider.get_chat_response(
|
|
56
|
+
messages=user_messages,
|
|
57
|
+
user=user,
|
|
58
|
+
system_prompt="Summarize this conversation, keeping the most important and useful information using English.",
|
|
59
|
+
)
|
|
60
|
+
initial_message = Message(role="user", content="What we were talking about?")
|
|
61
|
+
answer_message = Message(role="assistant", content=response.answer)
|
|
62
|
+
await reset_chat_history(user_id=user_id)
|
|
63
|
+
await db.add_message(user=user, message=initial_message, ttl=gpt_settings.messages_ttl)
|
|
64
|
+
await db.add_message(user=user, message=answer_message, ttl=gpt_settings.messages_ttl)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@inject_database
|
|
68
|
+
async def get_llm_chat_completion_answer(
|
|
69
|
+
db: Database,
|
|
70
|
+
user_id: int,
|
|
71
|
+
update: Update,
|
|
72
|
+
context: ContextTypes.DEFAULT_TYPE,
|
|
73
|
+
user_text_message: str | None = None,
|
|
74
|
+
user_voice_message: BytesIO | None = None,
|
|
75
|
+
tool_message: Optional["ToolResponse"] = None,
|
|
76
|
+
) -> ChatResponseSchema:
|
|
77
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
78
|
+
lock = await LockManager().get_lock(key=user_id)
|
|
79
|
+
|
|
80
|
+
if not user_text_message and not user_voice_message and not tool_message:
|
|
81
|
+
raise ValueError("No prompt data provided")
|
|
82
|
+
|
|
83
|
+
if user_voice_message and not user.stt_provider:
|
|
84
|
+
raise ValueError("Can't compute voice message: no STT provide available.")
|
|
85
|
+
|
|
86
|
+
prompt: dict[str, Any]
|
|
87
|
+
|
|
88
|
+
if tool_message:
|
|
89
|
+
prompt = {
|
|
90
|
+
"type": "tool response",
|
|
91
|
+
"desc": "background task is done",
|
|
92
|
+
"tool_name": tool_message.tool_name,
|
|
93
|
+
"tool_response": tool_message.model_dump(),
|
|
94
|
+
"datetime_now": datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z%z"),
|
|
95
|
+
}
|
|
96
|
+
else:
|
|
97
|
+
user_message = (
|
|
98
|
+
await user.stt_provider.transcribe(audio=user_voice_message) if user_voice_message else user_text_message
|
|
99
|
+
)
|
|
100
|
+
assert user_message
|
|
101
|
+
prompt = {
|
|
102
|
+
"prompt": user_message,
|
|
103
|
+
"datetime_now": datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z%z"),
|
|
104
|
+
"type": "user message",
|
|
105
|
+
"transcribed_from_voice_message": bool(user_voice_message),
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
async with lock:
|
|
109
|
+
conversation_messages: list[Message] = await db.get_conversation_messages(user=user)
|
|
110
|
+
new_message_to_llm = Message(role="user", content=json.dumps(prompt))
|
|
111
|
+
conversation_messages.append(new_message_to_llm)
|
|
112
|
+
|
|
113
|
+
chat_response, new_messages = await user.active_gpt_provider.get_chat_response(
|
|
114
|
+
messages=conversation_messages,
|
|
115
|
+
user=user,
|
|
116
|
+
model=user.selected_gpt_model_name,
|
|
117
|
+
update=update,
|
|
118
|
+
context=context,
|
|
119
|
+
)
|
|
120
|
+
await db.add_message(user=user, message=new_message_to_llm, ttl=gpt_settings.messages_ttl)
|
|
121
|
+
for message in new_messages:
|
|
122
|
+
await db.add_message(user=user, message=message, ttl=gpt_settings.messages_ttl)
|
|
123
|
+
return chat_response
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@inject_database
|
|
127
|
+
async def check_history_and_summarize(db: Database, user_id: int) -> bool:
|
|
128
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
129
|
+
messages = await db.get_messages(user=user)
|
|
130
|
+
# Roughly estimating how many tokens the current conversation history will comprise. It is possible to calculate
|
|
131
|
+
# this accurately, but the modules that can be used for this need to be separately built for armv7, which is
|
|
132
|
+
# difficult to do right now (but will be done further, I hope).
|
|
133
|
+
if len(str(messages)) / 4 >= gpt_settings.max_history_tokens:
|
|
134
|
+
await emergency_summarization(user_id=user_id)
|
|
135
|
+
return True
|
|
136
|
+
return False
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@inject_database
|
|
140
|
+
async def generate_image(
|
|
141
|
+
db: Database, user_id: int, prompt: str, model: str | None = None, provider_name: str | None = None
|
|
142
|
+
) -> list[str] | list[BytesIO]:
|
|
143
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
144
|
+
|
|
145
|
+
if provider_name:
|
|
146
|
+
provider = user.providers.get(provider_name)
|
|
147
|
+
selected_model = model
|
|
148
|
+
elif user.selected_image_provider_name:
|
|
149
|
+
provider = user.active_image_provider
|
|
150
|
+
selected_model = model or user.selected_image_model_name
|
|
151
|
+
else:
|
|
152
|
+
provider = user.active_image_provider
|
|
153
|
+
selected_model = None
|
|
154
|
+
if not provider:
|
|
155
|
+
raise ValueError(f"User {user_id}: no image provider available.")
|
|
156
|
+
images = await provider.get_images(prompt=prompt, model=selected_model)
|
|
157
|
+
if user_id not in gpt_settings.image_generations_whitelist:
|
|
158
|
+
await db.count_image(user_id)
|
|
159
|
+
return images
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@cached(ttl=300)
|
|
163
|
+
@inject_database
|
|
164
|
+
async def get_user_cached_models(db: Database, user_id: int, image_generation: bool = False) -> list[ModelChangeSchema]:
|
|
165
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
166
|
+
return await user.get_available_models(image_generation=image_generation)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@inject_database
|
|
170
|
+
async def get_models_available(db: Database, user_id: int, image_generation: bool = False) -> list[ModelChangeSchema]:
|
|
171
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
172
|
+
user_models = await get_user_cached_models(user_id=user_id, image_generation=image_generation)
|
|
173
|
+
|
|
174
|
+
if not user_models:
|
|
175
|
+
return []
|
|
176
|
+
|
|
177
|
+
available_models = deepcopy(user_models)
|
|
178
|
+
if image_generation:
|
|
179
|
+
active_model = user.selected_image_model_name or user.active_image_provider.default_image_model
|
|
180
|
+
else:
|
|
181
|
+
active_model = user.selected_gpt_model_name or user.active_gpt_provider.default_image_model
|
|
182
|
+
|
|
183
|
+
for model in available_models:
|
|
184
|
+
if model.name == active_model:
|
|
185
|
+
model.display_name = f"🟢 {model.display_name}️"
|
|
186
|
+
return available_models
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@inject_database
|
|
190
|
+
async def user_has_reached_images_generation_limit(db: Database, user_id: int) -> bool:
|
|
191
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
192
|
+
return user.has_reached_image_limits
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@inject_database
|
|
196
|
+
async def set_api_key(db: Database, user_id: int, api_key: str, provider_name: str) -> None:
|
|
197
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
198
|
+
user.tokens[provider_name] = api_key
|
|
199
|
+
await db.save_user(user)
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@inject_database
|
|
204
|
+
async def get_info(db: Database, user_id: int) -> str:
|
|
205
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
206
|
+
return user.info
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@inject_database
|
|
210
|
+
async def set_info(db: Database, user_id: int, new_info: str) -> None:
|
|
211
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
212
|
+
user.info = new_info
|
|
213
|
+
await db.save_user(user)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@inject_database
|
|
217
|
+
async def activate_llm_skill(db: Database, user_id: int, skill_name: str, skill_payload: str) -> None:
|
|
218
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
219
|
+
user.llm_skills[skill_name] = skill_payload
|
|
220
|
+
await db.save_user(user)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@inject_database
|
|
224
|
+
async def deactivate_llm_skill(db: Database, user_id: int, skill_name: str) -> None:
|
|
225
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
226
|
+
if skill_name not in user.llm_skills.keys():
|
|
227
|
+
raise ValueError(f"The skill {skill_name} seems never been activated")
|
|
228
|
+
user.llm_skills.pop(skill_name)
|
|
229
|
+
await db.save_user(user)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@inject_database
|
|
233
|
+
async def set_working_dir(db: Database, user_id: int, new_wd: str) -> None:
|
|
234
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
235
|
+
user.working_dir = new_wd
|
|
236
|
+
await db.save_user(user)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@inject_database
|
|
240
|
+
async def get_cwd(db: Database, user_id: int) -> str:
|
|
241
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
242
|
+
return user.working_dir
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@inject_database
|
|
246
|
+
async def get_moderation_provider(db: Database, user_id: int) -> "Provider":
|
|
247
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
248
|
+
return user.moderation_provider
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@inject_database
|
|
252
|
+
async def drop_tool_call_history(db: Database, user_id: int) -> None:
|
|
253
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
254
|
+
chat_history: list[Message] = await db.get_conversation_messages(user=user)
|
|
255
|
+
await reset_chat_history(user_id=user_id)
|
|
256
|
+
for message in chat_history:
|
|
257
|
+
if message.role == "tool":
|
|
258
|
+
continue
|
|
259
|
+
message.tool_calls = None
|
|
260
|
+
message.tool_call_id = None
|
|
261
|
+
await db.add_message(user=user, message=message, ttl=gpt_settings.messages_ttl)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@inject_database
|
|
265
|
+
async def summarize_history(db: Database, user_id: int) -> None:
|
|
266
|
+
user = await db.get_or_create_user(user_id=user_id)
|
|
267
|
+
chat_history: list[Message] = await db.get_conversation_messages(user=user)
|
|
268
|
+
await reset_chat_history(user_id=user_id)
|
|
269
|
+
await db.add_message(user=user, message=chat_history[0], ttl=gpt_settings.messages_ttl)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from openai.types.chat import (
|
|
6
|
+
ChatCompletionAssistantMessageParam,
|
|
7
|
+
ChatCompletionFunctionMessageParam,
|
|
8
|
+
ChatCompletionToolMessageParam,
|
|
9
|
+
ChatCompletionUserMessageParam,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from chibi.models import ImageMeta, Message, User
|
|
13
|
+
|
|
14
|
+
CHAT_COMPLETION_CLASSES = {
|
|
15
|
+
"assistant": ChatCompletionAssistantMessageParam,
|
|
16
|
+
"function": ChatCompletionFunctionMessageParam,
|
|
17
|
+
"tool": ChatCompletionToolMessageParam,
|
|
18
|
+
"user": ChatCompletionUserMessageParam,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Database(ABC):
|
|
23
|
+
async def get_or_create_user(self, user_id: int) -> User:
|
|
24
|
+
if user := await self.get_user(user_id=user_id):
|
|
25
|
+
return user
|
|
26
|
+
return await self.create_user(user_id=user_id)
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def save_user(self, user: User) -> None: ...
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def create_user(self, user_id: int) -> User: ...
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def get_user(self, user_id: int) -> User | None: ...
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def add_message(self, user: User, message: Message, ttl: Optional[int] = None) -> None: ...
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def get_messages(self, user: User) -> list[dict[str, str]]: ...
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def drop_messages(self, user: User) -> None: ...
|
|
45
|
+
|
|
46
|
+
async def get_conversation_messages(self, user: User) -> list[Message]:
|
|
47
|
+
messages = await self.get_messages(user=user)
|
|
48
|
+
return [Message(**msg) for msg in messages]
|
|
49
|
+
|
|
50
|
+
async def count_image(self, user_id: int) -> None:
|
|
51
|
+
user = await self.get_or_create_user(user_id=user_id)
|
|
52
|
+
expire_at = time.time() + 60 * 750 # ~ 1 month
|
|
53
|
+
user.images.append(ImageMeta(expire_at=expire_at))
|
|
54
|
+
await self.save_user(user)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import Awaitable, Callable, Concatenate, Optional, ParamSpec, TypeVar, cast
|
|
4
|
+
|
|
5
|
+
from chibi.config.app import application_settings
|
|
6
|
+
from chibi.storage.abstract import Database
|
|
7
|
+
from chibi.storage.dynamodb import DynamoDBStorage
|
|
8
|
+
from chibi.storage.local import LocalStorage
|
|
9
|
+
from chibi.storage.redis import RedisStorage
|
|
10
|
+
|
|
11
|
+
R = TypeVar("R")
|
|
12
|
+
P = ParamSpec("P")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatabaseCache:
|
|
16
|
+
"""
|
|
17
|
+
Caches a Database instance according to application settings.
|
|
18
|
+
Supports 'local', 'redis', and 'dynamodb' backends.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
self._cache: Optional[Database] = None
|
|
23
|
+
self._lock = asyncio.Lock()
|
|
24
|
+
|
|
25
|
+
async def get_database(self) -> Database:
|
|
26
|
+
"""Get or create the Database instance based on storage_backend setting.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Initialized Database instance.
|
|
30
|
+
"""
|
|
31
|
+
async with self._lock:
|
|
32
|
+
if self._cache is not None:
|
|
33
|
+
return self._cache
|
|
34
|
+
|
|
35
|
+
backend = application_settings.storage_backend.lower()
|
|
36
|
+
if backend == "redis":
|
|
37
|
+
# RedisStorage.create expects URL and password
|
|
38
|
+
self._cache = await RedisStorage.create(
|
|
39
|
+
url=cast(str, application_settings.redis),
|
|
40
|
+
password=application_settings.redis_password,
|
|
41
|
+
)
|
|
42
|
+
elif backend == "dynamodb":
|
|
43
|
+
# DynamoDBStorage.create expects region, access_key, secret_key, tables
|
|
44
|
+
self._cache = await DynamoDBStorage.create(
|
|
45
|
+
region=application_settings.aws_region or "",
|
|
46
|
+
access_key=application_settings.aws_access_key_id,
|
|
47
|
+
secret_access_key=application_settings.aws_secret_access_key,
|
|
48
|
+
users_table=application_settings.ddb_users_table or "",
|
|
49
|
+
messages_table=application_settings.ddb_messages_table or "",
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
# default to local storage
|
|
53
|
+
self._cache = LocalStorage(application_settings.local_data_path)
|
|
54
|
+
|
|
55
|
+
return self._cache
|
|
56
|
+
|
|
57
|
+
def clear_cache(self) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Clear the cached Database instance, forcing reinitialization on next use.
|
|
60
|
+
"""
|
|
61
|
+
self._cache = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
_db_provider = DatabaseCache()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def inject_database(
|
|
68
|
+
func: Callable[Concatenate[Database, P], Awaitable[R]],
|
|
69
|
+
) -> Callable[P, Awaitable[R]]:
|
|
70
|
+
"""Decorator to inject the Database instance into async functions.
|
|
71
|
+
|
|
72
|
+
Wraps a function with signature func(db, *args, **kwargs) -> Awaitable.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
func: The function to decorate.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Function execution wrapper.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
@wraps(func)
|
|
82
|
+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
83
|
+
db = await _db_provider.get_database()
|
|
84
|
+
return await func(db, *args, **kwargs)
|
|
85
|
+
|
|
86
|
+
return wrapper
|