chibi-bot 1.6.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. chibi/__init__.py +0 -0
  2. chibi/__main__.py +343 -0
  3. chibi/cli.py +90 -0
  4. chibi/config/__init__.py +6 -0
  5. chibi/config/app.py +123 -0
  6. chibi/config/gpt.py +108 -0
  7. chibi/config/logging.py +15 -0
  8. chibi/config/telegram.py +43 -0
  9. chibi/config_generator.py +233 -0
  10. chibi/constants.py +362 -0
  11. chibi/exceptions.py +58 -0
  12. chibi/models.py +496 -0
  13. chibi/schemas/__init__.py +0 -0
  14. chibi/schemas/anthropic.py +20 -0
  15. chibi/schemas/app.py +54 -0
  16. chibi/schemas/cloudflare.py +65 -0
  17. chibi/schemas/mistralai.py +56 -0
  18. chibi/schemas/suno.py +83 -0
  19. chibi/service.py +135 -0
  20. chibi/services/bot.py +276 -0
  21. chibi/services/lock_manager.py +20 -0
  22. chibi/services/mcp/manager.py +242 -0
  23. chibi/services/metrics.py +54 -0
  24. chibi/services/providers/__init__.py +16 -0
  25. chibi/services/providers/alibaba.py +79 -0
  26. chibi/services/providers/anthropic.py +40 -0
  27. chibi/services/providers/cloudflare.py +98 -0
  28. chibi/services/providers/constants/suno.py +2 -0
  29. chibi/services/providers/customopenai.py +11 -0
  30. chibi/services/providers/deepseek.py +15 -0
  31. chibi/services/providers/eleven_labs.py +85 -0
  32. chibi/services/providers/gemini_native.py +489 -0
  33. chibi/services/providers/grok.py +40 -0
  34. chibi/services/providers/minimax.py +96 -0
  35. chibi/services/providers/mistralai_native.py +312 -0
  36. chibi/services/providers/moonshotai.py +20 -0
  37. chibi/services/providers/openai.py +74 -0
  38. chibi/services/providers/provider.py +892 -0
  39. chibi/services/providers/suno.py +130 -0
  40. chibi/services/providers/tools/__init__.py +23 -0
  41. chibi/services/providers/tools/cmd.py +132 -0
  42. chibi/services/providers/tools/common.py +127 -0
  43. chibi/services/providers/tools/constants.py +78 -0
  44. chibi/services/providers/tools/exceptions.py +1 -0
  45. chibi/services/providers/tools/file_editor.py +875 -0
  46. chibi/services/providers/tools/mcp_management.py +274 -0
  47. chibi/services/providers/tools/mcp_simple.py +72 -0
  48. chibi/services/providers/tools/media.py +451 -0
  49. chibi/services/providers/tools/memory.py +252 -0
  50. chibi/services/providers/tools/schemas.py +10 -0
  51. chibi/services/providers/tools/send.py +435 -0
  52. chibi/services/providers/tools/tool.py +163 -0
  53. chibi/services/providers/tools/utils.py +146 -0
  54. chibi/services/providers/tools/web.py +261 -0
  55. chibi/services/providers/utils.py +182 -0
  56. chibi/services/task_manager.py +93 -0
  57. chibi/services/user.py +269 -0
  58. chibi/storage/abstract.py +54 -0
  59. chibi/storage/database.py +86 -0
  60. chibi/storage/dynamodb.py +257 -0
  61. chibi/storage/local.py +70 -0
  62. chibi/storage/redis.py +91 -0
  63. chibi/utils/__init__.py +0 -0
  64. chibi/utils/app.py +249 -0
  65. chibi/utils/telegram.py +521 -0
  66. chibi_bot-1.6.0b0.dist-info/LICENSE +21 -0
  67. chibi_bot-1.6.0b0.dist-info/METADATA +340 -0
  68. chibi_bot-1.6.0b0.dist-info/RECORD +70 -0
  69. chibi_bot-1.6.0b0.dist-info/WHEEL +4 -0
  70. chibi_bot-1.6.0b0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,182 @@
1
+ import inspect
2
+ import json
3
+ from typing import Any, Callable, Coroutine, ParamSpec, Type, TypeAlias, TypeVar
4
+
5
+ from anthropic.types import (
6
+ Message as AnthropicMessage,
7
+ )
8
+ from google.genai.types import GenerateContentResponse
9
+ from mistralai import ChatCompletionResponse
10
+ from openai.types import CompletionUsage
11
+ from openai.types.chat import ChatCompletion
12
+ from telegram import Update, constants
13
+ from telegram.ext import ContextTypes
14
+
15
+ from chibi.config import gpt_settings
16
+ from chibi.models import User
17
+ from chibi.schemas.app import UsageSchema
18
+ from chibi.schemas.suno import SunoGetGenerationDetailsSchema
19
+ from chibi.utils.app import get_builtin_skill_names
20
+
21
+ T = TypeVar("T")
22
+ P = ParamSpec("P")
23
+ M = TypeVar("M", bound=Callable[..., Coroutine[Any, Any, Any]])
24
+ AsyncFunc: TypeAlias = Callable[P, Coroutine[Any, Any, T]]
25
+
26
+
27
+ def decorate_async_methods(decorator: Callable[[M], M]) -> Callable[[Type[T]], Type[T]]:
28
+ def decorate(cls: Type[T]) -> Type[T]:
29
+ for attr in cls.__dict__:
30
+ if inspect.iscoroutinefunction(getattr(cls, attr)):
31
+ original_func = getattr(cls, attr)
32
+ decorated_func = decorator(original_func)
33
+ setattr(cls, attr, decorated_func)
34
+ return cls
35
+
36
+ return decorate
37
+
38
+
39
+ def escape_and_truncate(message: str | dict[str, Any] | list[dict[str, Any]] | None, limit: int = 50) -> str:
40
+ if not message:
41
+ return "no data"
42
+
43
+ if isinstance(message, dict):
44
+ return json.dumps({k: escape_and_truncate(message=v, limit=limit) for k, v in message.items()})
45
+
46
+ if isinstance(message, list):
47
+ return json.dumps([escape_and_truncate(message=m, limit=limit) for m in message])
48
+
49
+ escaped_message = str(message).replace("<", r"\<").replace(">", r"\>")
50
+ if len(escaped_message) < limit + 20:
51
+ return escaped_message
52
+ return f"{escaped_message[:limit]}... (truncated)"
53
+
54
+
55
+ async def prepare_system_prompt(base_system_prompt: str, user: User | None = None) -> str:
56
+ if user is None:
57
+ return base_system_prompt
58
+
59
+ prompt = {
60
+ "current_working_dir": user.working_dir,
61
+ "user_id": user.id,
62
+ "user_info": user.info,
63
+ "system_prompt": base_system_prompt,
64
+ "available_builtin_skills": get_builtin_skill_names(),
65
+ "activated_skills": user.llm_skills,
66
+ }
67
+ return json.dumps(prompt)
68
+
69
+
70
+ async def send_llm_thoughts(thoughts: str, update: Update | None, context: ContextTypes.DEFAULT_TYPE | None) -> None:
71
+ if not gpt_settings.show_llm_thoughts:
72
+ return None
73
+
74
+ from chibi.utils.telegram import send_long_message
75
+
76
+ if update is None or context is None:
77
+ return None
78
+ message = f"💡💭 {thoughts}"
79
+ await send_long_message(
80
+ message=message,
81
+ update=update,
82
+ context=context,
83
+ parse_mode=constants.ParseMode.MARKDOWN_V2,
84
+ reply=False,
85
+ )
86
+
87
+
88
+ def get_usage_from_anthropic_response(response_message: AnthropicMessage) -> UsageSchema:
89
+ return UsageSchema(
90
+ completion_tokens=response_message.usage.output_tokens,
91
+ prompt_tokens=response_message.usage.input_tokens,
92
+ cache_creation_input_tokens=response_message.usage.cache_creation_input_tokens or 0,
93
+ cache_read_input_tokens=response_message.usage.cache_read_input_tokens or 0,
94
+ total_tokens=response_message.usage.output_tokens + response_message.usage.input_tokens,
95
+ )
96
+
97
+
98
+ def get_usage_from_openai_response(response_message: ChatCompletion) -> UsageSchema:
99
+ if response_message.usage is None:
100
+ return UsageSchema()
101
+ response_usage = response_message.usage
102
+ usage = UsageSchema(
103
+ completion_tokens=response_usage.completion_tokens,
104
+ prompt_tokens=response_usage.prompt_tokens,
105
+ total_tokens=response_usage.total_tokens,
106
+ )
107
+ if prompt_cache := response_usage.prompt_tokens_details:
108
+ usage.cache_read_input_tokens = prompt_cache.cached_tokens or 0
109
+ return usage
110
+
111
+
112
+ def get_usage_from_google_response(response_message: GenerateContentResponse) -> UsageSchema:
113
+ if not response_message.usage_metadata:
114
+ return UsageSchema()
115
+
116
+ return UsageSchema(
117
+ total_tokens=response_message.usage_metadata.total_token_count or 0,
118
+ completion_tokens=response_message.usage_metadata.candidates_token_count or 0,
119
+ prompt_tokens=response_message.usage_metadata.prompt_token_count or 0,
120
+ cache_read_input_tokens=response_message.usage_metadata.cached_content_token_count or 0,
121
+ )
122
+
123
+
124
+ def get_usage_from_mistral_response(response_message: ChatCompletionResponse) -> UsageSchema:
125
+ return UsageSchema(
126
+ completion_tokens=response_message.usage.completion_tokens or 0,
127
+ prompt_tokens=response_message.usage.prompt_tokens or 0,
128
+ cache_creation_input_tokens=0,
129
+ cache_read_input_tokens=0,
130
+ total_tokens=response_message.usage.total_tokens or 0,
131
+ )
132
+
133
+
134
+ def get_usage_msg(usage: UsageSchema | CompletionUsage | None) -> str:
135
+ if usage is None:
136
+ return ""
137
+ cache_read = getattr(usage, "cache_read_input_tokens", None)
138
+ cache_create = getattr(usage, "cache_creation_input_tokens", None)
139
+ return (
140
+ f"Tokens used: {getattr(usage, 'total_tokens', None) or 'n/a'} "
141
+ f"({getattr(usage, 'prompt_tokens', None)} prompt, "
142
+ f"{getattr(usage, 'completion_tokens', None)} completion, "
143
+ f"{cache_read or 0} cached read/prompt, "
144
+ f"{cache_create or 0} cached creation)"
145
+ )
146
+
147
+
148
+ def suno_task_still_processing(task_data_response: SunoGetGenerationDetailsSchema) -> bool:
149
+ return task_data_response.is_in_progress
150
+
151
+
152
+ # def limit_recursion(
153
+ # max_depth: int = application_settings.max_consecutive_tool_calls,
154
+ # ) -> Callable[[AsyncFunc[P, T]], AsyncFunc[P, T]]:
155
+ # def decorator(func: AsyncFunc[P, T]) -> AsyncFunc[P, T]:
156
+ # depth_var: ContextVar[int] = ContextVar(f"{func.__name__}_depth", default=0)
157
+ #
158
+ # @wraps(func)
159
+ # async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
160
+ # current_depth = depth_var.get()
161
+ # depth_var.set(current_depth + 1)
162
+ # if depth_var.get() > max_depth + 1:
163
+ # depth_var.set(current_depth)
164
+ # class_name = ""
165
+ # if args and hasattr(args[0], "__class__"):
166
+ # class_name = f"{args[0].__class__.__name__}."
167
+ # raise RecursionLimitExceeded(
168
+ # provider=class_name,
169
+ # model=cast(str, kwargs.get("model", "unknown")),
170
+ # detail=f"Recursion depth exceeded: {max_depth} (function: {class_name}{func.__name__})",
171
+ # exceeded_limit=max_depth,
172
+ # )
173
+ #
174
+ # try:
175
+ # result = await func(*args, **kwargs)
176
+ # return result
177
+ # finally:
178
+ # depth_var.set(current_depth)
179
+ #
180
+ # return async_wrapper
181
+ #
182
+ # return decorator
@@ -0,0 +1,93 @@
1
+ import asyncio
2
+ from typing import Any, Coroutine
3
+
4
+ from loguru import logger
5
+
6
+ from chibi.utils.app import SingletonMeta
7
+
8
+
9
+ class BackgroundTaskManager(metaclass=SingletonMeta):
10
+ def __init__(self) -> None:
11
+ """Initialize the task manager."""
12
+ if not hasattr(self, "_tasks"):
13
+ self._tasks: set[asyncio.Task] = set()
14
+ self._shutting_down: bool = False
15
+
16
+ async def _wrap_with_timeout(self, coro: Coroutine[Any, Any, Any], timeout: float) -> Any:
17
+ """Wrap a coroutine with a timeout.
18
+
19
+ Args:
20
+ coro: The coroutine to wrap.
21
+
22
+ Returns:
23
+ The wrapped coroutine result.
24
+ """
25
+ try:
26
+ return await asyncio.wait_for(fut=coro, timeout=timeout)
27
+ except asyncio.TimeoutError:
28
+ logger.warning(f"Background task timed out after {timeout}s")
29
+ raise
30
+
31
+ def run_task(self, coro: Coroutine[Any, Any, Any], timeout: float | None = None) -> asyncio.Task | None:
32
+ """Schedule a coroutine to run in the background.
33
+
34
+ Args:
35
+ coro: the coroutine to run
36
+ timeout: optional timeout in seconds. If the task doesn't complete
37
+ within this time, it will be cancelled with TimeoutError
38
+ """
39
+ if self._shutting_down:
40
+ logger.warning("Task manager is shutting down, refusing new task")
41
+ coro.close()
42
+ return None
43
+
44
+ # Wrap with timeout if specified
45
+ if timeout is not None:
46
+ coro = self._wrap_with_timeout(coro, timeout)
47
+
48
+ task = asyncio.create_task(coro)
49
+ self._tasks.add(task)
50
+ task.add_done_callback(self._discard_task)
51
+ return task
52
+
53
+ def _discard_task(self, task: asyncio.Task) -> None:
54
+ """Callback to remove a task from the set when it finishes."""
55
+ try:
56
+ exc = task.exception()
57
+ if exc:
58
+ logger.error(
59
+ f"Background task '{task.get_name()}' failed: {exc.__class__.__name__} ({str(exc) or 'no details'})"
60
+ )
61
+ except asyncio.CancelledError:
62
+ pass
63
+ except Exception as e:
64
+ logger.error(f"Error checking background task result: {e}")
65
+ finally:
66
+ self._tasks.discard(task)
67
+
68
+ async def shutdown(self, *args: Any) -> None:
69
+ """
70
+ Wait for all background tasks to complete with a timeout.
71
+ If tasks do not finish within 15 seconds, they are cancelled.
72
+ """
73
+ logger.info("Shutting down background tasks...")
74
+ self._shutting_down = True
75
+ tasks_to_wait = list(self._tasks)
76
+ if not tasks_to_wait:
77
+ logger.info("No background tasks to wait, we're good.")
78
+ return None
79
+
80
+ logger.info(f"Waiting for {len(tasks_to_wait)} background tasks to complete...")
81
+ try:
82
+ await asyncio.wait_for(asyncio.gather(*tasks_to_wait, return_exceptions=True), timeout=5.0)
83
+ logger.info("All background tasks completed.")
84
+
85
+ except asyncio.TimeoutError:
86
+ logger.warning("Timeout reached. Cancelling remaining background tasks...")
87
+ remaining = [t for t in tasks_to_wait if not t.done()]
88
+ for task in remaining:
89
+ task.cancel()
90
+ logger.info(f"Cancelled {len(remaining)} remaining tasks.")
91
+
92
+
93
+ task_manager = BackgroundTaskManager()
chibi/services/user.py ADDED
@@ -0,0 +1,269 @@
1
+ import datetime
2
+ import json
3
+ from copy import deepcopy
4
+ from datetime import timezone
5
+ from io import BytesIO
6
+ from typing import TYPE_CHECKING, Any, Optional
7
+
8
+ from aiocache import cached
9
+ from telegram import Update
10
+ from telegram.ext import ContextTypes
11
+
12
+ from chibi.config import gpt_settings
13
+ from chibi.models import Message, User
14
+ from chibi.schemas.app import ChatResponseSchema, ModelChangeSchema
15
+ from chibi.services.lock_manager import LockManager
16
+ from chibi.storage.abstract import Database
17
+ from chibi.storage.database import inject_database
18
+
19
+ if TYPE_CHECKING:
20
+ from chibi.services.providers.provider import Provider
21
+ from chibi.services.providers.tools import ToolResponse
22
+
23
+
24
+ @inject_database
25
+ async def get_chibi_user(db: Database, user_id: int) -> User:
26
+ return await db.get_or_create_user(user_id=user_id)
27
+
28
+
29
+ @inject_database
30
+ async def set_active_model(db: Database, user_id: int, model: ModelChangeSchema) -> None:
31
+ user = await db.get_or_create_user(user_id=user_id)
32
+ if model.image_generation:
33
+ user.selected_image_model_name = model.name
34
+ user.selected_image_provider_name = model.provider
35
+ else:
36
+ user.selected_gpt_model_name = model.name
37
+ user.selected_gpt_provider_name = model.provider
38
+ await db.save_user(user)
39
+
40
+
41
+ @inject_database
42
+ async def reset_chat_history(db: Database, user_id: int) -> None:
43
+ user = await db.get_or_create_user(user_id=user_id)
44
+ await db.drop_messages(user=user)
45
+
46
+
47
+ @inject_database
48
+ async def emergency_summarization(db: Database, user_id: int) -> None:
49
+ user = await db.get_or_create_user(user_id=user_id)
50
+
51
+ chat_history = await db.get_messages(user=user)
52
+ chat_history_string = str(msg for msg in chat_history if not any((msg.get("tool_calls"), msg.get("tool_call_id"))))
53
+ user_messages: list[Message] = [Message(role="user", content=chat_history_string)]
54
+
55
+ response, _ = await user.active_gpt_provider.get_chat_response(
56
+ messages=user_messages,
57
+ user=user,
58
+ system_prompt="Summarize this conversation, keeping the most important and useful information using English.",
59
+ )
60
+ initial_message = Message(role="user", content="What we were talking about?")
61
+ answer_message = Message(role="assistant", content=response.answer)
62
+ await reset_chat_history(user_id=user_id)
63
+ await db.add_message(user=user, message=initial_message, ttl=gpt_settings.messages_ttl)
64
+ await db.add_message(user=user, message=answer_message, ttl=gpt_settings.messages_ttl)
65
+
66
+
67
+ @inject_database
68
+ async def get_llm_chat_completion_answer(
69
+ db: Database,
70
+ user_id: int,
71
+ update: Update,
72
+ context: ContextTypes.DEFAULT_TYPE,
73
+ user_text_message: str | None = None,
74
+ user_voice_message: BytesIO | None = None,
75
+ tool_message: Optional["ToolResponse"] = None,
76
+ ) -> ChatResponseSchema:
77
+ user = await db.get_or_create_user(user_id=user_id)
78
+ lock = await LockManager().get_lock(key=user_id)
79
+
80
+ if not user_text_message and not user_voice_message and not tool_message:
81
+ raise ValueError("No prompt data provided")
82
+
83
+ if user_voice_message and not user.stt_provider:
84
+ raise ValueError("Can't compute voice message: no STT provide available.")
85
+
86
+ prompt: dict[str, Any]
87
+
88
+ if tool_message:
89
+ prompt = {
90
+ "type": "tool response",
91
+ "desc": "background task is done",
92
+ "tool_name": tool_message.tool_name,
93
+ "tool_response": tool_message.model_dump(),
94
+ "datetime_now": datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z%z"),
95
+ }
96
+ else:
97
+ user_message = (
98
+ await user.stt_provider.transcribe(audio=user_voice_message) if user_voice_message else user_text_message
99
+ )
100
+ assert user_message
101
+ prompt = {
102
+ "prompt": user_message,
103
+ "datetime_now": datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z%z"),
104
+ "type": "user message",
105
+ "transcribed_from_voice_message": bool(user_voice_message),
106
+ }
107
+
108
+ async with lock:
109
+ conversation_messages: list[Message] = await db.get_conversation_messages(user=user)
110
+ new_message_to_llm = Message(role="user", content=json.dumps(prompt))
111
+ conversation_messages.append(new_message_to_llm)
112
+
113
+ chat_response, new_messages = await user.active_gpt_provider.get_chat_response(
114
+ messages=conversation_messages,
115
+ user=user,
116
+ model=user.selected_gpt_model_name,
117
+ update=update,
118
+ context=context,
119
+ )
120
+ await db.add_message(user=user, message=new_message_to_llm, ttl=gpt_settings.messages_ttl)
121
+ for message in new_messages:
122
+ await db.add_message(user=user, message=message, ttl=gpt_settings.messages_ttl)
123
+ return chat_response
124
+
125
+
126
+ @inject_database
127
+ async def check_history_and_summarize(db: Database, user_id: int) -> bool:
128
+ user = await db.get_or_create_user(user_id=user_id)
129
+ messages = await db.get_messages(user=user)
130
+ # Roughly estimating how many tokens the current conversation history will comprise. It is possible to calculate
131
+ # this accurately, but the modules that can be used for this need to be separately built for armv7, which is
132
+ # difficult to do right now (but will be done further, I hope).
133
+ if len(str(messages)) / 4 >= gpt_settings.max_history_tokens:
134
+ await emergency_summarization(user_id=user_id)
135
+ return True
136
+ return False
137
+
138
+
139
+ @inject_database
140
+ async def generate_image(
141
+ db: Database, user_id: int, prompt: str, model: str | None = None, provider_name: str | None = None
142
+ ) -> list[str] | list[BytesIO]:
143
+ user = await db.get_or_create_user(user_id=user_id)
144
+
145
+ if provider_name:
146
+ provider = user.providers.get(provider_name)
147
+ selected_model = model
148
+ elif user.selected_image_provider_name:
149
+ provider = user.active_image_provider
150
+ selected_model = model or user.selected_image_model_name
151
+ else:
152
+ provider = user.active_image_provider
153
+ selected_model = None
154
+ if not provider:
155
+ raise ValueError(f"User {user_id}: no image provider available.")
156
+ images = await provider.get_images(prompt=prompt, model=selected_model)
157
+ if user_id not in gpt_settings.image_generations_whitelist:
158
+ await db.count_image(user_id)
159
+ return images
160
+
161
+
162
+ @cached(ttl=300)
163
+ @inject_database
164
+ async def get_user_cached_models(db: Database, user_id: int, image_generation: bool = False) -> list[ModelChangeSchema]:
165
+ user = await db.get_or_create_user(user_id=user_id)
166
+ return await user.get_available_models(image_generation=image_generation)
167
+
168
+
169
+ @inject_database
170
+ async def get_models_available(db: Database, user_id: int, image_generation: bool = False) -> list[ModelChangeSchema]:
171
+ user = await db.get_or_create_user(user_id=user_id)
172
+ user_models = await get_user_cached_models(user_id=user_id, image_generation=image_generation)
173
+
174
+ if not user_models:
175
+ return []
176
+
177
+ available_models = deepcopy(user_models)
178
+ if image_generation:
179
+ active_model = user.selected_image_model_name or user.active_image_provider.default_image_model
180
+ else:
181
+ active_model = user.selected_gpt_model_name or user.active_gpt_provider.default_image_model
182
+
183
+ for model in available_models:
184
+ if model.name == active_model:
185
+ model.display_name = f"🟢 {model.display_name}️"
186
+ return available_models
187
+
188
+
189
+ @inject_database
190
+ async def user_has_reached_images_generation_limit(db: Database, user_id: int) -> bool:
191
+ user = await db.get_or_create_user(user_id=user_id)
192
+ return user.has_reached_image_limits
193
+
194
+
195
+ @inject_database
196
+ async def set_api_key(db: Database, user_id: int, api_key: str, provider_name: str) -> None:
197
+ user = await db.get_or_create_user(user_id=user_id)
198
+ user.tokens[provider_name] = api_key
199
+ await db.save_user(user)
200
+ return None
201
+
202
+
203
+ @inject_database
204
+ async def get_info(db: Database, user_id: int) -> str:
205
+ user = await db.get_or_create_user(user_id=user_id)
206
+ return user.info
207
+
208
+
209
+ @inject_database
210
+ async def set_info(db: Database, user_id: int, new_info: str) -> None:
211
+ user = await db.get_or_create_user(user_id=user_id)
212
+ user.info = new_info
213
+ await db.save_user(user)
214
+
215
+
216
+ @inject_database
217
+ async def activate_llm_skill(db: Database, user_id: int, skill_name: str, skill_payload: str) -> None:
218
+ user = await db.get_or_create_user(user_id=user_id)
219
+ user.llm_skills[skill_name] = skill_payload
220
+ await db.save_user(user)
221
+
222
+
223
+ @inject_database
224
+ async def deactivate_llm_skill(db: Database, user_id: int, skill_name: str) -> None:
225
+ user = await db.get_or_create_user(user_id=user_id)
226
+ if skill_name not in user.llm_skills.keys():
227
+ raise ValueError(f"The skill {skill_name} seems never been activated")
228
+ user.llm_skills.pop(skill_name)
229
+ await db.save_user(user)
230
+
231
+
232
+ @inject_database
233
+ async def set_working_dir(db: Database, user_id: int, new_wd: str) -> None:
234
+ user = await db.get_or_create_user(user_id=user_id)
235
+ user.working_dir = new_wd
236
+ await db.save_user(user)
237
+
238
+
239
+ @inject_database
240
+ async def get_cwd(db: Database, user_id: int) -> str:
241
+ user = await db.get_or_create_user(user_id=user_id)
242
+ return user.working_dir
243
+
244
+
245
+ @inject_database
246
+ async def get_moderation_provider(db: Database, user_id: int) -> "Provider":
247
+ user = await db.get_or_create_user(user_id=user_id)
248
+ return user.moderation_provider
249
+
250
+
251
+ @inject_database
252
+ async def drop_tool_call_history(db: Database, user_id: int) -> None:
253
+ user = await db.get_or_create_user(user_id=user_id)
254
+ chat_history: list[Message] = await db.get_conversation_messages(user=user)
255
+ await reset_chat_history(user_id=user_id)
256
+ for message in chat_history:
257
+ if message.role == "tool":
258
+ continue
259
+ message.tool_calls = None
260
+ message.tool_call_id = None
261
+ await db.add_message(user=user, message=message, ttl=gpt_settings.messages_ttl)
262
+
263
+
264
+ @inject_database
265
+ async def summarize_history(db: Database, user_id: int) -> None:
266
+ user = await db.get_or_create_user(user_id=user_id)
267
+ chat_history: list[Message] = await db.get_conversation_messages(user=user)
268
+ await reset_chat_history(user_id=user_id)
269
+ await db.add_message(user=user, message=chat_history[0], ttl=gpt_settings.messages_ttl)
@@ -0,0 +1,54 @@
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+ from typing import Optional
4
+
5
+ from openai.types.chat import (
6
+ ChatCompletionAssistantMessageParam,
7
+ ChatCompletionFunctionMessageParam,
8
+ ChatCompletionToolMessageParam,
9
+ ChatCompletionUserMessageParam,
10
+ )
11
+
12
+ from chibi.models import ImageMeta, Message, User
13
+
14
+ CHAT_COMPLETION_CLASSES = {
15
+ "assistant": ChatCompletionAssistantMessageParam,
16
+ "function": ChatCompletionFunctionMessageParam,
17
+ "tool": ChatCompletionToolMessageParam,
18
+ "user": ChatCompletionUserMessageParam,
19
+ }
20
+
21
+
22
+ class Database(ABC):
23
+ async def get_or_create_user(self, user_id: int) -> User:
24
+ if user := await self.get_user(user_id=user_id):
25
+ return user
26
+ return await self.create_user(user_id=user_id)
27
+
28
+ @abstractmethod
29
+ async def save_user(self, user: User) -> None: ...
30
+
31
+ @abstractmethod
32
+ async def create_user(self, user_id: int) -> User: ...
33
+
34
+ @abstractmethod
35
+ async def get_user(self, user_id: int) -> User | None: ...
36
+
37
+ @abstractmethod
38
+ async def add_message(self, user: User, message: Message, ttl: Optional[int] = None) -> None: ...
39
+
40
+ @abstractmethod
41
+ async def get_messages(self, user: User) -> list[dict[str, str]]: ...
42
+
43
+ @abstractmethod
44
+ async def drop_messages(self, user: User) -> None: ...
45
+
46
+ async def get_conversation_messages(self, user: User) -> list[Message]:
47
+ messages = await self.get_messages(user=user)
48
+ return [Message(**msg) for msg in messages]
49
+
50
+ async def count_image(self, user_id: int) -> None:
51
+ user = await self.get_or_create_user(user_id=user_id)
52
+ expire_at = time.time() + 60 * 750 # ~ 1 month
53
+ user.images.append(ImageMeta(expire_at=expire_at))
54
+ await self.save_user(user)
@@ -0,0 +1,86 @@
1
+ import asyncio
2
+ from functools import wraps
3
+ from typing import Awaitable, Callable, Concatenate, Optional, ParamSpec, TypeVar, cast
4
+
5
+ from chibi.config.app import application_settings
6
+ from chibi.storage.abstract import Database
7
+ from chibi.storage.dynamodb import DynamoDBStorage
8
+ from chibi.storage.local import LocalStorage
9
+ from chibi.storage.redis import RedisStorage
10
+
11
+ R = TypeVar("R")
12
+ P = ParamSpec("P")
13
+
14
+
15
+ class DatabaseCache:
16
+ """
17
+ Caches a Database instance according to application settings.
18
+ Supports 'local', 'redis', and 'dynamodb' backends.
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ self._cache: Optional[Database] = None
23
+ self._lock = asyncio.Lock()
24
+
25
+ async def get_database(self) -> Database:
26
+ """Get or create the Database instance based on storage_backend setting.
27
+
28
+ Returns:
29
+ Initialized Database instance.
30
+ """
31
+ async with self._lock:
32
+ if self._cache is not None:
33
+ return self._cache
34
+
35
+ backend = application_settings.storage_backend.lower()
36
+ if backend == "redis":
37
+ # RedisStorage.create expects URL and password
38
+ self._cache = await RedisStorage.create(
39
+ url=cast(str, application_settings.redis),
40
+ password=application_settings.redis_password,
41
+ )
42
+ elif backend == "dynamodb":
43
+ # DynamoDBStorage.create expects region, access_key, secret_key, tables
44
+ self._cache = await DynamoDBStorage.create(
45
+ region=application_settings.aws_region or "",
46
+ access_key=application_settings.aws_access_key_id,
47
+ secret_access_key=application_settings.aws_secret_access_key,
48
+ users_table=application_settings.ddb_users_table or "",
49
+ messages_table=application_settings.ddb_messages_table or "",
50
+ )
51
+ else:
52
+ # default to local storage
53
+ self._cache = LocalStorage(application_settings.local_data_path)
54
+
55
+ return self._cache
56
+
57
+ def clear_cache(self) -> None:
58
+ """
59
+ Clear the cached Database instance, forcing reinitialization on next use.
60
+ """
61
+ self._cache = None
62
+
63
+
64
+ _db_provider = DatabaseCache()
65
+
66
+
67
+ def inject_database(
68
+ func: Callable[Concatenate[Database, P], Awaitable[R]],
69
+ ) -> Callable[P, Awaitable[R]]:
70
+ """Decorator to inject the Database instance into async functions.
71
+
72
+ Wraps a function with signature func(db, *args, **kwargs) -> Awaitable.
73
+
74
+ Args:
75
+ func: The function to decorate.
76
+
77
+ Returns:
78
+ Function execution wrapper.
79
+ """
80
+
81
+ @wraps(func)
82
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
83
+ db = await _db_provider.get_database()
84
+ return await func(db, *args, **kwargs)
85
+
86
+ return wrapper