chibi-bot 1.6.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chibi/__init__.py +0 -0
- chibi/__main__.py +343 -0
- chibi/cli.py +90 -0
- chibi/config/__init__.py +6 -0
- chibi/config/app.py +123 -0
- chibi/config/gpt.py +108 -0
- chibi/config/logging.py +15 -0
- chibi/config/telegram.py +43 -0
- chibi/config_generator.py +233 -0
- chibi/constants.py +362 -0
- chibi/exceptions.py +58 -0
- chibi/models.py +496 -0
- chibi/schemas/__init__.py +0 -0
- chibi/schemas/anthropic.py +20 -0
- chibi/schemas/app.py +54 -0
- chibi/schemas/cloudflare.py +65 -0
- chibi/schemas/mistralai.py +56 -0
- chibi/schemas/suno.py +83 -0
- chibi/service.py +135 -0
- chibi/services/bot.py +276 -0
- chibi/services/lock_manager.py +20 -0
- chibi/services/mcp/manager.py +242 -0
- chibi/services/metrics.py +54 -0
- chibi/services/providers/__init__.py +16 -0
- chibi/services/providers/alibaba.py +79 -0
- chibi/services/providers/anthropic.py +40 -0
- chibi/services/providers/cloudflare.py +98 -0
- chibi/services/providers/constants/suno.py +2 -0
- chibi/services/providers/customopenai.py +11 -0
- chibi/services/providers/deepseek.py +15 -0
- chibi/services/providers/eleven_labs.py +85 -0
- chibi/services/providers/gemini_native.py +489 -0
- chibi/services/providers/grok.py +40 -0
- chibi/services/providers/minimax.py +96 -0
- chibi/services/providers/mistralai_native.py +312 -0
- chibi/services/providers/moonshotai.py +20 -0
- chibi/services/providers/openai.py +74 -0
- chibi/services/providers/provider.py +892 -0
- chibi/services/providers/suno.py +130 -0
- chibi/services/providers/tools/__init__.py +23 -0
- chibi/services/providers/tools/cmd.py +132 -0
- chibi/services/providers/tools/common.py +127 -0
- chibi/services/providers/tools/constants.py +78 -0
- chibi/services/providers/tools/exceptions.py +1 -0
- chibi/services/providers/tools/file_editor.py +875 -0
- chibi/services/providers/tools/mcp_management.py +274 -0
- chibi/services/providers/tools/mcp_simple.py +72 -0
- chibi/services/providers/tools/media.py +451 -0
- chibi/services/providers/tools/memory.py +252 -0
- chibi/services/providers/tools/schemas.py +10 -0
- chibi/services/providers/tools/send.py +435 -0
- chibi/services/providers/tools/tool.py +163 -0
- chibi/services/providers/tools/utils.py +146 -0
- chibi/services/providers/tools/web.py +261 -0
- chibi/services/providers/utils.py +182 -0
- chibi/services/task_manager.py +93 -0
- chibi/services/user.py +269 -0
- chibi/storage/abstract.py +54 -0
- chibi/storage/database.py +86 -0
- chibi/storage/dynamodb.py +257 -0
- chibi/storage/local.py +70 -0
- chibi/storage/redis.py +91 -0
- chibi/utils/__init__.py +0 -0
- chibi/utils/app.py +249 -0
- chibi/utils/telegram.py +521 -0
- chibi_bot-1.6.0b0.dist-info/LICENSE +21 -0
- chibi_bot-1.6.0b0.dist-info/METADATA +340 -0
- chibi_bot-1.6.0b0.dist-info/RECORD +70 -0
- chibi_bot-1.6.0b0.dist-info/WHEEL +4 -0
- chibi_bot-1.6.0b0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import random
|
|
4
|
+
from asyncio import sleep
|
|
5
|
+
from typing import Any, Union
|
|
6
|
+
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from mistralai import ChatCompletionResponse, JSONSchemaTypedDict, Mistral, ResponseFormatTypedDict, TextChunk
|
|
9
|
+
from mistralai.models import (
|
|
10
|
+
AssistantMessage,
|
|
11
|
+
FunctionCall,
|
|
12
|
+
SystemMessage,
|
|
13
|
+
ToolCall,
|
|
14
|
+
ToolMessage,
|
|
15
|
+
UserMessage,
|
|
16
|
+
)
|
|
17
|
+
from openai.types.chat import ChatCompletionToolParam
|
|
18
|
+
from telegram import Update
|
|
19
|
+
from telegram.ext import ContextTypes
|
|
20
|
+
|
|
21
|
+
from chibi.config import application_settings, gpt_settings
|
|
22
|
+
from chibi.exceptions import NoApiKeyProvidedError, NoResponseError
|
|
23
|
+
from chibi.models import Message, User
|
|
24
|
+
from chibi.schemas.app import ChatResponseSchema, ModelChangeSchema, ModeratorsAnswer
|
|
25
|
+
from chibi.services.metrics import MetricsService
|
|
26
|
+
from chibi.services.providers.provider import RestApiFriendlyProvider
|
|
27
|
+
from chibi.services.providers.tools import RegisteredChibiTools
|
|
28
|
+
from chibi.services.providers.tools.constants import MODERATOR_PROMPT
|
|
29
|
+
from chibi.services.providers.utils import (
|
|
30
|
+
get_usage_from_mistral_response,
|
|
31
|
+
get_usage_msg,
|
|
32
|
+
prepare_system_prompt,
|
|
33
|
+
send_llm_thoughts,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
MistralMessageParam = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MistralAI(RestApiFriendlyProvider):
|
|
40
|
+
api_key = gpt_settings.mistralai_key
|
|
41
|
+
chat_ready = True
|
|
42
|
+
moderation_ready = True
|
|
43
|
+
|
|
44
|
+
name = "MistralAI"
|
|
45
|
+
model_name_keywords = ["mistral", "mixtral", "ministral"]
|
|
46
|
+
model_name_keywords_exclude = ["embed", "moderation", "ocr"]
|
|
47
|
+
default_model = "mistral-medium-latest"
|
|
48
|
+
default_moderation_model = "mistral-small-latest"
|
|
49
|
+
frequency_penalty: float | None = 0.6
|
|
50
|
+
max_tokens: int = gpt_settings.max_tokens
|
|
51
|
+
presence_penalty: float | None = 0.3
|
|
52
|
+
temperature: float = 0.3
|
|
53
|
+
|
|
54
|
+
def __init__(self, token: str) -> None:
|
|
55
|
+
self._client: Mistral | None = None
|
|
56
|
+
super().__init__(token=token)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def tools_list(self) -> list[ChatCompletionToolParam]:
|
|
60
|
+
"""Return tools in OpenAI-compatible format (which Mistral uses)."""
|
|
61
|
+
return RegisteredChibiTools.get_tool_definitions()
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def client(self) -> Mistral:
|
|
65
|
+
if self._client:
|
|
66
|
+
return self._client
|
|
67
|
+
|
|
68
|
+
if not self.token:
|
|
69
|
+
raise NoApiKeyProvidedError(provider=self.name)
|
|
70
|
+
|
|
71
|
+
self._client = Mistral(api_key=self.token)
|
|
72
|
+
return self._client
|
|
73
|
+
|
|
74
|
+
def get_thoughts(self, assistant_message: AssistantMessage) -> str | None:
|
|
75
|
+
if not assistant_message.content:
|
|
76
|
+
return None
|
|
77
|
+
content = assistant_message.content
|
|
78
|
+
if isinstance(content, str):
|
|
79
|
+
return content
|
|
80
|
+
|
|
81
|
+
if not isinstance(content, list):
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
for chunk in content:
|
|
85
|
+
if isinstance(chunk, TextChunk):
|
|
86
|
+
return chunk.text
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
async def _generate_content(
|
|
90
|
+
self,
|
|
91
|
+
model: str,
|
|
92
|
+
messages: list[MistralMessageParam],
|
|
93
|
+
) -> ChatCompletionResponse:
|
|
94
|
+
"""Generate content with retry logic."""
|
|
95
|
+
for attempt in range(gpt_settings.retries):
|
|
96
|
+
response = await self.client.chat.complete_async(
|
|
97
|
+
model=model,
|
|
98
|
+
messages=messages,
|
|
99
|
+
max_tokens=self.max_tokens,
|
|
100
|
+
temperature=self.temperature,
|
|
101
|
+
tools=self.tools_list, # type: ignore[arg-type]
|
|
102
|
+
tool_choice="auto",
|
|
103
|
+
http_headers={"Cache-Control": "max-age=86400"},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if response.choices and len(response.choices) > 0:
|
|
107
|
+
return response
|
|
108
|
+
|
|
109
|
+
delay = gpt_settings.backoff_factor * (2**attempt)
|
|
110
|
+
jitter = delay * random.uniform(0.1, 0.5)
|
|
111
|
+
total_delay = delay + jitter
|
|
112
|
+
|
|
113
|
+
logger.warning(
|
|
114
|
+
f"Attempt #{attempt + 1}. Unexpected (empty) response received. Retrying in {total_delay} seconds..."
|
|
115
|
+
)
|
|
116
|
+
await sleep(total_delay)
|
|
117
|
+
raise NoResponseError(provider=self.name, model=model, detail="Unexpected (empty) response received")
|
|
118
|
+
|
|
119
|
+
async def get_chat_response(
|
|
120
|
+
self,
|
|
121
|
+
messages: list[Message],
|
|
122
|
+
user: User | None = None,
|
|
123
|
+
model: str | None = None,
|
|
124
|
+
system_prompt: str = gpt_settings.assistant_prompt,
|
|
125
|
+
update: Update | None = None,
|
|
126
|
+
context: ContextTypes.DEFAULT_TYPE | None = None,
|
|
127
|
+
) -> tuple[ChatResponseSchema, list[Message]]:
|
|
128
|
+
model = model or self.default_model
|
|
129
|
+
initial_messages = [msg.to_mistral() for msg in messages]
|
|
130
|
+
chat_response, updated_messages = await self._get_chat_completion_response(
|
|
131
|
+
messages=initial_messages.copy(),
|
|
132
|
+
user=user,
|
|
133
|
+
model=model,
|
|
134
|
+
system_prompt=system_prompt,
|
|
135
|
+
context=context,
|
|
136
|
+
update=update,
|
|
137
|
+
)
|
|
138
|
+
new_messages = [msg for msg in updated_messages if msg not in initial_messages]
|
|
139
|
+
return (
|
|
140
|
+
chat_response,
|
|
141
|
+
[Message.from_mistral(msg) for msg in new_messages if not isinstance(msg, SystemMessage)],
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
async def _get_chat_completion_response(
|
|
145
|
+
self,
|
|
146
|
+
messages: list[MistralMessageParam],
|
|
147
|
+
model: str,
|
|
148
|
+
user: User | None = None,
|
|
149
|
+
system_prompt: str = gpt_settings.assistant_prompt,
|
|
150
|
+
update: Update | None = None,
|
|
151
|
+
context: ContextTypes.DEFAULT_TYPE | None = None,
|
|
152
|
+
) -> tuple[ChatResponseSchema, list[MistralMessageParam]]:
|
|
153
|
+
prepared_system_prompt = await prepare_system_prompt(base_system_prompt=system_prompt, user=user)
|
|
154
|
+
if not messages or not isinstance(messages[0], SystemMessage):
|
|
155
|
+
messages = [SystemMessage(content=prepared_system_prompt, role="system")] + messages
|
|
156
|
+
else:
|
|
157
|
+
messages = [SystemMessage(content=prepared_system_prompt, role="system")] + messages[1:]
|
|
158
|
+
|
|
159
|
+
response: ChatCompletionResponse = await self._generate_content(model=model, messages=messages)
|
|
160
|
+
usage = get_usage_from_mistral_response(response_message=response)
|
|
161
|
+
|
|
162
|
+
if application_settings.is_influx_configured:
|
|
163
|
+
MetricsService.send_usage_metrics(metric=usage, user=user, model=model, provider=self.name)
|
|
164
|
+
|
|
165
|
+
message_data = response.choices[0].message
|
|
166
|
+
|
|
167
|
+
tool_calls = message_data.tool_calls
|
|
168
|
+
if not tool_calls:
|
|
169
|
+
messages.append(
|
|
170
|
+
AssistantMessage(
|
|
171
|
+
content=message_data.content or "",
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
return ChatResponseSchema(
|
|
175
|
+
answer=message_data.content or "no data",
|
|
176
|
+
provider=self.name,
|
|
177
|
+
model=model,
|
|
178
|
+
usage=usage,
|
|
179
|
+
), messages
|
|
180
|
+
|
|
181
|
+
# Tool calls handling
|
|
182
|
+
logger.log("CALL", f"{model} requested the call of {len(tool_calls)} tools.")
|
|
183
|
+
|
|
184
|
+
thoughts = self.get_thoughts(assistant_message=message_data)
|
|
185
|
+
if thoughts:
|
|
186
|
+
await send_llm_thoughts(thoughts=thoughts, context=context, update=update)
|
|
187
|
+
|
|
188
|
+
logger.log("THINK", f"{model}: {thoughts}. {get_usage_msg(usage=usage)}")
|
|
189
|
+
tool_context: dict[str, Any] = {
|
|
190
|
+
"user_id": user.id if user else None,
|
|
191
|
+
"telegram_context": context,
|
|
192
|
+
"telegram_update": update,
|
|
193
|
+
"model": model,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
tool_coroutines = []
|
|
197
|
+
for tool_call in tool_calls:
|
|
198
|
+
function_args = (
|
|
199
|
+
json.loads(tool_call.function.arguments)
|
|
200
|
+
if isinstance(tool_call.function.arguments, str)
|
|
201
|
+
else tool_call.function.arguments
|
|
202
|
+
)
|
|
203
|
+
tool_coroutines.append(
|
|
204
|
+
RegisteredChibiTools.call(tool_name=tool_call.function.name, tools_args=tool_context | function_args)
|
|
205
|
+
)
|
|
206
|
+
results = await asyncio.gather(*tool_coroutines)
|
|
207
|
+
|
|
208
|
+
for tool_call, result in zip(tool_calls, results):
|
|
209
|
+
tool_call_message = AssistantMessage(
|
|
210
|
+
content=message_data.content or "",
|
|
211
|
+
tool_calls=[
|
|
212
|
+
ToolCall(
|
|
213
|
+
id=tool_call.id,
|
|
214
|
+
function=FunctionCall(
|
|
215
|
+
name=tool_call.function.name,
|
|
216
|
+
arguments=tool_call.function.arguments,
|
|
217
|
+
),
|
|
218
|
+
)
|
|
219
|
+
],
|
|
220
|
+
)
|
|
221
|
+
tool_result_message = ToolMessage(
|
|
222
|
+
name=tool_call.function.name,
|
|
223
|
+
content=result.model_dump_json(),
|
|
224
|
+
tool_call_id=tool_call.id,
|
|
225
|
+
)
|
|
226
|
+
messages.append(tool_call_message)
|
|
227
|
+
messages.append(tool_result_message)
|
|
228
|
+
|
|
229
|
+
logger.log("CALL", "All the function results have been obtained. Returning them to the LLM...")
|
|
230
|
+
return await self._get_chat_completion_response(
|
|
231
|
+
messages=messages,
|
|
232
|
+
model=model,
|
|
233
|
+
user=user,
|
|
234
|
+
system_prompt=system_prompt,
|
|
235
|
+
context=context,
|
|
236
|
+
update=update,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
async def moderate_command(self, cmd: str, model: str | None = None) -> ModeratorsAnswer:
|
|
240
|
+
moderator_model = model or self.default_moderation_model or self.default_model
|
|
241
|
+
messages = [
|
|
242
|
+
SystemMessage(content=MODERATOR_PROMPT, role="system"),
|
|
243
|
+
Message(role="user", content=cmd).to_mistral(),
|
|
244
|
+
]
|
|
245
|
+
response = await self.client.chat.complete_async(
|
|
246
|
+
model=moderator_model,
|
|
247
|
+
messages=messages,
|
|
248
|
+
max_tokens=1024,
|
|
249
|
+
temperature=0.2,
|
|
250
|
+
response_format=ResponseFormatTypedDict(
|
|
251
|
+
type="json_schema",
|
|
252
|
+
json_schema=JSONSchemaTypedDict(
|
|
253
|
+
name="moderator_verdict",
|
|
254
|
+
schema_definition={
|
|
255
|
+
"type": "object",
|
|
256
|
+
"properties": {
|
|
257
|
+
"verdict": {"type": "string"},
|
|
258
|
+
"reason": {"type": "string"},
|
|
259
|
+
"status": {"type": "string", "default": "ok"},
|
|
260
|
+
},
|
|
261
|
+
"required": ["verdict"],
|
|
262
|
+
},
|
|
263
|
+
strict=True,
|
|
264
|
+
),
|
|
265
|
+
),
|
|
266
|
+
)
|
|
267
|
+
if not response.choices:
|
|
268
|
+
return ModeratorsAnswer(status="error", verdict="declined", reason="no response from moderator received")
|
|
269
|
+
|
|
270
|
+
usage = get_usage_from_mistral_response(response_message=response)
|
|
271
|
+
|
|
272
|
+
if application_settings.is_influx_configured:
|
|
273
|
+
MetricsService.send_usage_metrics(metric=usage, model=moderator_model, provider=self.name)
|
|
274
|
+
|
|
275
|
+
message_data = response.choices[0].message
|
|
276
|
+
answer = message_data.content
|
|
277
|
+
if not answer:
|
|
278
|
+
return ModeratorsAnswer(status="error", verdict="declined", reason="no response from moderator received")
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
return ModeratorsAnswer.model_validate_json(answer, extra="ignore") # type: ignore
|
|
282
|
+
except Exception as e:
|
|
283
|
+
msg = f"Error parsing moderator's response: {answer}. Error: {e}"
|
|
284
|
+
logger.error(msg)
|
|
285
|
+
return ModeratorsAnswer(verdict="declined", reason=msg, status="error")
|
|
286
|
+
|
|
287
|
+
async def get_available_models(self, image_generation: bool = False) -> list[ModelChangeSchema]:
|
|
288
|
+
if image_generation:
|
|
289
|
+
return []
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
response = await self.client.models.list_async()
|
|
293
|
+
mistral_models = response.data or []
|
|
294
|
+
except Exception as e:
|
|
295
|
+
logger.error(f"Failed to get available models for provider {self.name} due to exception: {e}")
|
|
296
|
+
return []
|
|
297
|
+
|
|
298
|
+
all_models = [
|
|
299
|
+
ModelChangeSchema(
|
|
300
|
+
provider=self.name,
|
|
301
|
+
name=model.id,
|
|
302
|
+
display_name=model.id.replace("-", " ").title(),
|
|
303
|
+
image_generation=False,
|
|
304
|
+
)
|
|
305
|
+
for model in mistral_models
|
|
306
|
+
]
|
|
307
|
+
all_models.sort(key=lambda model: model.name)
|
|
308
|
+
|
|
309
|
+
if gpt_settings.models_whitelist:
|
|
310
|
+
return [model for model in all_models if model.name in gpt_settings.models_whitelist]
|
|
311
|
+
|
|
312
|
+
return all_models
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from openai import NOT_GIVEN
|
|
2
|
+
|
|
3
|
+
from chibi.config import gpt_settings
|
|
4
|
+
from chibi.services.providers.provider import OpenAIFriendlyProvider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MoonshotAI(OpenAIFriendlyProvider):
|
|
8
|
+
api_key = gpt_settings.moonshotai_key
|
|
9
|
+
chat_ready = True
|
|
10
|
+
moderation_ready = True
|
|
11
|
+
|
|
12
|
+
base_url = "https://api.moonshot.cn/v1"
|
|
13
|
+
name = "MoonshotAI"
|
|
14
|
+
model_name_keywords = ["moonshot", "kimi"]
|
|
15
|
+
model_name_keywords_exclude = ["vision"]
|
|
16
|
+
image_quality = NOT_GIVEN
|
|
17
|
+
image_size = NOT_GIVEN
|
|
18
|
+
default_model = "kimi-latest"
|
|
19
|
+
default_moderation_model = "kimi-k2-turbo-preview"
|
|
20
|
+
temperature = 0.3
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from io import BytesIO
|
|
2
|
+
|
|
3
|
+
from loguru import logger
|
|
4
|
+
from openai import NOT_GIVEN
|
|
5
|
+
from openai.types import ImagesResponse
|
|
6
|
+
|
|
7
|
+
from chibi.config import gpt_settings
|
|
8
|
+
from chibi.constants import OPENAI_TTS_INSTRUCTIONS
|
|
9
|
+
from chibi.services.providers.provider import OpenAIFriendlyProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OpenAI(OpenAIFriendlyProvider):
|
|
13
|
+
api_key = gpt_settings.openai_key
|
|
14
|
+
chat_ready = True
|
|
15
|
+
tts_ready = True
|
|
16
|
+
stt_ready = True
|
|
17
|
+
image_generation_ready = True
|
|
18
|
+
moderation_ready = True
|
|
19
|
+
|
|
20
|
+
name = "OpenAI"
|
|
21
|
+
model_name_prefixes = ["gpt", "o1", "o3", "o4"]
|
|
22
|
+
model_name_keywords_exclude = ["audio", "realtime", "transcribe", "tts", "image"]
|
|
23
|
+
base_url = "https://api.openai.com/v1"
|
|
24
|
+
max_tokens = NOT_GIVEN
|
|
25
|
+
default_model = "gpt-5.2"
|
|
26
|
+
default_image_model = "dall-e-3"
|
|
27
|
+
default_moderation_model = "gpt-5-mini"
|
|
28
|
+
default_stt_model = "whisper-1"
|
|
29
|
+
default_tts_model = "gpt-4o-mini-tts"
|
|
30
|
+
default_tts_voice = "nova"
|
|
31
|
+
|
|
32
|
+
async def transcribe(self, audio: BytesIO, model: str | None = None) -> str:
|
|
33
|
+
model = model or self.default_stt_model
|
|
34
|
+
logger.info(f"Transcribing audio with model {model}...")
|
|
35
|
+
response = await self.client.audio.transcriptions.create(
|
|
36
|
+
model=model,
|
|
37
|
+
file=("voice.ogg", audio.getvalue()),
|
|
38
|
+
)
|
|
39
|
+
if response:
|
|
40
|
+
logger.info(f"Transcribed text: {response.text}")
|
|
41
|
+
return response.text
|
|
42
|
+
raise ValueError("Could not transcribe audio message")
|
|
43
|
+
|
|
44
|
+
async def speech(self, text: str, voice: str | None = None, model: str | None = None) -> bytes:
|
|
45
|
+
voice = voice or self.default_tts_voice
|
|
46
|
+
model = model or self.default_tts_model
|
|
47
|
+
logger.info(f"Recording a voice message with model {model}...")
|
|
48
|
+
response = await self.client.audio.speech.create(
|
|
49
|
+
model=model,
|
|
50
|
+
voice=voice,
|
|
51
|
+
input=text,
|
|
52
|
+
instructions=OPENAI_TTS_INSTRUCTIONS,
|
|
53
|
+
)
|
|
54
|
+
return await response.aread()
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def is_image_ready_model(cls, model_name: str) -> bool:
|
|
58
|
+
return "dall-e" in model_name
|
|
59
|
+
|
|
60
|
+
async def _get_image_generation_response(self, prompt: str, model: str) -> ImagesResponse:
|
|
61
|
+
return await self.client.images.generate(
|
|
62
|
+
model=model,
|
|
63
|
+
prompt=prompt,
|
|
64
|
+
n=gpt_settings.image_n_choices if "dall-e-2" in model else 1,
|
|
65
|
+
quality=self.image_quality,
|
|
66
|
+
size=self.image_size if "dall-e-3" in model else NOT_GIVEN,
|
|
67
|
+
timeout=gpt_settings.timeout,
|
|
68
|
+
response_format="url",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def get_model_display_name(self, model_name: str) -> str:
|
|
72
|
+
if "dall" in model_name:
|
|
73
|
+
return model_name.replace("dall-e-", "DALL·E ")
|
|
74
|
+
return model_name.replace("-", " ")
|