amrita_core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amrita_core/__init__.py +101 -0
- amrita_core/builtins/__init__.py +7 -0
- amrita_core/builtins/adapter.py +148 -0
- amrita_core/builtins/agent.py +415 -0
- amrita_core/builtins/tools.py +64 -0
- amrita_core/chatmanager.py +896 -0
- amrita_core/config.py +159 -0
- amrita_core/hook/event.py +90 -0
- amrita_core/hook/exception.py +14 -0
- amrita_core/hook/matcher.py +213 -0
- amrita_core/hook/on.py +14 -0
- amrita_core/libchat.py +189 -0
- amrita_core/logging.py +71 -0
- amrita_core/preset.py +166 -0
- amrita_core/protocol.py +101 -0
- amrita_core/tokenizer.py +115 -0
- amrita_core/tools/manager.py +163 -0
- amrita_core/tools/mcp.py +338 -0
- amrita_core/tools/models.py +353 -0
- amrita_core/types.py +274 -0
- amrita_core/utils.py +66 -0
- amrita_core-0.1.0.dist-info/METADATA +73 -0
- amrita_core-0.1.0.dist-info/RECORD +26 -0
- amrita_core-0.1.0.dist-info/WHEEL +5 -0
- amrita_core-0.1.0.dist-info/licenses/LICENSE +661 -0
- amrita_core-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,896 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import copy
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from asyncio import Lock, Task
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from collections.abc import AsyncGenerator
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from io import BytesIO, StringIO
|
|
10
|
+
from typing import Any, TypedDict
|
|
11
|
+
from uuid import uuid4
|
|
12
|
+
|
|
13
|
+
import pytz
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
from pytz import utc
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
from .config import AmritaConfig, get_config
|
|
19
|
+
from .hook.event import CompletionEvent, PreCompletionEvent
|
|
20
|
+
from .hook.matcher import MatcherManager
|
|
21
|
+
from .libchat import call_completion, get_last_response, get_tokens, text_generator
|
|
22
|
+
from .logging import logger
|
|
23
|
+
from .tokenizer import hybrid_token_count
|
|
24
|
+
from .types import (
|
|
25
|
+
CONTENT_LIST_TYPE,
|
|
26
|
+
CT_MAP,
|
|
27
|
+
USER_INPUT,
|
|
28
|
+
Content,
|
|
29
|
+
ImageContent,
|
|
30
|
+
Message,
|
|
31
|
+
SendMessageWrap,
|
|
32
|
+
TextContent,
|
|
33
|
+
ToolResult,
|
|
34
|
+
UniResponse,
|
|
35
|
+
UniResponseUsage,
|
|
36
|
+
)
|
|
37
|
+
from .types import (
|
|
38
|
+
MemoryModel as Memory,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Global lock for thread-safe operations in the chat manager
|
|
42
|
+
LOCK = Lock()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_current_datetime_timestamp(utc_time: None | datetime = None):
|
|
46
|
+
"""Get current time and format as date, weekday and time string in Asia/Shanghai timezone
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
utc_time: Optional datetime object in UTC. If None, current time will be used
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Formatted timestamp string in the format "[YYYY-MM-DD Weekday HH:MM:SS]"
|
|
53
|
+
"""
|
|
54
|
+
utc_time = utc_time or datetime.now(pytz.utc)
|
|
55
|
+
asia_shanghai = pytz.timezone("Asia/Shanghai")
|
|
56
|
+
now = utc_time.astimezone(asia_shanghai)
|
|
57
|
+
formatted_date = now.strftime("%Y-%m-%d")
|
|
58
|
+
formatted_weekday = now.strftime("%A")
|
|
59
|
+
formatted_time = now.strftime("%H:%M:%S")
|
|
60
|
+
return f"[{formatted_date} {formatted_weekday} {formatted_time}]"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MessageContent(ABC):
|
|
64
|
+
"""Abstract base class for different types of message content
|
|
65
|
+
|
|
66
|
+
This allows for various types of content to be yielded by the chat manager,
|
|
67
|
+
not just strings. Subclasses should implement their own representation.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, content_type: str):
|
|
71
|
+
self.type = content_type
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def get_content(self):
|
|
75
|
+
"""Return the actual content of the message"""
|
|
76
|
+
raise NotImplementedError("Subclasses must implement get_content method")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class StringMessageContent(MessageContent):
|
|
80
|
+
"""String type message content implementation"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, text: str):
|
|
83
|
+
super().__init__("string")
|
|
84
|
+
self.text = text
|
|
85
|
+
|
|
86
|
+
def get_content(self):
|
|
87
|
+
return self.text
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class RawMessageContent(MessageContent):
|
|
91
|
+
"""Raw message content implementation"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, raw_data: Any):
|
|
94
|
+
super().__init__("raw")
|
|
95
|
+
self.raw_data = raw_data
|
|
96
|
+
|
|
97
|
+
def get_content(self):
|
|
98
|
+
return self.raw_data
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class MessageMetadata(TypedDict):
|
|
102
|
+
content: str
|
|
103
|
+
metadata: dict[str, Any]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class MessageWithMetadata(MessageContent):
|
|
107
|
+
"""Message with additional metadata"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, content: Any, metadata: dict):
|
|
110
|
+
super().__init__("metadata")
|
|
111
|
+
self.content = content
|
|
112
|
+
self.metadata = metadata
|
|
113
|
+
|
|
114
|
+
def get_content(self) -> Any:
|
|
115
|
+
return self.content
|
|
116
|
+
|
|
117
|
+
def get_metadata(self) -> dict:
|
|
118
|
+
return self.metadata
|
|
119
|
+
|
|
120
|
+
def get_full_content(self) -> MessageMetadata:
|
|
121
|
+
return MessageMetadata(content=self.content, metadata=self.metadata)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ImageMessage(MessageContent):
|
|
125
|
+
"""Image message"""
|
|
126
|
+
|
|
127
|
+
def __init__(self, image: str | BytesIO):
|
|
128
|
+
super().__init__("image")
|
|
129
|
+
self.image: str | BytesIO = image
|
|
130
|
+
|
|
131
|
+
def get_content(self) -> BytesIO | str:
|
|
132
|
+
return self.image
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ChatObjectMeta(BaseModel):
|
|
136
|
+
"""Metadata model for chat object
|
|
137
|
+
|
|
138
|
+
Used to store identification, events, and time information for the chat object.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
stream_id: str # Chat stream ID
|
|
142
|
+
session_id: str # Session ID
|
|
143
|
+
user_input: list[TextContent | ImageContent] | str
|
|
144
|
+
time: datetime = Field(default_factory=datetime.now) # Creation time
|
|
145
|
+
last_call: datetime = Field(default_factory=datetime.now) # Last call time
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class MemoryLimiter:
|
|
149
|
+
"""Context processor
|
|
150
|
+
|
|
151
|
+
This class is responsible for handling context memory length and token count limits,
|
|
152
|
+
ensuring the chat context remains within predefined constraints by summarizing
|
|
153
|
+
context and removing messages to avoid exceeding the model processing capacity.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
config: AmritaConfig # Configuration object
|
|
157
|
+
usage: UniResponseUsage | None = None # Token usage, initially None
|
|
158
|
+
_train: dict[str, str] # Training data (system prompts)
|
|
159
|
+
_dropped_messages: list[Message[str] | ToolResult] # List of removed messages
|
|
160
|
+
_copied_messages: Memory # Original message copies (for rollback on exceptions)
|
|
161
|
+
_abstract_instruction = """<<SYS>>
|
|
162
|
+
You are a professional context summarizer, strictly following user instructions to perform summarization tasks.
|
|
163
|
+
<</SYS>>
|
|
164
|
+
|
|
165
|
+
<<INSTRUCTIONS>>
|
|
166
|
+
1. Directly summarize the user-provided content
|
|
167
|
+
2. Maintain core information and key details from the original
|
|
168
|
+
3. Do not generate any additional content, explanations, or comments
|
|
169
|
+
4. Summaries should be concise, accurate, complete
|
|
170
|
+
<</INSTRUCTIONS>>
|
|
171
|
+
|
|
172
|
+
<<RULE>>
|
|
173
|
+
- Only summarize the text provided by the user
|
|
174
|
+
- Do not add any explanations, comments, or supplementary information
|
|
175
|
+
- Do not alter the main meaning of the original
|
|
176
|
+
- Maintain an objective and neutral tone
|
|
177
|
+
<</RULE>>
|
|
178
|
+
|
|
179
|
+
<<FORMATTING>>
|
|
180
|
+
User input → Direct summary output
|
|
181
|
+
<</FORMATTING>>"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, memory: Memory, train: dict[str, str]) -> None:
|
|
184
|
+
"""Initialize context processor
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
memory: Memory model to process
|
|
188
|
+
train: Training data (system prompts)
|
|
189
|
+
"""
|
|
190
|
+
self.memory = memory
|
|
191
|
+
self.config = get_config()
|
|
192
|
+
self._train = train
|
|
193
|
+
|
|
194
|
+
async def __aenter__(self) -> Self:
|
|
195
|
+
"""Async context manager entry, initialize processing state
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Return instance for use
|
|
199
|
+
"""
|
|
200
|
+
self._dropped_messages = []
|
|
201
|
+
self._copied_messages = copy.deepcopy(self.memory)
|
|
202
|
+
logger.debug(
|
|
203
|
+
f"MemoryLimiter initialized, message count: {len(self.memory.messages)}"
|
|
204
|
+
)
|
|
205
|
+
return self
|
|
206
|
+
|
|
207
|
+
async def _make_abstract(self):
|
|
208
|
+
"""Generate context summary
|
|
209
|
+
|
|
210
|
+
By calling LLM to summarize all message content in the current memory into a brief content,
|
|
211
|
+
to reduce context length while preserving key information.
|
|
212
|
+
"""
|
|
213
|
+
logger.debug("Starting context summarization..")
|
|
214
|
+
proportion = self.config.llm.memory_abstract_proportion # Summary proportion
|
|
215
|
+
dropped_part: CONTENT_LIST_TYPE = copy.deepcopy(self._dropped_messages)
|
|
216
|
+
index = int(len(self.memory.messages) * proportion) - len(dropped_part)
|
|
217
|
+
if index < 0:
|
|
218
|
+
index = 0
|
|
219
|
+
idx: int | None = None
|
|
220
|
+
if index:
|
|
221
|
+
for idx, element in enumerate(self.memory.messages):
|
|
222
|
+
dropped_part.append(element)
|
|
223
|
+
if (
|
|
224
|
+
getattr(element, "tool_calls", None) is not None
|
|
225
|
+
): # Remove along with tool calls (system(tool_call),tool_call)
|
|
226
|
+
continue
|
|
227
|
+
elif idx >= index:
|
|
228
|
+
break
|
|
229
|
+
self.memory.messages = self.memory.messages[
|
|
230
|
+
(idx if idx is not None else index) : # Remove some messages
|
|
231
|
+
]
|
|
232
|
+
if dropped_part:
|
|
233
|
+
msg_list: CONTENT_LIST_TYPE = [
|
|
234
|
+
Message[str](role="system", content=self._abstract_instruction),
|
|
235
|
+
Message[str](
|
|
236
|
+
role="user",
|
|
237
|
+
content=(
|
|
238
|
+
"Message list:\n```text\n".join(
|
|
239
|
+
[
|
|
240
|
+
f"{it}\n"
|
|
241
|
+
for it in text_generator(
|
|
242
|
+
dropped_part,
|
|
243
|
+
split_role=True,
|
|
244
|
+
)
|
|
245
|
+
]
|
|
246
|
+
)
|
|
247
|
+
+ "\n```"
|
|
248
|
+
),
|
|
249
|
+
),
|
|
250
|
+
]
|
|
251
|
+
logger.debug("Performing context summarization...")
|
|
252
|
+
response = await get_last_response(call_completion(msg_list))
|
|
253
|
+
usage = await get_tokens(msg_list, response)
|
|
254
|
+
self.usage = usage
|
|
255
|
+
logger.debug(f"Context summary received: {response.content}")
|
|
256
|
+
self.memory.abstract = response.content
|
|
257
|
+
logger.debug("Context summarization completed")
|
|
258
|
+
else:
|
|
259
|
+
logger.debug("Context summarization skipped")
|
|
260
|
+
|
|
261
|
+
def _drop_message(self):
|
|
262
|
+
"""Remove the oldest message from memory and add it to dropped messages list.
|
|
263
|
+
|
|
264
|
+
This method removes the first message from the memory and adds it to the
|
|
265
|
+
dropped messages list. If the next message is a tool message, it is also
|
|
266
|
+
removed and added to the dropped messages list.
|
|
267
|
+
"""
|
|
268
|
+
data = self.memory
|
|
269
|
+
if len(data.messages) < 2:
|
|
270
|
+
return
|
|
271
|
+
self._dropped_messages.append(data.messages.pop(0))
|
|
272
|
+
if data.messages[0].role == "tool":
|
|
273
|
+
self._dropped_messages.append(data.messages.pop(0))
|
|
274
|
+
|
|
275
|
+
async def run_enforce(self):
|
|
276
|
+
"""Execute memory limitation processing
|
|
277
|
+
|
|
278
|
+
Execute memory length limitation and token count limitation in sequence,
|
|
279
|
+
ensuring the chat context stays within predefined ranges.
|
|
280
|
+
This method must be used within an async context manager.
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
RuntimeError: Thrown when not used in an async context manager
|
|
284
|
+
"""
|
|
285
|
+
logger.debug("Starting memory limitation processing..")
|
|
286
|
+
if not hasattr(self, "_dropped_messages") and not hasattr(
|
|
287
|
+
self, "_copied_messages"
|
|
288
|
+
):
|
|
289
|
+
raise RuntimeError(
|
|
290
|
+
"MemoryLimiter is not initialized, please use `async with MemoryLimiter(memory)` before calling."
|
|
291
|
+
)
|
|
292
|
+
await self._limit_length()
|
|
293
|
+
await self._limit_tokens()
|
|
294
|
+
if self.config.llm.enable_memory_abstract and self._dropped_messages:
|
|
295
|
+
await self._make_abstract()
|
|
296
|
+
logger.debug("Memory limitation processing completed")
|
|
297
|
+
|
|
298
|
+
async def _limit_length(self):
|
|
299
|
+
"""Control memory length, remove old messages that exceed the limit, remove unsupported messages."""
|
|
300
|
+
logger.debug("Starting memory length limitation..")
|
|
301
|
+
is_multimodal = get_config().llm.enable_multi_modal
|
|
302
|
+
data: Memory = self.memory
|
|
303
|
+
|
|
304
|
+
# Process multimodal messages when needed
|
|
305
|
+
for message in data.messages:
|
|
306
|
+
if (
|
|
307
|
+
isinstance(message.content, list)
|
|
308
|
+
and not is_multimodal
|
|
309
|
+
and message.role == "user"
|
|
310
|
+
):
|
|
311
|
+
message_text = ""
|
|
312
|
+
for content_part in message.content:
|
|
313
|
+
if isinstance(content_part, dict):
|
|
314
|
+
validator = CT_MAP.get(content_part["type"])
|
|
315
|
+
if not validator:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Invalid content type: {content_part['type']}"
|
|
318
|
+
)
|
|
319
|
+
content_part: Content = validator.model_validate(content_part)
|
|
320
|
+
if content_part["type"] == "text":
|
|
321
|
+
message_text += content_part["text"]
|
|
322
|
+
message.content = message_text
|
|
323
|
+
|
|
324
|
+
# Enforce memory length limit
|
|
325
|
+
initial_count = len(data.messages)
|
|
326
|
+
while len(data.messages) >= 2:
|
|
327
|
+
if data.messages[0].role == "tool":
|
|
328
|
+
data.messages.pop(0)
|
|
329
|
+
elif len(data.messages) > self.config.llm.memory_lenth_limit:
|
|
330
|
+
self._drop_message()
|
|
331
|
+
else:
|
|
332
|
+
break
|
|
333
|
+
final_count = len(data.messages)
|
|
334
|
+
logger.debug(
|
|
335
|
+
f"Memory length limitation completed, removed {initial_count - final_count} messages"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
async def _limit_tokens(self):
|
|
339
|
+
"""Control token count, remove old messages that exceed the limit
|
|
340
|
+
|
|
341
|
+
Calculate the token count of the current message list, when exceeding the configured session max token limit,
|
|
342
|
+
gradually remove the earliest messages until satisfying the token count limit.
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
def get_token(memory: CONTENT_LIST_TYPE) -> int:
|
|
346
|
+
"""Calculate the total token count for a given message list
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
memory: List of messages to calculate token count for
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Total token count for the messages
|
|
353
|
+
"""
|
|
354
|
+
tk_tmp: int = 0
|
|
355
|
+
for msg in text_generator(memory):
|
|
356
|
+
tk_tmp += hybrid_token_count(
|
|
357
|
+
msg,
|
|
358
|
+
self.config.llm.tokens_count_mode,
|
|
359
|
+
)
|
|
360
|
+
return tk_tmp
|
|
361
|
+
|
|
362
|
+
train = self._train
|
|
363
|
+
train_model = Message.model_validate(train)
|
|
364
|
+
data = self.memory
|
|
365
|
+
logger.debug("Starting token count limitation..")
|
|
366
|
+
memory_l: CONTENT_LIST_TYPE = [train_model, *data.messages]
|
|
367
|
+
if not self.config.llm.enable_tokens_limit:
|
|
368
|
+
logger.debug("Token limitation disabled, skipping processing")
|
|
369
|
+
return
|
|
370
|
+
prompt_length = hybrid_token_count(train["content"])
|
|
371
|
+
if prompt_length > self.config.llm.session_tokens_windows:
|
|
372
|
+
print(
|
|
373
|
+
f"Prompt size too large! It's {prompt_length}>{self.config.llm.session_tokens_windows}! Please adjusts the prompt or settings!"
|
|
374
|
+
)
|
|
375
|
+
return
|
|
376
|
+
tk_tmp: int = get_token(memory_l)
|
|
377
|
+
|
|
378
|
+
initial_count = len(data.messages)
|
|
379
|
+
while tk_tmp > self.config.llm.session_tokens_windows:
|
|
380
|
+
if len(data.messages) >= 2:
|
|
381
|
+
self._drop_message()
|
|
382
|
+
else:
|
|
383
|
+
break
|
|
384
|
+
|
|
385
|
+
tk_tmp: int = get_token(memory_l)
|
|
386
|
+
memory_l = [train_model, *data.messages]
|
|
387
|
+
await asyncio.sleep(
|
|
388
|
+
0
|
|
389
|
+
) # CPU intensive tasks may cause performance issues, yielding control here
|
|
390
|
+
final_count = len(data.messages)
|
|
391
|
+
logger.debug(
|
|
392
|
+
f"Token count limitation completed, removed {initial_count - final_count} messages"
|
|
393
|
+
)
|
|
394
|
+
logger.debug(f"Final token count: {tk_tmp}")
|
|
395
|
+
|
|
396
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
397
|
+
"""Async context manager exit, handle rollback in case of exceptions
|
|
398
|
+
|
|
399
|
+
In case of exceptions, restore messages to the state before processing,
|
|
400
|
+
ensuring data consistency.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
exc_type: Exception type
|
|
404
|
+
exc_val: Exception value
|
|
405
|
+
exc_tb: Exception traceback
|
|
406
|
+
"""
|
|
407
|
+
if exc_type is not None:
|
|
408
|
+
print("An exception occurred, rolling back messages...")
|
|
409
|
+
self.memory.messages = self._copied_messages.messages
|
|
410
|
+
return
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class ChatObject:
|
|
414
|
+
"""Chat processing object
|
|
415
|
+
|
|
416
|
+
This class is responsible for processing a single chat session, including message receiving,
|
|
417
|
+
context management, model calling, and response sending.
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
stream_id: str # Chat object ID
|
|
421
|
+
timestamp: str # Timestamp (for LLM)
|
|
422
|
+
time: datetime # Time
|
|
423
|
+
end_at: datetime | None = None
|
|
424
|
+
data: Memory # (lateinit) Memory file
|
|
425
|
+
user_input: USER_INPUT
|
|
426
|
+
user_message: Message[USER_INPUT] # (lateinit) User message
|
|
427
|
+
context_wrap: SendMessageWrap
|
|
428
|
+
train: dict[str, str] # System message
|
|
429
|
+
last_call: datetime # Last internal function call time
|
|
430
|
+
session_id: str
|
|
431
|
+
response: UniResponse[str, None]
|
|
432
|
+
response_queue: asyncio.Queue[Any]
|
|
433
|
+
overflow_queue: asyncio.Queue[Any]
|
|
434
|
+
_is_running: bool = False # Whether it is running
|
|
435
|
+
_is_done: bool = False # Whether it has completed
|
|
436
|
+
_task: Task[None]
|
|
437
|
+
_has_task: bool = False
|
|
438
|
+
_err: BaseException | None = None
|
|
439
|
+
_wait: bool = True
|
|
440
|
+
_queue_done: bool = False
|
|
441
|
+
__done_marker = object()
|
|
442
|
+
|
|
443
|
+
def __init__(
|
|
444
|
+
self,
|
|
445
|
+
train: dict[str, str],
|
|
446
|
+
user_input: USER_INPUT,
|
|
447
|
+
context: Memory,
|
|
448
|
+
session_id: str,
|
|
449
|
+
run_blocking: bool = True,
|
|
450
|
+
queue_size: int = 25,
|
|
451
|
+
overflow_queue_size: int = 45,
|
|
452
|
+
) -> None:
|
|
453
|
+
"""Initialize chat object
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
train: Training data (system prompts)
|
|
457
|
+
user_input: Input from the user
|
|
458
|
+
context: Memory context for the session
|
|
459
|
+
session_id: Unique identifier for the session
|
|
460
|
+
run_blocking: Whether to run in blocking mode
|
|
461
|
+
"""
|
|
462
|
+
self.train = train
|
|
463
|
+
self.data = context
|
|
464
|
+
self.session_id = session_id
|
|
465
|
+
self.user_input = user_input
|
|
466
|
+
self.user_message = Message(role="user", content=user_input)
|
|
467
|
+
self.timestamp = get_current_datetime_timestamp()
|
|
468
|
+
self.time = datetime.now(utc)
|
|
469
|
+
self.config: AmritaConfig = get_config()
|
|
470
|
+
self.last_call = datetime.now(utc)
|
|
471
|
+
self._wait = run_blocking
|
|
472
|
+
|
|
473
|
+
# Initialize async queue for streaming responses
|
|
474
|
+
self.response_queue = asyncio.Queue(queue_size)
|
|
475
|
+
self.overflow_queue = asyncio.Queue(overflow_queue_size)
|
|
476
|
+
self.stream_id = uuid4().hex
|
|
477
|
+
|
|
478
|
+
def get_exception(self) -> BaseException | None:
|
|
479
|
+
"""
|
|
480
|
+
Get exceptions that occurred during task execution
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
BaseException | None: Returns exception object if an exception occurred during task execution, otherwise returns None
|
|
484
|
+
"""
|
|
485
|
+
return self._err
|
|
486
|
+
|
|
487
|
+
def call(self):
|
|
488
|
+
"""
|
|
489
|
+
Get callable object
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Callable object (usually the class's __call__ method)
|
|
493
|
+
"""
|
|
494
|
+
return self.__call__()
|
|
495
|
+
|
|
496
|
+
def is_running(self) -> bool:
|
|
497
|
+
"""
|
|
498
|
+
Check if the task is running
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
bool: Returns True if the task is running, otherwise returns False
|
|
502
|
+
"""
|
|
503
|
+
return self._is_running
|
|
504
|
+
|
|
505
|
+
def is_done(self) -> bool:
|
|
506
|
+
"""
|
|
507
|
+
Check if the task has completed
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
bool: Returns True if the task has completed, otherwise returns False
|
|
511
|
+
"""
|
|
512
|
+
return self._is_done
|
|
513
|
+
|
|
514
|
+
def terminate(self) -> None:
|
|
515
|
+
"""
|
|
516
|
+
Terminate task execution
|
|
517
|
+
Sets the task status to completed and cancels the internal task
|
|
518
|
+
"""
|
|
519
|
+
self._is_done = True
|
|
520
|
+
self._is_running = False
|
|
521
|
+
self._task.cancel()
|
|
522
|
+
|
|
523
|
+
def __await__(self):
|
|
524
|
+
if not hasattr(self, "_task"):
|
|
525
|
+
raise RuntimeError("ChatObject not running")
|
|
526
|
+
return self._task.__await__()
|
|
527
|
+
|
|
528
|
+
async def __call__(self) -> None:
|
|
529
|
+
"""Call chat object to process messages
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
event: Message event
|
|
533
|
+
matcher: Matcher
|
|
534
|
+
bot: Bot instance
|
|
535
|
+
"""
|
|
536
|
+
if not self._has_task:
|
|
537
|
+
logger.debug("Starting chat object task...")
|
|
538
|
+
self._has_task = True
|
|
539
|
+
self._task = asyncio.create_task(self.__call__())
|
|
540
|
+
return await self._task if self._wait else None
|
|
541
|
+
if not self._is_running and not self._is_done:
|
|
542
|
+
self.stream_id = uuid4().hex
|
|
543
|
+
logger.debug(f"Starting chat processing, stream ID:{self.stream_id}")
|
|
544
|
+
|
|
545
|
+
try:
|
|
546
|
+
self._is_running = True
|
|
547
|
+
self.last_call = datetime.now(utc)
|
|
548
|
+
await chat_manager.add_chat_object(self)
|
|
549
|
+
|
|
550
|
+
await self._run()
|
|
551
|
+
finally:
|
|
552
|
+
self._is_running = False
|
|
553
|
+
self._is_done = True
|
|
554
|
+
self.end_at = datetime.now(utc)
|
|
555
|
+
chat_manager.running_chat_object_id2map.pop(self.stream_id, None)
|
|
556
|
+
logger.debug("Chat event processing completed")
|
|
557
|
+
|
|
558
|
+
else:
|
|
559
|
+
raise RuntimeError(
|
|
560
|
+
f"ChatObject of {self.stream_id} is already running or done"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
async def _run(self):
|
|
564
|
+
"""Run chat processing flow
|
|
565
|
+
|
|
566
|
+
Execute the main logic of message processing, including getting user information,
|
|
567
|
+
processing message content, managing context length and token limitations,
|
|
568
|
+
and sending responses.
|
|
569
|
+
"""
|
|
570
|
+
logger.debug("Starting chat processing flow..")
|
|
571
|
+
data = self.data
|
|
572
|
+
config = self.config
|
|
573
|
+
|
|
574
|
+
data.messages.append(self.user_message)
|
|
575
|
+
|
|
576
|
+
logger.debug(
|
|
577
|
+
f"Added user message to memory, current message count: {len(data.messages)}"
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
self.train["content"] = (
|
|
581
|
+
"<SCHEMA>\n"
|
|
582
|
+
+ (
|
|
583
|
+
f"<HIDDEN>{config.cookie.cookie}</HIDDEN>\n"
|
|
584
|
+
if config.cookie.enable_cookie
|
|
585
|
+
else ""
|
|
586
|
+
)
|
|
587
|
+
+ "Please participate in the discussion in your own character identity. Try not to use similar phrases when responding to different topics. User's messages are contained within user inputs."
|
|
588
|
+
+ "Your character setting is in the <SYSTEM_INSTRUCTIONS> tags, and the summary of previous conversations is in the <SUMMARY> tags."
|
|
589
|
+
+ "\n</SCHEMA>\n"
|
|
590
|
+
+ "<SYSTEM_INSTRUCTIONS>\n"
|
|
591
|
+
+ "\n</SYSTEM_INSTRUCTIONS>"
|
|
592
|
+
+ f"\n<SUMMARY>\n{data.abstract if config.llm.enable_memory_abstract else ''}\n</SUMMARY>"
|
|
593
|
+
)
|
|
594
|
+
logger.debug(self.train["content"])
|
|
595
|
+
|
|
596
|
+
logger.debug("Starting applying memory limitations..")
|
|
597
|
+
async with MemoryLimiter(self.data, self.train) as lim:
|
|
598
|
+
await lim.run_enforce()
|
|
599
|
+
abs_usage = lim.usage
|
|
600
|
+
self.data = lim.memory
|
|
601
|
+
logger.debug("Memory limitation application completed")
|
|
602
|
+
|
|
603
|
+
send_messages = self._prepare_send_messages()
|
|
604
|
+
self.context_wrap = SendMessageWrap.validate_messages(send_messages)
|
|
605
|
+
logger.debug(
|
|
606
|
+
f"Preparing sending messages completed, message count: {len(send_messages)}"
|
|
607
|
+
)
|
|
608
|
+
response: UniResponse[str, None] = await self._process_chat(send_messages)
|
|
609
|
+
if response.usage and abs_usage:
|
|
610
|
+
response.usage.completion_tokens += abs_usage.completion_tokens
|
|
611
|
+
response.usage.prompt_tokens += abs_usage.prompt_tokens
|
|
612
|
+
response.usage.total_tokens += abs_usage.total_tokens
|
|
613
|
+
|
|
614
|
+
logger.debug("Chat processing completed, preparing to send response")
|
|
615
|
+
await self.set_queue_done()
|
|
616
|
+
|
|
617
|
+
def queue_closed(self) -> bool:
|
|
618
|
+
"""Check if the response queue is closed
|
|
619
|
+
|
|
620
|
+
Returns:
|
|
621
|
+
True if the queue is closed, False otherwise
|
|
622
|
+
"""
|
|
623
|
+
return self._queue_done
|
|
624
|
+
|
|
625
|
+
async def set_queue_done(self) -> None:
|
|
626
|
+
"""Mark the response queue as done by putting the done marker"""
|
|
627
|
+
if not self.queue_closed():
|
|
628
|
+
await self._put_to_queue(self.__done_marker)
|
|
629
|
+
|
|
630
|
+
async def _put_to_queue(self, item):
|
|
631
|
+
"""Put an item to the queue, using overflow mechanism if primary queue is full
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
item: Item to put in the queue
|
|
635
|
+
"""
|
|
636
|
+
try:
|
|
637
|
+
self.response_queue.put_nowait(item)
|
|
638
|
+
except asyncio.QueueFull:
|
|
639
|
+
try:
|
|
640
|
+
self.overflow_queue.put_nowait(item)
|
|
641
|
+
except asyncio.QueueFull:
|
|
642
|
+
timeout = 5
|
|
643
|
+
while timeout > 0:
|
|
644
|
+
await asyncio.sleep(1)
|
|
645
|
+
timeout -= 1
|
|
646
|
+
try:
|
|
647
|
+
self.response_queue.put_nowait(item)
|
|
648
|
+
return
|
|
649
|
+
except asyncio.QueueFull:
|
|
650
|
+
try:
|
|
651
|
+
self.overflow_queue.put_nowait(item)
|
|
652
|
+
return
|
|
653
|
+
except asyncio.QueueFull:
|
|
654
|
+
continue
|
|
655
|
+
|
|
656
|
+
# After waiting, if still full, raise an exception
|
|
657
|
+
raise RuntimeError(
|
|
658
|
+
"Both primary and overflow queues are full after waiting"
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
async def yield_response(self, response: str | MessageContent):
|
|
662
|
+
"""Send chat model response to the queue allowing both str and MessageContent types.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
response: Either a string or MessageContent object to be sent to the queue
|
|
666
|
+
"""
|
|
667
|
+
if not self.queue_closed():
|
|
668
|
+
await self._put_to_queue(response)
|
|
669
|
+
else:
|
|
670
|
+
raise RuntimeError("Queue is closed.")
|
|
671
|
+
|
|
672
|
+
async def yield_response_iteration(
|
|
673
|
+
self, iterator: AsyncGenerator[str | MessageContent, None]
|
|
674
|
+
):
|
|
675
|
+
"""Send chat model response to the queue allowing both str and MessageContent types.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
iterator: An async generator that yields either strings or MessageContent objects
|
|
679
|
+
"""
|
|
680
|
+
async for chunk in iterator:
|
|
681
|
+
await self.yield_response(chunk)
|
|
682
|
+
|
|
683
|
+
def get_response_generator(self) -> AsyncGenerator[str | MessageContent, None]:
|
|
684
|
+
"""Return an async generator to iterate over responses from the queue.
|
|
685
|
+
|
|
686
|
+
Yields:
|
|
687
|
+
Either a string or MessageContent object from the response queue
|
|
688
|
+
"""
|
|
689
|
+
return self._response_generator()
|
|
690
|
+
|
|
691
|
+
async def full_response(self) -> str:
|
|
692
|
+
"""Return full response from the queue as a single string.
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Complete response string combining all chunks in the queue
|
|
696
|
+
"""
|
|
697
|
+
builder = StringIO()
|
|
698
|
+
async for item in self.get_response_generator():
|
|
699
|
+
if isinstance(item, str):
|
|
700
|
+
builder.write(item)
|
|
701
|
+
elif isinstance(item, MessageContent):
|
|
702
|
+
builder.write(str(item.get_content()))
|
|
703
|
+
return builder.getvalue()
|
|
704
|
+
|
|
705
|
+
async def _response_generator(self) -> AsyncGenerator[str | MessageContent, None]:
|
|
706
|
+
"""Internal method to asynchronously yield items from the queue until done marker is reached.
|
|
707
|
+
|
|
708
|
+
Yields:
|
|
709
|
+
Items from the response queue until the done marker is encountered
|
|
710
|
+
"""
|
|
711
|
+
# Yield from primary queue first
|
|
712
|
+
while True:
|
|
713
|
+
# Check primary queue first
|
|
714
|
+
while not self.response_queue.empty():
|
|
715
|
+
item = await self.response_queue.get()
|
|
716
|
+
self.response_queue.task_done()
|
|
717
|
+
if item is self.__done_marker:
|
|
718
|
+
return
|
|
719
|
+
yield item
|
|
720
|
+
|
|
721
|
+
# If primary queue is empty, check overflow queue
|
|
722
|
+
if not self.overflow_queue.empty():
|
|
723
|
+
item = await self.overflow_queue.get()
|
|
724
|
+
self.overflow_queue.task_done()
|
|
725
|
+
if item is self.__done_marker:
|
|
726
|
+
return
|
|
727
|
+
yield item
|
|
728
|
+
else:
|
|
729
|
+
if (
|
|
730
|
+
self.queue_closed()
|
|
731
|
+
and self.response_queue.empty()
|
|
732
|
+
and self.overflow_queue.empty()
|
|
733
|
+
):
|
|
734
|
+
break
|
|
735
|
+
# Otherwise, wait a bit before checking again
|
|
736
|
+
await asyncio.sleep(0.01)
|
|
737
|
+
|
|
738
|
+
async def _process_chat(
|
|
739
|
+
self,
|
|
740
|
+
send_messages: CONTENT_LIST_TYPE,
|
|
741
|
+
) -> UniResponse[str, None]:
|
|
742
|
+
"""Call chat model to generate response and trigger related events.
|
|
743
|
+
|
|
744
|
+
Args:
|
|
745
|
+
send_messages: Send message list
|
|
746
|
+
extra_usage: Extra token usage information
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
Model response
|
|
750
|
+
"""
|
|
751
|
+
self.last_call = datetime.now(utc)
|
|
752
|
+
|
|
753
|
+
data = self.data
|
|
754
|
+
logger.debug(
|
|
755
|
+
f"Starting chat processing, sending message count: {len(send_messages)}"
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
logger.debug("Triggering matcher functions..")
|
|
759
|
+
messages = self.context_wrap
|
|
760
|
+
chat_event = PreCompletionEvent(
|
|
761
|
+
chat_object=self,
|
|
762
|
+
user_input=self.user_input,
|
|
763
|
+
original_context=messages, # 使用包含系统消息的完整消息列表
|
|
764
|
+
)
|
|
765
|
+
await MatcherManager.trigger_event(chat_event)
|
|
766
|
+
self.data.messages = chat_event.get_context_messages().unwrap(
|
|
767
|
+
exclude_system=True
|
|
768
|
+
)
|
|
769
|
+
send_messages = chat_event.get_context_messages().unwrap()
|
|
770
|
+
|
|
771
|
+
logger.debug("Calling chat model..")
|
|
772
|
+
response: UniResponse[str, None] | None = None
|
|
773
|
+
async for chunk in call_completion(send_messages):
|
|
774
|
+
if isinstance(chunk, str):
|
|
775
|
+
await self.yield_response(StringMessageContent(chunk))
|
|
776
|
+
elif isinstance(chunk, UniResponse):
|
|
777
|
+
response = chunk
|
|
778
|
+
elif isinstance(chunk, MessageContent):
|
|
779
|
+
await self.yield_response(chunk)
|
|
780
|
+
if response is None:
|
|
781
|
+
raise RuntimeError("No final response from chat adapter.")
|
|
782
|
+
self.response = response
|
|
783
|
+
logger.debug("Triggering chat events..")
|
|
784
|
+
chat_event = CompletionEvent(self.user_input, messages, self, response.content)
|
|
785
|
+
await MatcherManager.trigger_event(chat_event)
|
|
786
|
+
response.content = chat_event.model_response
|
|
787
|
+
data.messages.append(
|
|
788
|
+
Message[str](
|
|
789
|
+
content=response.content,
|
|
790
|
+
role="assistant",
|
|
791
|
+
)
|
|
792
|
+
)
|
|
793
|
+
logger.debug(
|
|
794
|
+
f"Added assistant response to memory, current message count: {len(data.messages)}"
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
logger.debug("Chat processing completed")
|
|
798
|
+
return response
|
|
799
|
+
|
|
800
|
+
def _prepare_send_messages(
|
|
801
|
+
self,
|
|
802
|
+
) -> list:
|
|
803
|
+
"""Prepare message list to send to the chat model, including system prompt data and context.
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
Prepared message list to send
|
|
807
|
+
"""
|
|
808
|
+
self.last_call = datetime.now(utc)
|
|
809
|
+
logger.debug("Preparing messages to send..")
|
|
810
|
+
train: Message[str] = Message[str].model_validate(self.train)
|
|
811
|
+
data = self.data
|
|
812
|
+
messages = [train, *copy.deepcopy(data.messages)]
|
|
813
|
+
logger.debug(f"Messages preparation completed, total {len(messages)} messages")
|
|
814
|
+
return messages
|
|
815
|
+
|
|
816
|
+
def get_snapshot(self) -> ChatObjectMeta:
|
|
817
|
+
"""Get a snapshot of the chat object
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
Chat object metadata
|
|
821
|
+
"""
|
|
822
|
+
return ChatObjectMeta.model_validate(self, from_attributes=True)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
@dataclass
|
|
826
|
+
class ChatManager:
|
|
827
|
+
custom_menu: list[dict[str, str]] = field(default_factory=list)
|
|
828
|
+
running_messages_poke: dict[str, Any] = field(default_factory=dict)
|
|
829
|
+
running_chat_object: defaultdict[str, list[ChatObject]] = field(
|
|
830
|
+
default_factory=lambda: defaultdict(list)
|
|
831
|
+
)
|
|
832
|
+
running_chat_object_id2map: dict[str, ChatObjectMeta] = field(default_factory=dict)
|
|
833
|
+
|
|
834
|
+
def clean_obj(self, k: str, maxitems: int = 10):
|
|
835
|
+
"""
|
|
836
|
+
Clean up running chat objects under the specified key, keeping only the first 10 objects,
|
|
837
|
+
removing any excess unfinished parts
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
k (tuple[int, bool]): Key value, composed of instance ID and whether it's group chat
|
|
841
|
+
maxitems (int, optional): Maximum number of objects. Defaults to 10.
|
|
842
|
+
"""
|
|
843
|
+
objs = self.running_chat_object[k]
|
|
844
|
+
if len(objs) > maxitems:
|
|
845
|
+
dropped_obj = objs[maxitems:]
|
|
846
|
+
objs = [obj for obj in dropped_obj if not obj.is_done()] + objs[:maxitems]
|
|
847
|
+
dropped_obj = [obj for obj in dropped_obj if obj.is_done()]
|
|
848
|
+
for obj in dropped_obj:
|
|
849
|
+
self.running_chat_object_id2map.pop(obj.stream_id, None)
|
|
850
|
+
self.running_chat_object[k] = objs
|
|
851
|
+
|
|
852
|
+
def get_all_objs(self) -> list[ChatObjectMeta]:
|
|
853
|
+
"""
|
|
854
|
+
Get all running chat object metadata
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
list[ChatObjectMeta]: List of all running chat object metadata
|
|
858
|
+
"""
|
|
859
|
+
return list(self.running_chat_object_id2map.values())
|
|
860
|
+
|
|
861
|
+
def get_objs(self, session_id: str) -> list[ChatObject]:
|
|
862
|
+
"""
|
|
863
|
+
Get the corresponding list of chat objects based on the session ID
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
session_id (str): User session ID
|
|
867
|
+
|
|
868
|
+
Returns:
|
|
869
|
+
list[ChatObject]: List of chat objects
|
|
870
|
+
"""
|
|
871
|
+
return self.running_chat_object[session_id]
|
|
872
|
+
|
|
873
|
+
async def clean_chat_objects(self, maxitems: int = 10) -> None:
|
|
874
|
+
"""
|
|
875
|
+
Asynchronously clean up all running chat objects, limiting the number of objects for each key to no more than 10
|
|
876
|
+
"""
|
|
877
|
+
async with LOCK:
|
|
878
|
+
for key in self.running_chat_object.keys():
|
|
879
|
+
self.clean_obj(key, maxitems)
|
|
880
|
+
|
|
881
|
+
async def add_chat_object(self, chat_object: ChatObject) -> None:
|
|
882
|
+
"""
|
|
883
|
+
Add a new chat object to the running list
|
|
884
|
+
|
|
885
|
+
Args:
|
|
886
|
+
chat_object (ChatObject): Chat object instance
|
|
887
|
+
"""
|
|
888
|
+
async with LOCK:
|
|
889
|
+
meta: ChatObjectMeta = chat_object.get_snapshot()
|
|
890
|
+
self.running_chat_object_id2map[chat_object.stream_id] = meta
|
|
891
|
+
key = chat_object.session_id
|
|
892
|
+
self.running_chat_object[key].insert(0, chat_object)
|
|
893
|
+
self.clean_obj(key)
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
chat_manager = ChatManager()
|