amrita_core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,896 @@
1
+ import asyncio
2
+ import copy
3
+ from abc import ABC, abstractmethod
4
+ from asyncio import Lock, Task
5
+ from collections import defaultdict
6
+ from collections.abc import AsyncGenerator
7
+ from dataclasses import dataclass, field
8
+ from datetime import datetime
9
+ from io import BytesIO, StringIO
10
+ from typing import Any, TypedDict
11
+ from uuid import uuid4
12
+
13
+ import pytz
14
+ from pydantic import BaseModel, Field
15
+ from pytz import utc
16
+ from typing_extensions import Self
17
+
18
+ from .config import AmritaConfig, get_config
19
+ from .hook.event import CompletionEvent, PreCompletionEvent
20
+ from .hook.matcher import MatcherManager
21
+ from .libchat import call_completion, get_last_response, get_tokens, text_generator
22
+ from .logging import logger
23
+ from .tokenizer import hybrid_token_count
24
+ from .types import (
25
+ CONTENT_LIST_TYPE,
26
+ CT_MAP,
27
+ USER_INPUT,
28
+ Content,
29
+ ImageContent,
30
+ Message,
31
+ SendMessageWrap,
32
+ TextContent,
33
+ ToolResult,
34
+ UniResponse,
35
+ UniResponseUsage,
36
+ )
37
+ from .types import (
38
+ MemoryModel as Memory,
39
+ )
40
+
41
+ # Global lock for thread-safe operations in the chat manager
42
+ LOCK = Lock()
43
+
44
+
45
+ def get_current_datetime_timestamp(utc_time: None | datetime = None):
46
+ """Get current time and format as date, weekday and time string in Asia/Shanghai timezone
47
+
48
+ Args:
49
+ utc_time: Optional datetime object in UTC. If None, current time will be used
50
+
51
+ Returns:
52
+ Formatted timestamp string in the format "[YYYY-MM-DD Weekday HH:MM:SS]"
53
+ """
54
+ utc_time = utc_time or datetime.now(pytz.utc)
55
+ asia_shanghai = pytz.timezone("Asia/Shanghai")
56
+ now = utc_time.astimezone(asia_shanghai)
57
+ formatted_date = now.strftime("%Y-%m-%d")
58
+ formatted_weekday = now.strftime("%A")
59
+ formatted_time = now.strftime("%H:%M:%S")
60
+ return f"[{formatted_date} {formatted_weekday} {formatted_time}]"
61
+
62
+
63
+ class MessageContent(ABC):
64
+ """Abstract base class for different types of message content
65
+
66
+ This allows for various types of content to be yielded by the chat manager,
67
+ not just strings. Subclasses should implement their own representation.
68
+ """
69
+
70
+ def __init__(self, content_type: str):
71
+ self.type = content_type
72
+
73
+ @abstractmethod
74
+ def get_content(self):
75
+ """Return the actual content of the message"""
76
+ raise NotImplementedError("Subclasses must implement get_content method")
77
+
78
+
79
+ class StringMessageContent(MessageContent):
80
+ """String type message content implementation"""
81
+
82
+ def __init__(self, text: str):
83
+ super().__init__("string")
84
+ self.text = text
85
+
86
+ def get_content(self):
87
+ return self.text
88
+
89
+
90
+ class RawMessageContent(MessageContent):
91
+ """Raw message content implementation"""
92
+
93
+ def __init__(self, raw_data: Any):
94
+ super().__init__("raw")
95
+ self.raw_data = raw_data
96
+
97
+ def get_content(self):
98
+ return self.raw_data
99
+
100
+
101
+ class MessageMetadata(TypedDict):
102
+ content: str
103
+ metadata: dict[str, Any]
104
+
105
+
106
+ class MessageWithMetadata(MessageContent):
107
+ """Message with additional metadata"""
108
+
109
+ def __init__(self, content: Any, metadata: dict):
110
+ super().__init__("metadata")
111
+ self.content = content
112
+ self.metadata = metadata
113
+
114
+ def get_content(self) -> Any:
115
+ return self.content
116
+
117
+ def get_metadata(self) -> dict:
118
+ return self.metadata
119
+
120
+ def get_full_content(self) -> MessageMetadata:
121
+ return MessageMetadata(content=self.content, metadata=self.metadata)
122
+
123
+
124
+ class ImageMessage(MessageContent):
125
+ """Image message"""
126
+
127
+ def __init__(self, image: str | BytesIO):
128
+ super().__init__("image")
129
+ self.image: str | BytesIO = image
130
+
131
+ def get_content(self) -> BytesIO | str:
132
+ return self.image
133
+
134
+
135
+ class ChatObjectMeta(BaseModel):
136
+ """Metadata model for chat object
137
+
138
+ Used to store identification, events, and time information for the chat object.
139
+ """
140
+
141
+ stream_id: str # Chat stream ID
142
+ session_id: str # Session ID
143
+ user_input: list[TextContent | ImageContent] | str
144
+ time: datetime = Field(default_factory=datetime.now) # Creation time
145
+ last_call: datetime = Field(default_factory=datetime.now) # Last call time
146
+
147
+
148
+ class MemoryLimiter:
149
+ """Context processor
150
+
151
+ This class is responsible for handling context memory length and token count limits,
152
+ ensuring the chat context remains within predefined constraints by summarizing
153
+ context and removing messages to avoid exceeding the model processing capacity.
154
+ """
155
+
156
+ config: AmritaConfig # Configuration object
157
+ usage: UniResponseUsage | None = None # Token usage, initially None
158
+ _train: dict[str, str] # Training data (system prompts)
159
+ _dropped_messages: list[Message[str] | ToolResult] # List of removed messages
160
+ _copied_messages: Memory # Original message copies (for rollback on exceptions)
161
+ _abstract_instruction = """<<SYS>>
162
+ You are a professional context summarizer, strictly following user instructions to perform summarization tasks.
163
+ <</SYS>>
164
+
165
+ <<INSTRUCTIONS>>
166
+ 1. Directly summarize the user-provided content
167
+ 2. Maintain core information and key details from the original
168
+ 3. Do not generate any additional content, explanations, or comments
169
+ 4. Summaries should be concise, accurate, complete
170
+ <</INSTRUCTIONS>>
171
+
172
+ <<RULE>>
173
+ - Only summarize the text provided by the user
174
+ - Do not add any explanations, comments, or supplementary information
175
+ - Do not alter the main meaning of the original
176
+ - Maintain an objective and neutral tone
177
+ <</RULE>>
178
+
179
+ <<FORMATTING>>
180
+ User input → Direct summary output
181
+ <</FORMATTING>>"""
182
+
183
+ def __init__(self, memory: Memory, train: dict[str, str]) -> None:
184
+ """Initialize context processor
185
+
186
+ Args:
187
+ memory: Memory model to process
188
+ train: Training data (system prompts)
189
+ """
190
+ self.memory = memory
191
+ self.config = get_config()
192
+ self._train = train
193
+
194
+ async def __aenter__(self) -> Self:
195
+ """Async context manager entry, initialize processing state
196
+
197
+ Returns:
198
+ Return instance for use
199
+ """
200
+ self._dropped_messages = []
201
+ self._copied_messages = copy.deepcopy(self.memory)
202
+ logger.debug(
203
+ f"MemoryLimiter initialized, message count: {len(self.memory.messages)}"
204
+ )
205
+ return self
206
+
207
+ async def _make_abstract(self):
208
+ """Generate context summary
209
+
210
+ By calling LLM to summarize all message content in the current memory into a brief content,
211
+ to reduce context length while preserving key information.
212
+ """
213
+ logger.debug("Starting context summarization..")
214
+ proportion = self.config.llm.memory_abstract_proportion # Summary proportion
215
+ dropped_part: CONTENT_LIST_TYPE = copy.deepcopy(self._dropped_messages)
216
+ index = int(len(self.memory.messages) * proportion) - len(dropped_part)
217
+ if index < 0:
218
+ index = 0
219
+ idx: int | None = None
220
+ if index:
221
+ for idx, element in enumerate(self.memory.messages):
222
+ dropped_part.append(element)
223
+ if (
224
+ getattr(element, "tool_calls", None) is not None
225
+ ): # Remove along with tool calls (system(tool_call),tool_call)
226
+ continue
227
+ elif idx >= index:
228
+ break
229
+ self.memory.messages = self.memory.messages[
230
+ (idx if idx is not None else index) : # Remove some messages
231
+ ]
232
+ if dropped_part:
233
+ msg_list: CONTENT_LIST_TYPE = [
234
+ Message[str](role="system", content=self._abstract_instruction),
235
+ Message[str](
236
+ role="user",
237
+ content=(
238
+ "Message list:\n```text\n".join(
239
+ [
240
+ f"{it}\n"
241
+ for it in text_generator(
242
+ dropped_part,
243
+ split_role=True,
244
+ )
245
+ ]
246
+ )
247
+ + "\n```"
248
+ ),
249
+ ),
250
+ ]
251
+ logger.debug("Performing context summarization...")
252
+ response = await get_last_response(call_completion(msg_list))
253
+ usage = await get_tokens(msg_list, response)
254
+ self.usage = usage
255
+ logger.debug(f"Context summary received: {response.content}")
256
+ self.memory.abstract = response.content
257
+ logger.debug("Context summarization completed")
258
+ else:
259
+ logger.debug("Context summarization skipped")
260
+
261
+ def _drop_message(self):
262
+ """Remove the oldest message from memory and add it to dropped messages list.
263
+
264
+ This method removes the first message from the memory and adds it to the
265
+ dropped messages list. If the next message is a tool message, it is also
266
+ removed and added to the dropped messages list.
267
+ """
268
+ data = self.memory
269
+ if len(data.messages) < 2:
270
+ return
271
+ self._dropped_messages.append(data.messages.pop(0))
272
+ if data.messages[0].role == "tool":
273
+ self._dropped_messages.append(data.messages.pop(0))
274
+
275
+ async def run_enforce(self):
276
+ """Execute memory limitation processing
277
+
278
+ Execute memory length limitation and token count limitation in sequence,
279
+ ensuring the chat context stays within predefined ranges.
280
+ This method must be used within an async context manager.
281
+
282
+ Raises:
283
+ RuntimeError: Thrown when not used in an async context manager
284
+ """
285
+ logger.debug("Starting memory limitation processing..")
286
+ if not hasattr(self, "_dropped_messages") and not hasattr(
287
+ self, "_copied_messages"
288
+ ):
289
+ raise RuntimeError(
290
+ "MemoryLimiter is not initialized, please use `async with MemoryLimiter(memory)` before calling."
291
+ )
292
+ await self._limit_length()
293
+ await self._limit_tokens()
294
+ if self.config.llm.enable_memory_abstract and self._dropped_messages:
295
+ await self._make_abstract()
296
+ logger.debug("Memory limitation processing completed")
297
+
298
+ async def _limit_length(self):
299
+ """Control memory length, remove old messages that exceed the limit, remove unsupported messages."""
300
+ logger.debug("Starting memory length limitation..")
301
+ is_multimodal = get_config().llm.enable_multi_modal
302
+ data: Memory = self.memory
303
+
304
+ # Process multimodal messages when needed
305
+ for message in data.messages:
306
+ if (
307
+ isinstance(message.content, list)
308
+ and not is_multimodal
309
+ and message.role == "user"
310
+ ):
311
+ message_text = ""
312
+ for content_part in message.content:
313
+ if isinstance(content_part, dict):
314
+ validator = CT_MAP.get(content_part["type"])
315
+ if not validator:
316
+ raise ValueError(
317
+ f"Invalid content type: {content_part['type']}"
318
+ )
319
+ content_part: Content = validator.model_validate(content_part)
320
+ if content_part["type"] == "text":
321
+ message_text += content_part["text"]
322
+ message.content = message_text
323
+
324
+ # Enforce memory length limit
325
+ initial_count = len(data.messages)
326
+ while len(data.messages) >= 2:
327
+ if data.messages[0].role == "tool":
328
+ data.messages.pop(0)
329
+ elif len(data.messages) > self.config.llm.memory_lenth_limit:
330
+ self._drop_message()
331
+ else:
332
+ break
333
+ final_count = len(data.messages)
334
+ logger.debug(
335
+ f"Memory length limitation completed, removed {initial_count - final_count} messages"
336
+ )
337
+
338
+ async def _limit_tokens(self):
339
+ """Control token count, remove old messages that exceed the limit
340
+
341
+ Calculate the token count of the current message list, when exceeding the configured session max token limit,
342
+ gradually remove the earliest messages until satisfying the token count limit.
343
+ """
344
+
345
+ def get_token(memory: CONTENT_LIST_TYPE) -> int:
346
+ """Calculate the total token count for a given message list
347
+
348
+ Args:
349
+ memory: List of messages to calculate token count for
350
+
351
+ Returns:
352
+ Total token count for the messages
353
+ """
354
+ tk_tmp: int = 0
355
+ for msg in text_generator(memory):
356
+ tk_tmp += hybrid_token_count(
357
+ msg,
358
+ self.config.llm.tokens_count_mode,
359
+ )
360
+ return tk_tmp
361
+
362
+ train = self._train
363
+ train_model = Message.model_validate(train)
364
+ data = self.memory
365
+ logger.debug("Starting token count limitation..")
366
+ memory_l: CONTENT_LIST_TYPE = [train_model, *data.messages]
367
+ if not self.config.llm.enable_tokens_limit:
368
+ logger.debug("Token limitation disabled, skipping processing")
369
+ return
370
+ prompt_length = hybrid_token_count(train["content"])
371
+ if prompt_length > self.config.llm.session_tokens_windows:
372
+ print(
373
+ f"Prompt size too large! It's {prompt_length}>{self.config.llm.session_tokens_windows}! Please adjusts the prompt or settings!"
374
+ )
375
+ return
376
+ tk_tmp: int = get_token(memory_l)
377
+
378
+ initial_count = len(data.messages)
379
+ while tk_tmp > self.config.llm.session_tokens_windows:
380
+ if len(data.messages) >= 2:
381
+ self._drop_message()
382
+ else:
383
+ break
384
+
385
+ tk_tmp: int = get_token(memory_l)
386
+ memory_l = [train_model, *data.messages]
387
+ await asyncio.sleep(
388
+ 0
389
+ ) # CPU intensive tasks may cause performance issues, yielding control here
390
+ final_count = len(data.messages)
391
+ logger.debug(
392
+ f"Token count limitation completed, removed {initial_count - final_count} messages"
393
+ )
394
+ logger.debug(f"Final token count: {tk_tmp}")
395
+
396
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
397
+ """Async context manager exit, handle rollback in case of exceptions
398
+
399
+ In case of exceptions, restore messages to the state before processing,
400
+ ensuring data consistency.
401
+
402
+ Args:
403
+ exc_type: Exception type
404
+ exc_val: Exception value
405
+ exc_tb: Exception traceback
406
+ """
407
+ if exc_type is not None:
408
+ print("An exception occurred, rolling back messages...")
409
+ self.memory.messages = self._copied_messages.messages
410
+ return
411
+
412
+
413
+ class ChatObject:
414
+ """Chat processing object
415
+
416
+ This class is responsible for processing a single chat session, including message receiving,
417
+ context management, model calling, and response sending.
418
+ """
419
+
420
+ stream_id: str # Chat object ID
421
+ timestamp: str # Timestamp (for LLM)
422
+ time: datetime # Time
423
+ end_at: datetime | None = None
424
+ data: Memory # (lateinit) Memory file
425
+ user_input: USER_INPUT
426
+ user_message: Message[USER_INPUT] # (lateinit) User message
427
+ context_wrap: SendMessageWrap
428
+ train: dict[str, str] # System message
429
+ last_call: datetime # Last internal function call time
430
+ session_id: str
431
+ response: UniResponse[str, None]
432
+ response_queue: asyncio.Queue[Any]
433
+ overflow_queue: asyncio.Queue[Any]
434
+ _is_running: bool = False # Whether it is running
435
+ _is_done: bool = False # Whether it has completed
436
+ _task: Task[None]
437
+ _has_task: bool = False
438
+ _err: BaseException | None = None
439
+ _wait: bool = True
440
+ _queue_done: bool = False
441
+ __done_marker = object()
442
+
443
+ def __init__(
444
+ self,
445
+ train: dict[str, str],
446
+ user_input: USER_INPUT,
447
+ context: Memory,
448
+ session_id: str,
449
+ run_blocking: bool = True,
450
+ queue_size: int = 25,
451
+ overflow_queue_size: int = 45,
452
+ ) -> None:
453
+ """Initialize chat object
454
+
455
+ Args:
456
+ train: Training data (system prompts)
457
+ user_input: Input from the user
458
+ context: Memory context for the session
459
+ session_id: Unique identifier for the session
460
+ run_blocking: Whether to run in blocking mode
461
+ """
462
+ self.train = train
463
+ self.data = context
464
+ self.session_id = session_id
465
+ self.user_input = user_input
466
+ self.user_message = Message(role="user", content=user_input)
467
+ self.timestamp = get_current_datetime_timestamp()
468
+ self.time = datetime.now(utc)
469
+ self.config: AmritaConfig = get_config()
470
+ self.last_call = datetime.now(utc)
471
+ self._wait = run_blocking
472
+
473
+ # Initialize async queue for streaming responses
474
+ self.response_queue = asyncio.Queue(queue_size)
475
+ self.overflow_queue = asyncio.Queue(overflow_queue_size)
476
+ self.stream_id = uuid4().hex
477
+
478
+ def get_exception(self) -> BaseException | None:
479
+ """
480
+ Get exceptions that occurred during task execution
481
+
482
+ Returns:
483
+ BaseException | None: Returns exception object if an exception occurred during task execution, otherwise returns None
484
+ """
485
+ return self._err
486
+
487
+ def call(self):
488
+ """
489
+ Get callable object
490
+
491
+ Returns:
492
+ Callable object (usually the class's __call__ method)
493
+ """
494
+ return self.__call__()
495
+
496
+ def is_running(self) -> bool:
497
+ """
498
+ Check if the task is running
499
+
500
+ Returns:
501
+ bool: Returns True if the task is running, otherwise returns False
502
+ """
503
+ return self._is_running
504
+
505
+ def is_done(self) -> bool:
506
+ """
507
+ Check if the task has completed
508
+
509
+ Returns:
510
+ bool: Returns True if the task has completed, otherwise returns False
511
+ """
512
+ return self._is_done
513
+
514
+ def terminate(self) -> None:
515
+ """
516
+ Terminate task execution
517
+ Sets the task status to completed and cancels the internal task
518
+ """
519
+ self._is_done = True
520
+ self._is_running = False
521
+ self._task.cancel()
522
+
523
+ def __await__(self):
524
+ if not hasattr(self, "_task"):
525
+ raise RuntimeError("ChatObject not running")
526
+ return self._task.__await__()
527
+
528
+ async def __call__(self) -> None:
529
+ """Call chat object to process messages
530
+
531
+ Args:
532
+ event: Message event
533
+ matcher: Matcher
534
+ bot: Bot instance
535
+ """
536
+ if not self._has_task:
537
+ logger.debug("Starting chat object task...")
538
+ self._has_task = True
539
+ self._task = asyncio.create_task(self.__call__())
540
+ return await self._task if self._wait else None
541
+ if not self._is_running and not self._is_done:
542
+ self.stream_id = uuid4().hex
543
+ logger.debug(f"Starting chat processing, stream ID:{self.stream_id}")
544
+
545
+ try:
546
+ self._is_running = True
547
+ self.last_call = datetime.now(utc)
548
+ await chat_manager.add_chat_object(self)
549
+
550
+ await self._run()
551
+ finally:
552
+ self._is_running = False
553
+ self._is_done = True
554
+ self.end_at = datetime.now(utc)
555
+ chat_manager.running_chat_object_id2map.pop(self.stream_id, None)
556
+ logger.debug("Chat event processing completed")
557
+
558
+ else:
559
+ raise RuntimeError(
560
+ f"ChatObject of {self.stream_id} is already running or done"
561
+ )
562
+
563
+ async def _run(self):
564
+ """Run chat processing flow
565
+
566
+ Execute the main logic of message processing, including getting user information,
567
+ processing message content, managing context length and token limitations,
568
+ and sending responses.
569
+ """
570
+ logger.debug("Starting chat processing flow..")
571
+ data = self.data
572
+ config = self.config
573
+
574
+ data.messages.append(self.user_message)
575
+
576
+ logger.debug(
577
+ f"Added user message to memory, current message count: {len(data.messages)}"
578
+ )
579
+
580
+ self.train["content"] = (
581
+ "<SCHEMA>\n"
582
+ + (
583
+ f"<HIDDEN>{config.cookie.cookie}</HIDDEN>\n"
584
+ if config.cookie.enable_cookie
585
+ else ""
586
+ )
587
+ + "Please participate in the discussion in your own character identity. Try not to use similar phrases when responding to different topics. User's messages are contained within user inputs."
588
+ + "Your character setting is in the <SYSTEM_INSTRUCTIONS> tags, and the summary of previous conversations is in the <SUMMARY> tags."
589
+ + "\n</SCHEMA>\n"
590
+ + "<SYSTEM_INSTRUCTIONS>\n"
591
+ + "\n</SYSTEM_INSTRUCTIONS>"
592
+ + f"\n<SUMMARY>\n{data.abstract if config.llm.enable_memory_abstract else ''}\n</SUMMARY>"
593
+ )
594
+ logger.debug(self.train["content"])
595
+
596
+ logger.debug("Starting applying memory limitations..")
597
+ async with MemoryLimiter(self.data, self.train) as lim:
598
+ await lim.run_enforce()
599
+ abs_usage = lim.usage
600
+ self.data = lim.memory
601
+ logger.debug("Memory limitation application completed")
602
+
603
+ send_messages = self._prepare_send_messages()
604
+ self.context_wrap = SendMessageWrap.validate_messages(send_messages)
605
+ logger.debug(
606
+ f"Preparing sending messages completed, message count: {len(send_messages)}"
607
+ )
608
+ response: UniResponse[str, None] = await self._process_chat(send_messages)
609
+ if response.usage and abs_usage:
610
+ response.usage.completion_tokens += abs_usage.completion_tokens
611
+ response.usage.prompt_tokens += abs_usage.prompt_tokens
612
+ response.usage.total_tokens += abs_usage.total_tokens
613
+
614
+ logger.debug("Chat processing completed, preparing to send response")
615
+ await self.set_queue_done()
616
+
617
+ def queue_closed(self) -> bool:
618
+ """Check if the response queue is closed
619
+
620
+ Returns:
621
+ True if the queue is closed, False otherwise
622
+ """
623
+ return self._queue_done
624
+
625
+ async def set_queue_done(self) -> None:
626
+ """Mark the response queue as done by putting the done marker"""
627
+ if not self.queue_closed():
628
+ await self._put_to_queue(self.__done_marker)
629
+
630
+ async def _put_to_queue(self, item):
631
+ """Put an item to the queue, using overflow mechanism if primary queue is full
632
+
633
+ Args:
634
+ item: Item to put in the queue
635
+ """
636
+ try:
637
+ self.response_queue.put_nowait(item)
638
+ except asyncio.QueueFull:
639
+ try:
640
+ self.overflow_queue.put_nowait(item)
641
+ except asyncio.QueueFull:
642
+ timeout = 5
643
+ while timeout > 0:
644
+ await asyncio.sleep(1)
645
+ timeout -= 1
646
+ try:
647
+ self.response_queue.put_nowait(item)
648
+ return
649
+ except asyncio.QueueFull:
650
+ try:
651
+ self.overflow_queue.put_nowait(item)
652
+ return
653
+ except asyncio.QueueFull:
654
+ continue
655
+
656
+ # After waiting, if still full, raise an exception
657
+ raise RuntimeError(
658
+ "Both primary and overflow queues are full after waiting"
659
+ )
660
+
661
+ async def yield_response(self, response: str | MessageContent):
662
+ """Send chat model response to the queue allowing both str and MessageContent types.
663
+
664
+ Args:
665
+ response: Either a string or MessageContent object to be sent to the queue
666
+ """
667
+ if not self.queue_closed():
668
+ await self._put_to_queue(response)
669
+ else:
670
+ raise RuntimeError("Queue is closed.")
671
+
672
+ async def yield_response_iteration(
673
+ self, iterator: AsyncGenerator[str | MessageContent, None]
674
+ ):
675
+ """Send chat model response to the queue allowing both str and MessageContent types.
676
+
677
+ Args:
678
+ iterator: An async generator that yields either strings or MessageContent objects
679
+ """
680
+ async for chunk in iterator:
681
+ await self.yield_response(chunk)
682
+
683
+ def get_response_generator(self) -> AsyncGenerator[str | MessageContent, None]:
684
+ """Return an async generator to iterate over responses from the queue.
685
+
686
+ Yields:
687
+ Either a string or MessageContent object from the response queue
688
+ """
689
+ return self._response_generator()
690
+
691
+ async def full_response(self) -> str:
692
+ """Return full response from the queue as a single string.
693
+
694
+ Returns:
695
+ Complete response string combining all chunks in the queue
696
+ """
697
+ builder = StringIO()
698
+ async for item in self.get_response_generator():
699
+ if isinstance(item, str):
700
+ builder.write(item)
701
+ elif isinstance(item, MessageContent):
702
+ builder.write(str(item.get_content()))
703
+ return builder.getvalue()
704
+
705
+ async def _response_generator(self) -> AsyncGenerator[str | MessageContent, None]:
706
+ """Internal method to asynchronously yield items from the queue until done marker is reached.
707
+
708
+ Yields:
709
+ Items from the response queue until the done marker is encountered
710
+ """
711
+ # Yield from primary queue first
712
+ while True:
713
+ # Check primary queue first
714
+ while not self.response_queue.empty():
715
+ item = await self.response_queue.get()
716
+ self.response_queue.task_done()
717
+ if item is self.__done_marker:
718
+ return
719
+ yield item
720
+
721
+ # If primary queue is empty, check overflow queue
722
+ if not self.overflow_queue.empty():
723
+ item = await self.overflow_queue.get()
724
+ self.overflow_queue.task_done()
725
+ if item is self.__done_marker:
726
+ return
727
+ yield item
728
+ else:
729
+ if (
730
+ self.queue_closed()
731
+ and self.response_queue.empty()
732
+ and self.overflow_queue.empty()
733
+ ):
734
+ break
735
+ # Otherwise, wait a bit before checking again
736
+ await asyncio.sleep(0.01)
737
+
738
+ async def _process_chat(
739
+ self,
740
+ send_messages: CONTENT_LIST_TYPE,
741
+ ) -> UniResponse[str, None]:
742
+ """Call chat model to generate response and trigger related events.
743
+
744
+ Args:
745
+ send_messages: Send message list
746
+ extra_usage: Extra token usage information
747
+
748
+ Returns:
749
+ Model response
750
+ """
751
+ self.last_call = datetime.now(utc)
752
+
753
+ data = self.data
754
+ logger.debug(
755
+ f"Starting chat processing, sending message count: {len(send_messages)}"
756
+ )
757
+
758
+ logger.debug("Triggering matcher functions..")
759
+ messages = self.context_wrap
760
+ chat_event = PreCompletionEvent(
761
+ chat_object=self,
762
+ user_input=self.user_input,
763
+ original_context=messages, # 使用包含系统消息的完整消息列表
764
+ )
765
+ await MatcherManager.trigger_event(chat_event)
766
+ self.data.messages = chat_event.get_context_messages().unwrap(
767
+ exclude_system=True
768
+ )
769
+ send_messages = chat_event.get_context_messages().unwrap()
770
+
771
+ logger.debug("Calling chat model..")
772
+ response: UniResponse[str, None] | None = None
773
+ async for chunk in call_completion(send_messages):
774
+ if isinstance(chunk, str):
775
+ await self.yield_response(StringMessageContent(chunk))
776
+ elif isinstance(chunk, UniResponse):
777
+ response = chunk
778
+ elif isinstance(chunk, MessageContent):
779
+ await self.yield_response(chunk)
780
+ if response is None:
781
+ raise RuntimeError("No final response from chat adapter.")
782
+ self.response = response
783
+ logger.debug("Triggering chat events..")
784
+ chat_event = CompletionEvent(self.user_input, messages, self, response.content)
785
+ await MatcherManager.trigger_event(chat_event)
786
+ response.content = chat_event.model_response
787
+ data.messages.append(
788
+ Message[str](
789
+ content=response.content,
790
+ role="assistant",
791
+ )
792
+ )
793
+ logger.debug(
794
+ f"Added assistant response to memory, current message count: {len(data.messages)}"
795
+ )
796
+
797
+ logger.debug("Chat processing completed")
798
+ return response
799
+
800
+ def _prepare_send_messages(
801
+ self,
802
+ ) -> list:
803
+ """Prepare message list to send to the chat model, including system prompt data and context.
804
+
805
+ Returns:
806
+ Prepared message list to send
807
+ """
808
+ self.last_call = datetime.now(utc)
809
+ logger.debug("Preparing messages to send..")
810
+ train: Message[str] = Message[str].model_validate(self.train)
811
+ data = self.data
812
+ messages = [train, *copy.deepcopy(data.messages)]
813
+ logger.debug(f"Messages preparation completed, total {len(messages)} messages")
814
+ return messages
815
+
816
+ def get_snapshot(self) -> ChatObjectMeta:
817
+ """Get a snapshot of the chat object
818
+
819
+ Returns:
820
+ Chat object metadata
821
+ """
822
+ return ChatObjectMeta.model_validate(self, from_attributes=True)
823
+
824
+
825
+ @dataclass
826
+ class ChatManager:
827
+ custom_menu: list[dict[str, str]] = field(default_factory=list)
828
+ running_messages_poke: dict[str, Any] = field(default_factory=dict)
829
+ running_chat_object: defaultdict[str, list[ChatObject]] = field(
830
+ default_factory=lambda: defaultdict(list)
831
+ )
832
+ running_chat_object_id2map: dict[str, ChatObjectMeta] = field(default_factory=dict)
833
+
834
+ def clean_obj(self, k: str, maxitems: int = 10):
835
+ """
836
+ Clean up running chat objects under the specified key, keeping only the first 10 objects,
837
+ removing any excess unfinished parts
838
+
839
+ Args:
840
+ k (tuple[int, bool]): Key value, composed of instance ID and whether it's group chat
841
+ maxitems (int, optional): Maximum number of objects. Defaults to 10.
842
+ """
843
+ objs = self.running_chat_object[k]
844
+ if len(objs) > maxitems:
845
+ dropped_obj = objs[maxitems:]
846
+ objs = [obj for obj in dropped_obj if not obj.is_done()] + objs[:maxitems]
847
+ dropped_obj = [obj for obj in dropped_obj if obj.is_done()]
848
+ for obj in dropped_obj:
849
+ self.running_chat_object_id2map.pop(obj.stream_id, None)
850
+ self.running_chat_object[k] = objs
851
+
852
+ def get_all_objs(self) -> list[ChatObjectMeta]:
853
+ """
854
+ Get all running chat object metadata
855
+
856
+ Returns:
857
+ list[ChatObjectMeta]: List of all running chat object metadata
858
+ """
859
+ return list(self.running_chat_object_id2map.values())
860
+
861
+ def get_objs(self, session_id: str) -> list[ChatObject]:
862
+ """
863
+ Get the corresponding list of chat objects based on the session ID
864
+
865
+ Args:
866
+ session_id (str): User session ID
867
+
868
+ Returns:
869
+ list[ChatObject]: List of chat objects
870
+ """
871
+ return self.running_chat_object[session_id]
872
+
873
+ async def clean_chat_objects(self, maxitems: int = 10) -> None:
874
+ """
875
+ Asynchronously clean up all running chat objects, limiting the number of objects for each key to no more than 10
876
+ """
877
+ async with LOCK:
878
+ for key in self.running_chat_object.keys():
879
+ self.clean_obj(key, maxitems)
880
+
881
+ async def add_chat_object(self, chat_object: ChatObject) -> None:
882
+ """
883
+ Add a new chat object to the running list
884
+
885
+ Args:
886
+ chat_object (ChatObject): Chat object instance
887
+ """
888
+ async with LOCK:
889
+ meta: ChatObjectMeta = chat_object.get_snapshot()
890
+ self.running_chat_object_id2map[chat_object.stream_id] = meta
891
+ key = chat_object.session_id
892
+ self.running_chat_object[key].insert(0, chat_object)
893
+ self.clean_obj(key)
894
+
895
+
896
+ chat_manager = ChatManager()