codegnipy 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codegnipy/memory.py ADDED
@@ -0,0 +1,276 @@
1
+ """
2
+ Codegnipy 记忆存储模块
3
+
4
+ 提供会话记忆的存储、检索和管理功能。
5
+ """
6
+
7
+ import json
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+ from pathlib import Path
12
+ from typing import Optional, List, Dict, Any
13
+ from enum import Enum
14
+
15
+
16
+ class MessageRole(Enum):
17
+ """消息角色"""
18
+ USER = "user"
19
+ ASSISTANT = "assistant"
20
+ SYSTEM = "system"
21
+ REFLECTION = "reflection" # 反思消息
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ """单条消息"""
27
+ role: MessageRole
28
+ content: str
29
+ timestamp: float = field(default_factory=time.time)
30
+ metadata: Dict[str, Any] = field(default_factory=dict)
31
+
32
+ def to_dict(self) -> dict:
33
+ return {
34
+ "role": self.role.value,
35
+ "content": self.content,
36
+ "timestamp": self.timestamp,
37
+ "metadata": self.metadata
38
+ }
39
+
40
+ @classmethod
41
+ def from_dict(cls, data: dict) -> "Message":
42
+ return cls(
43
+ role=MessageRole(data["role"]),
44
+ content=data["content"],
45
+ timestamp=data.get("timestamp", time.time()),
46
+ metadata=data.get("metadata", {})
47
+ )
48
+
49
+ def to_openai_format(self) -> dict:
50
+ """转换为 OpenAI API 格式"""
51
+ return {
52
+ "role": self.role.value if self.role != MessageRole.REFLECTION else "system",
53
+ "content": self.content
54
+ }
55
+
56
+
57
+ class MemoryStore(ABC):
58
+ """记忆存储抽象基类"""
59
+
60
+ @abstractmethod
61
+ def add(self, message: Message) -> str:
62
+ """添加消息,返回消息 ID"""
63
+ pass
64
+
65
+ @abstractmethod
66
+ def get(self, message_id: str) -> Optional[Message]:
67
+ """获取单条消息"""
68
+ pass
69
+
70
+ @abstractmethod
71
+ def get_all(self) -> List[Message]:
72
+ """获取所有消息"""
73
+ pass
74
+
75
+ @abstractmethod
76
+ def get_recent(self, n: int) -> List[Message]:
77
+ """获取最近 n 条消息"""
78
+ pass
79
+
80
+ @abstractmethod
81
+ def clear(self) -> None:
82
+ """清空记忆"""
83
+ pass
84
+
85
+ @abstractmethod
86
+ def count(self) -> int:
87
+ """获取消息数量"""
88
+ pass
89
+
90
+ def add_user_message(self, content: str, **metadata) -> str:
91
+ """添加用户消息"""
92
+ return self.add(Message(MessageRole.USER, content, metadata=metadata))
93
+
94
+ def add_assistant_message(self, content: str, **metadata) -> str:
95
+ """添加助手消息"""
96
+ return self.add(Message(MessageRole.ASSISTANT, content, metadata=metadata))
97
+
98
+ def add_reflection(self, content: str, **metadata) -> str:
99
+ """添加反思消息"""
100
+ return self.add(Message(MessageRole.REFLECTION, content, metadata=metadata))
101
+
102
+ def to_openai_messages(self, include_reflections: bool = True) -> List[dict]:
103
+ """转换为 OpenAI API 消息格式"""
104
+ messages = []
105
+ for msg in self.get_all():
106
+ if msg.role == MessageRole.REFLECTION and not include_reflections:
107
+ continue
108
+ messages.append(msg.to_openai_format())
109
+ return messages
110
+
111
+
112
+ class InMemoryStore(MemoryStore):
113
+ """内存存储实现"""
114
+
115
+ def __init__(self):
116
+ self._messages: List[Message] = []
117
+ self._counter = 0
118
+
119
+ def add(self, message: Message) -> str:
120
+ self._counter += 1
121
+ message.metadata["_id"] = str(self._counter)
122
+ self._messages.append(message)
123
+ return message.metadata["_id"]
124
+
125
+ def get(self, message_id: str) -> Optional[Message]:
126
+ for msg in self._messages:
127
+ if msg.metadata.get("_id") == message_id:
128
+ return msg
129
+ return None
130
+
131
+ def get_all(self) -> List[Message]:
132
+ return self._messages.copy()
133
+
134
+ def get_recent(self, n: int) -> List[Message]:
135
+ return self._messages[-n:] if n > 0 else []
136
+
137
+ def clear(self) -> None:
138
+ self._messages.clear()
139
+
140
+ def count(self) -> int:
141
+ return len(self._messages)
142
+
143
+
144
+ class FileStore(MemoryStore):
145
+ """文件持久化存储"""
146
+
147
+ def __init__(self, filepath: str):
148
+ self.filepath = Path(filepath)
149
+ self._messages: List[Message] = []
150
+ self._counter = 0
151
+ self._load()
152
+
153
+ def _load(self) -> None:
154
+ """从文件加载记忆"""
155
+ if self.filepath.exists():
156
+ try:
157
+ with open(self.filepath, "r", encoding="utf-8") as f:
158
+ data = json.load(f)
159
+ self._messages = [Message.from_dict(m) for m in data.get("messages", [])]
160
+ self._counter = data.get("counter", 0)
161
+ except (json.JSONDecodeError, KeyError):
162
+ self._messages = []
163
+ self._counter = 0
164
+
165
+ def _save(self) -> None:
166
+ """保存记忆到文件"""
167
+ self.filepath.parent.mkdir(parents=True, exist_ok=True)
168
+ with open(self.filepath, "w", encoding="utf-8") as f:
169
+ json.dump({
170
+ "messages": [m.to_dict() for m in self._messages],
171
+ "counter": self._counter
172
+ }, f, ensure_ascii=False, indent=2)
173
+
174
+ def add(self, message: Message) -> str:
175
+ self._counter += 1
176
+ message.metadata["_id"] = str(self._counter)
177
+ self._messages.append(message)
178
+ self._save()
179
+ return message.metadata["_id"]
180
+
181
+ def get(self, message_id: str) -> Optional[Message]:
182
+ for msg in self._messages:
183
+ if msg.metadata.get("_id") == message_id:
184
+ return msg
185
+ return None
186
+
187
+ def get_all(self) -> List[Message]:
188
+ return self._messages.copy()
189
+
190
+ def get_recent(self, n: int) -> List[Message]:
191
+ return self._messages[-n:] if n > 0 else []
192
+
193
+ def clear(self) -> None:
194
+ self._messages.clear()
195
+ self._counter = 0
196
+ self._save()
197
+
198
+ def count(self) -> int:
199
+ return len(self._messages)
200
+
201
+
202
+ class ContextCompressor:
203
+ """上下文压缩器"""
204
+
205
+ def __init__(self, max_tokens: int = 4000, compression_ratio: float = 0.5):
206
+ """
207
+ 参数:
208
+ max_tokens: 最大 token 数(近似)
209
+ compression_ratio: 压缩比例
210
+ """
211
+ self.max_tokens = max_tokens
212
+ self.compression_ratio = compression_ratio
213
+
214
+ def estimate_tokens(self, text: str) -> int:
215
+ """估算文本 token 数(简单估算:4 字符 ≈ 1 token)"""
216
+ return len(text) // 4
217
+
218
+ def needs_compression(self, messages: List[Message]) -> bool:
219
+ """检查是否需要压缩"""
220
+ total = sum(self.estimate_tokens(m.content) for m in messages)
221
+ return total > self.max_tokens
222
+
223
+ def compress(self, messages: List[Message], summarizer=None) -> List[Message]:
224
+ """
225
+ 压缩消息历史
226
+
227
+ 参数:
228
+ messages: 消息列表
229
+ summarizer: 可选的摘要函数,接收消息列表返回摘要字符串
230
+
231
+ 返回:
232
+ 压缩后的消息列表
233
+ """
234
+ if not self.needs_compression(messages):
235
+ return messages
236
+
237
+ # 保留最近的消息
238
+ keep_recent = int(len(messages) * (1 - self.compression_ratio))
239
+ recent_messages = messages[-keep_recent:]
240
+ old_messages = messages[:-keep_recent]
241
+
242
+ if not old_messages:
243
+ return recent_messages
244
+
245
+ # 生成摘要
246
+ if summarizer:
247
+ summary = summarizer(old_messages)
248
+ else:
249
+ # 简单摘要:提取关键信息
250
+ summary = self._simple_summarize(old_messages)
251
+
252
+ # 创建摘要消息
253
+ summary_msg = Message(
254
+ role=MessageRole.SYSTEM,
255
+ content=f"[历史摘要] {summary}",
256
+ metadata={"compressed": True, "original_count": len(old_messages)}
257
+ )
258
+
259
+ return [summary_msg] + recent_messages
260
+
261
+ def _simple_summarize(self, messages: List[Message]) -> str:
262
+ """简单摘要:提取用户和助手的主要交互"""
263
+ interactions = []
264
+ current_turn = []
265
+
266
+ for msg in messages:
267
+ if msg.role in (MessageRole.USER, MessageRole.ASSISTANT):
268
+ current_turn.append(f"{msg.role.value}: {msg.content[:100]}...")
269
+ if len(current_turn) == 2:
270
+ interactions.append(" | ".join(current_turn))
271
+ current_turn = []
272
+
273
+ if current_turn:
274
+ interactions.append(" | ".join(current_turn))
275
+
276
+ return f"之前的 {len(messages)} 条消息已压缩。主要交互: {'; '.join(interactions[:5])}"