mcp-memory-server 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimemory/__init__.py +3 -0
- aimemory/config.py +267 -0
- aimemory/dataset/__init__.py +16 -0
- aimemory/dataset/builder.py +270 -0
- aimemory/dataset/splitter.py +191 -0
- aimemory/dataset/stats.py +300 -0
- aimemory/i18n/__init__.py +95 -0
- aimemory/i18n/en.py +157 -0
- aimemory/i18n/ko.py +179 -0
- aimemory/mcp/__init__.py +5 -0
- aimemory/mcp/__main__.py +5 -0
- aimemory/mcp/bridge.py +487 -0
- aimemory/mcp/server.py +314 -0
- aimemory/memory/__init__.py +50 -0
- aimemory/memory/composer.py +157 -0
- aimemory/memory/consolidation.py +197 -0
- aimemory/memory/forgetting.py +178 -0
- aimemory/memory/graph_retriever.py +143 -0
- aimemory/memory/graph_store.py +492 -0
- aimemory/memory/knowledge_graph.py +184 -0
- aimemory/memory/resolution.py +140 -0
- aimemory/memory/sleep_cycle.py +289 -0
- aimemory/online/__init__.py +35 -0
- aimemory/online/ab_comparator.py +182 -0
- aimemory/online/autonomy.py +115 -0
- aimemory/online/enhanced_encoder.py +53 -0
- aimemory/online/enhanced_policy.py +164 -0
- aimemory/online/gossip.py +280 -0
- aimemory/online/policy.py +438 -0
- aimemory/online/replay_buffer.py +53 -0
- aimemory/online/reranker.py +570 -0
- aimemory/online/rule_verifier.py +57 -0
- aimemory/online/transport.py +162 -0
- aimemory/reward/__init__.py +31 -0
- aimemory/reward/calculator.py +258 -0
- aimemory/reward/feedback_detector.py +209 -0
- aimemory/reward/implicit_detector.py +66 -0
- aimemory/reward/korean_patterns.py +269 -0
- aimemory/reward/signals.py +524 -0
- aimemory/schemas.py +169 -0
- aimemory/selfplay/__init__.py +31 -0
- aimemory/selfplay/engine.py +427 -0
- aimemory/selfplay/llm_client.py +172 -0
- aimemory/selfplay/memory_agent.py +381 -0
- aimemory/selfplay/scenarios.py +201 -0
- mcp_memory_server-0.2.0.dist-info/METADATA +219 -0
- mcp_memory_server-0.2.0.dist-info/RECORD +50 -0
- mcp_memory_server-0.2.0.dist-info/WHEEL +4 -0
- mcp_memory_server-0.2.0.dist-info/entry_points.txt +2 -0
- mcp_memory_server-0.2.0.dist-info/licenses/LICENSE +21 -0
aimemory/__init__.py
ADDED
aimemory/config.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Configuration for the AI Memory System."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OllamaConfig(BaseModel):
|
|
14
|
+
"""Ollama LLM configuration."""
|
|
15
|
+
|
|
16
|
+
base_url: str = "http://localhost:11434"
|
|
17
|
+
model: str = "exaone3.5:7.8b"
|
|
18
|
+
timeout: float = 120.0
|
|
19
|
+
max_retries: int = 3
|
|
20
|
+
temperature: float = 0.7
|
|
21
|
+
top_p: float = 0.9
|
|
22
|
+
max_tokens: int = 384
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SelfPlayConfig(BaseModel):
|
|
26
|
+
"""Self-play engine configuration."""
|
|
27
|
+
|
|
28
|
+
min_turns: int = 4
|
|
29
|
+
max_turns: int = 8
|
|
30
|
+
memory_test_probability: float = 0.3
|
|
31
|
+
checkpoint_interval: int = 10
|
|
32
|
+
user_system_prompt: str = (
|
|
33
|
+
"## 절대 규칙\n"
|
|
34
|
+
"- 반드시 한국어로만 말하세요. 영어/중국어/일본어 문장 금지.\n"
|
|
35
|
+
"- 코드 블록(```) 절대 금지.\n\n"
|
|
36
|
+
"당신은 20~30대 한국인 역할입니다. 친구에게 자기 이야기를 하듯 말하세요.\n\n"
|
|
37
|
+
"## 반드시 지켜야 할 규칙\n"
|
|
38
|
+
"- 항상 '저는', '제가', '요즘 제가' 로 시작하세요.\n"
|
|
39
|
+
"- 자신의 경험, 취향, 습관, 감정을 구체적으로 말하세요.\n"
|
|
40
|
+
"- 상대방에게 질문하지 마세요. 물음표(?)를 쓰지 마세요.\n"
|
|
41
|
+
"- 추천, 조언, 설명을 하지 마세요.\n"
|
|
42
|
+
"- 괄호(), 메타 설명, 지시문을 절대 쓰지 마세요.\n"
|
|
43
|
+
"- 1~2문장으로 짧게 말하세요.\n\n"
|
|
44
|
+
"## 좋은 예시\n"
|
|
45
|
+
"- '저는 매일 아침 조깅을 해요. 한강 근처를 삼십 분 정도 뛰어요.'\n"
|
|
46
|
+
"- '제가 좋아하는 음식은 김치찌개예요. 엄마가 해주시는 게 제일 맛있어요.'\n"
|
|
47
|
+
"- '요즘 기타를 배우고 있어요. 아직 코드 세 개밖에 못 쳐요.'\n"
|
|
48
|
+
"\n## 대화 흐름 예시 (질문 없이 이야기하기)\n"
|
|
49
|
+
"어시스턴트: '어떤 음식을 좋아하세요?'\n"
|
|
50
|
+
"사용자: '저는 된장찌개를 제일 좋아해요. 엄마가 해주시는 게 최고예요.'\n"
|
|
51
|
+
"어시스턴트: '된장찌개를 좋아하시는군요! 자주 드시나요?'\n"
|
|
52
|
+
"사용자: '네, 일주일에 두세 번은 꼭 먹어요. 요즘은 직접 끓여 먹기도 해요.'\n"
|
|
53
|
+
)
|
|
54
|
+
assistant_system_prompt: str = (
|
|
55
|
+
"## 절대 규칙\n"
|
|
56
|
+
"- 반드시 한국어로만 답변하세요.\n"
|
|
57
|
+
"- 코드 블록 금지, 번호 목록/마크다운 금지.\n\n"
|
|
58
|
+
"당신은 한국어 AI 어시스턴트입니다.\n"
|
|
59
|
+
"- 사용자의 말에 공감하고 짧게 반응하세요.\n"
|
|
60
|
+
"- 1~2문장으로 답변하세요.\n"
|
|
61
|
+
"- 번호 목록이나 마크다운을 쓰지 마세요.\n"
|
|
62
|
+
"- 마지막에 질문 하나를 추가해서 사용자가 더 이야기하도록 유도하세요.\n"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class RewardConfig(BaseModel):
|
|
67
|
+
"""Reward calculation configuration."""
|
|
68
|
+
|
|
69
|
+
weights: dict[str, float] = Field(default_factory=lambda: {
|
|
70
|
+
"r1_keyword_reappearance": 1.0,
|
|
71
|
+
"r2_repeated_question_penalty": 1.0,
|
|
72
|
+
"r3_efficiency": 0.8,
|
|
73
|
+
"r4_retrieval_relevance": 1.2,
|
|
74
|
+
"r5_speech_act_weight": 1.0,
|
|
75
|
+
"r6_self_reference": 1.0,
|
|
76
|
+
"r7_info_density": 0.8,
|
|
77
|
+
"r8_preference_constraint": 1.2,
|
|
78
|
+
"r9_emotional_salience": 0.6,
|
|
79
|
+
"r10_topic_boundary": 1.0,
|
|
80
|
+
"r11_user_feedback": 1.0,
|
|
81
|
+
})
|
|
82
|
+
proper_noun_multiplier: float = 3.0
|
|
83
|
+
common_noun_multiplier: float = 0.3
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class DatasetConfig(BaseModel):
|
|
87
|
+
"""Dataset building configuration."""
|
|
88
|
+
|
|
89
|
+
context_window: int = 6 # number of recent turns for state
|
|
90
|
+
train_ratio: float = 0.8
|
|
91
|
+
val_ratio: float = 0.1
|
|
92
|
+
test_ratio: float = 0.1
|
|
93
|
+
random_seed: int = 42
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class OnlinePolicyConfig(BaseModel):
|
|
97
|
+
"""Online policy configuration (rule-based + MLP bandit)."""
|
|
98
|
+
|
|
99
|
+
# MLP bandit (existing OnlinePolicy params)
|
|
100
|
+
feature_dim: int = 10
|
|
101
|
+
hidden_dim: int = 64
|
|
102
|
+
n_actions: int = 3
|
|
103
|
+
lr: float = 0.01
|
|
104
|
+
epsilon: float = 0.1
|
|
105
|
+
|
|
106
|
+
# Rule-based thresholds
|
|
107
|
+
save_threshold: float = 0.7
|
|
108
|
+
skip_threshold: float = 0.1
|
|
109
|
+
|
|
110
|
+
# Importance weights
|
|
111
|
+
personal_weight: float = 0.4
|
|
112
|
+
preference_weight: float = 0.35
|
|
113
|
+
tech_weight: float = 0.3
|
|
114
|
+
emotion_weight: float = 0.2
|
|
115
|
+
keyword_weight: float = 0.15
|
|
116
|
+
|
|
117
|
+
# Retrieval
|
|
118
|
+
retrieve_top_k: int = 3
|
|
119
|
+
|
|
120
|
+
# Enhanced policy (opt-in)
|
|
121
|
+
use_enhanced_policy: bool = False
|
|
122
|
+
use_progressive_autonomy: bool = False
|
|
123
|
+
autonomy_confidence_threshold: int = 50
|
|
124
|
+
|
|
125
|
+
# Sentence-transformer model (GraphMemoryStore)
|
|
126
|
+
st_model: str = "intfloat/multilingual-e5-small"
|
|
127
|
+
|
|
128
|
+
# Language
|
|
129
|
+
language: str = "ko"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class DataPaths(BaseModel):
|
|
133
|
+
"""Data directory paths."""
|
|
134
|
+
|
|
135
|
+
root: Path = PROJECT_ROOT / "data"
|
|
136
|
+
raw_episodes: Path = PROJECT_ROOT / "data" / "raw" / "episodes"
|
|
137
|
+
splits: Path = PROJECT_ROOT / "data" / "splits"
|
|
138
|
+
embeddings: Path = PROJECT_ROOT / "data" / "embeddings"
|
|
139
|
+
|
|
140
|
+
def ensure_dirs(self) -> None:
|
|
141
|
+
self.raw_episodes.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
self.splits.mkdir(parents=True, exist_ok=True)
|
|
143
|
+
self.embeddings.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SecurityConfig(BaseModel):
|
|
147
|
+
"""Security filtering configuration."""
|
|
148
|
+
|
|
149
|
+
block_passwords: bool = True
|
|
150
|
+
block_api_keys: bool = True
|
|
151
|
+
block_medical_info: bool = True
|
|
152
|
+
require_source_turn_id: bool = False
|
|
153
|
+
respect_life_dignity: bool = True
|
|
154
|
+
no_harm_to_humans: bool = True
|
|
155
|
+
recognize_creator: bool = True
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ForgettingConfig(BaseModel):
|
|
159
|
+
"""Forgetting / decay configuration."""
|
|
160
|
+
|
|
161
|
+
decay_lambda: float = 0.05
|
|
162
|
+
threshold_compress: float = 0.3
|
|
163
|
+
threshold_deactivate: float = 0.1
|
|
164
|
+
deactivation_days: int = 30
|
|
165
|
+
related_boost: float = 0.1
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class SleepCycleConfig(BaseModel):
|
|
169
|
+
"""Sleep cycle (periodic maintenance) configuration."""
|
|
170
|
+
|
|
171
|
+
enable_consolidation: bool = True
|
|
172
|
+
enable_resolution_regen: bool = True
|
|
173
|
+
enable_forgetting: bool = True
|
|
174
|
+
enable_checkpoint: bool = True
|
|
175
|
+
consolidation_threshold: float = 0.92
|
|
176
|
+
max_consolidation_pairs: int = 50
|
|
177
|
+
forgetting_decay_lambda: float = 0.05
|
|
178
|
+
forgetting_threshold_compress: float = 0.3
|
|
179
|
+
forgetting_threshold_deactivate: float = 0.1
|
|
180
|
+
forgetting_deactivation_days: int = 30
|
|
181
|
+
forgetting_related_boost: float = 0.1
|
|
182
|
+
checkpoint_dir: str = "checkpoints/sleep_cycle"
|
|
183
|
+
report_dir: str = "data/reports/sleep_cycle"
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class ComposerConfig(BaseModel):
|
|
187
|
+
"""Context composer configuration."""
|
|
188
|
+
|
|
189
|
+
default_token_budget: int = 1024
|
|
190
|
+
top_k: int = 10
|
|
191
|
+
level0_avg_tokens: int = 60
|
|
192
|
+
level1_avg_tokens: int = 25
|
|
193
|
+
level2_avg_tokens: int = 10
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class GossipConfig(BaseModel):
|
|
197
|
+
"""P2P gossip protocol configuration."""
|
|
198
|
+
|
|
199
|
+
max_norm: float = 1.0
|
|
200
|
+
gossip_interval: int = 50
|
|
201
|
+
dp_epsilon: float = 1.0
|
|
202
|
+
dp_delta: float = 1e-5
|
|
203
|
+
dp_enabled: bool = True
|
|
204
|
+
transport_host: str = "0.0.0.0"
|
|
205
|
+
transport_port: int = 9400
|
|
206
|
+
rule_hash_verify: bool = True
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class ReRankerConfig(BaseModel):
|
|
211
|
+
"""RL Re-ranker configuration."""
|
|
212
|
+
|
|
213
|
+
# Feature extraction
|
|
214
|
+
feature_dim: int = 11 # 8 → 11 (기존 8 + graph 3)
|
|
215
|
+
use_graph_features: bool = False # True면 KG 피처 활성화
|
|
216
|
+
|
|
217
|
+
# Model
|
|
218
|
+
hidden_dim: int = 32
|
|
219
|
+
lr: float = 0.005
|
|
220
|
+
epsilon: float = 0.15
|
|
221
|
+
|
|
222
|
+
# Re-ranking
|
|
223
|
+
candidate_k: int = 10 # ChromaDB에서 가져올 후보 수
|
|
224
|
+
select_k: int = 3 # 리랭킹 후 선택할 수
|
|
225
|
+
|
|
226
|
+
# Latency budget
|
|
227
|
+
max_latency_ms: float = 20.0 # 최대 허용 리랭킹 지연 시간 (ms)
|
|
228
|
+
|
|
229
|
+
# Enable/disable
|
|
230
|
+
enabled: bool = True # False이면 ChromaDB 순서를 그대로 사용
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class MCPServerConfig(BaseModel):
|
|
234
|
+
"""MCP server configuration."""
|
|
235
|
+
|
|
236
|
+
persist_directory: str = "./memory_db"
|
|
237
|
+
collection_name: str = "memories"
|
|
238
|
+
embedding_model: str = "intfloat/multilingual-e5-small"
|
|
239
|
+
token_budget: int = 1024
|
|
240
|
+
top_k: int = 5
|
|
241
|
+
reranker_pool_size: int = 20
|
|
242
|
+
min_relevance: float = 0.6
|
|
243
|
+
policy_checkpoint: str | None = None
|
|
244
|
+
log_level: str = "INFO"
|
|
245
|
+
|
|
246
|
+
# Enhanced policy / GraphRAG (opt-in)
|
|
247
|
+
use_enhanced_policy: bool = False
|
|
248
|
+
use_graph_rag: bool = False
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class AppConfig(BaseModel):
|
|
252
|
+
"""Top-level application configuration."""
|
|
253
|
+
|
|
254
|
+
ollama: OllamaConfig = Field(default_factory=OllamaConfig)
|
|
255
|
+
selfplay: SelfPlayConfig = Field(default_factory=SelfPlayConfig)
|
|
256
|
+
reward: RewardConfig = Field(default_factory=RewardConfig)
|
|
257
|
+
dataset: DatasetConfig = Field(default_factory=DatasetConfig)
|
|
258
|
+
online_policy: OnlinePolicyConfig = Field(default_factory=OnlinePolicyConfig)
|
|
259
|
+
paths: DataPaths = Field(default_factory=DataPaths)
|
|
260
|
+
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
|
261
|
+
forgetting: ForgettingConfig = Field(default_factory=ForgettingConfig)
|
|
262
|
+
composer: ComposerConfig = Field(default_factory=ComposerConfig)
|
|
263
|
+
sleep_cycle: SleepCycleConfig = Field(default_factory=SleepCycleConfig)
|
|
264
|
+
gossip: GossipConfig = Field(default_factory=GossipConfig)
|
|
265
|
+
reranker: ReRankerConfig = Field(default_factory=ReRankerConfig)
|
|
266
|
+
mcp: MCPServerConfig = Field(default_factory=MCPServerConfig)
|
|
267
|
+
num_episodes: int = 1000
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Dataset module for the AI Memory System.
|
|
2
|
+
|
|
3
|
+
Provides Episode→SARTriple conversion, train/val/test splitting, and statistics.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from aimemory.dataset.builder import EpisodeBuilder
|
|
7
|
+
from aimemory.dataset.splitter import EpisodeSplitter, SplitResult
|
|
8
|
+
from aimemory.dataset.stats import DatasetStats, StatsComputer
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"EpisodeBuilder",
|
|
12
|
+
"EpisodeSplitter",
|
|
13
|
+
"SplitResult",
|
|
14
|
+
"DatasetStats",
|
|
15
|
+
"StatsComputer",
|
|
16
|
+
]
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""Episode → SARTriple conversion for RL training dataset.
|
|
2
|
+
|
|
3
|
+
Converts raw episodes to State-Action-Reward triples with:
|
|
4
|
+
- State: last 6 turns dialogue window + current memory summary
|
|
5
|
+
- Next state for TD learning
|
|
6
|
+
- Edge case handling (first/last turns, empty memory)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Iterator
|
|
15
|
+
|
|
16
|
+
from aimemory.config import DatasetConfig
|
|
17
|
+
from aimemory.schemas import (
|
|
18
|
+
Action,
|
|
19
|
+
Episode,
|
|
20
|
+
MemoryActionType,
|
|
21
|
+
RewardBreakdown,
|
|
22
|
+
SARTriple,
|
|
23
|
+
State,
|
|
24
|
+
Turn,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EpisodeBuilder:
|
|
31
|
+
"""Converts Episodes into SARTriple sequences for RL training."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, config: DatasetConfig | None = None) -> None:
|
|
34
|
+
self.config = config or DatasetConfig()
|
|
35
|
+
|
|
36
|
+
def _build_state(
|
|
37
|
+
self,
|
|
38
|
+
episode: Episode,
|
|
39
|
+
turn_id: int,
|
|
40
|
+
memory_entries_up_to_turn: list,
|
|
41
|
+
) -> State:
|
|
42
|
+
"""Build RL State from episode context at a given turn.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
episode: The source episode.
|
|
46
|
+
turn_id: Current turn index (0-based).
|
|
47
|
+
memory_entries_up_to_turn: Memory entries saved up to (but not including) this turn.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
State with recent dialogue window and memory summary.
|
|
51
|
+
"""
|
|
52
|
+
window = self.config.context_window # default 6
|
|
53
|
+
|
|
54
|
+
# Get the last `window` turns up to and including turn_id
|
|
55
|
+
start = max(0, turn_id - window + 1)
|
|
56
|
+
recent_turns: list[Turn] = episode.turns[start : turn_id + 1]
|
|
57
|
+
|
|
58
|
+
# Memory summary: list of content strings from saved entries
|
|
59
|
+
memory_summary = [entry.content for entry in memory_entries_up_to_turn]
|
|
60
|
+
|
|
61
|
+
# Normalized position in episode (0.0 ~ 1.0)
|
|
62
|
+
num_turns = max(1, len(episode.turns))
|
|
63
|
+
turn_position = turn_id / (num_turns - 1) if num_turns > 1 else 0.0
|
|
64
|
+
|
|
65
|
+
return State(
|
|
66
|
+
episode_id=episode.episode_id,
|
|
67
|
+
turn_id=turn_id,
|
|
68
|
+
recent_turns=recent_turns,
|
|
69
|
+
current_memory_summary=memory_summary,
|
|
70
|
+
memory_count=len(memory_entries_up_to_turn),
|
|
71
|
+
turn_position=turn_position,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _decision_to_action(self, decision) -> Action:
|
|
75
|
+
"""Convert a MemoryDecision to an RL Action."""
|
|
76
|
+
action_type = decision.action
|
|
77
|
+
|
|
78
|
+
saved_content = None
|
|
79
|
+
saved_keywords: list[str] = []
|
|
80
|
+
retrieved_count = 0
|
|
81
|
+
|
|
82
|
+
if action_type == MemoryActionType.SAVE and decision.memory_entry is not None:
|
|
83
|
+
saved_content = decision.memory_entry.content
|
|
84
|
+
saved_keywords = decision.memory_entry.keywords
|
|
85
|
+
|
|
86
|
+
if action_type == MemoryActionType.RETRIEVE:
|
|
87
|
+
retrieved_count = len(decision.retrieved_memories)
|
|
88
|
+
|
|
89
|
+
return Action(
|
|
90
|
+
action_type=action_type,
|
|
91
|
+
saved_content=saved_content,
|
|
92
|
+
saved_keywords=saved_keywords,
|
|
93
|
+
retrieved_count=retrieved_count,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def episode_to_sar_triples(
|
|
97
|
+
self,
|
|
98
|
+
episode: Episode,
|
|
99
|
+
reward_map: dict[int, RewardBreakdown] | None = None,
|
|
100
|
+
) -> list[SARTriple]:
|
|
101
|
+
"""Convert an Episode into a list of SARTriples.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
episode: Source episode to convert.
|
|
105
|
+
reward_map: Optional mapping of turn_id → RewardBreakdown.
|
|
106
|
+
If None, zero rewards are used (to be filled later).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of SARTriples, one per memory decision in the episode.
|
|
110
|
+
"""
|
|
111
|
+
if not episode.memory_decisions:
|
|
112
|
+
logger.warning("Episode %s has no memory decisions", episode.episode_id)
|
|
113
|
+
return []
|
|
114
|
+
|
|
115
|
+
# Build a turn_id → decision lookup
|
|
116
|
+
decision_by_turn: dict[int, object] = {
|
|
117
|
+
d.turn_id: d for d in episode.memory_decisions
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
# Track cumulative memory entries as we step through decisions
|
|
121
|
+
cumulative_memory: list = []
|
|
122
|
+
|
|
123
|
+
triples: list[SARTriple] = []
|
|
124
|
+
decisions_sorted = sorted(episode.memory_decisions, key=lambda d: d.turn_id)
|
|
125
|
+
|
|
126
|
+
for step_index, decision in enumerate(decisions_sorted):
|
|
127
|
+
turn_id = decision.turn_id
|
|
128
|
+
|
|
129
|
+
# Validate turn_id is within episode
|
|
130
|
+
if turn_id >= len(episode.turns):
|
|
131
|
+
logger.warning(
|
|
132
|
+
"Decision turn_id %d out of range for episode %s (len=%d)",
|
|
133
|
+
turn_id,
|
|
134
|
+
episode.episode_id,
|
|
135
|
+
len(episode.turns),
|
|
136
|
+
)
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
# State: memory entries before this decision
|
|
140
|
+
state = self._build_state(episode, turn_id, list(cumulative_memory))
|
|
141
|
+
|
|
142
|
+
# Action
|
|
143
|
+
action = self._decision_to_action(decision)
|
|
144
|
+
|
|
145
|
+
# Reward: use provided map or zero default
|
|
146
|
+
if reward_map and turn_id in reward_map:
|
|
147
|
+
reward = reward_map[turn_id]
|
|
148
|
+
else:
|
|
149
|
+
reward = RewardBreakdown()
|
|
150
|
+
|
|
151
|
+
# Update cumulative memory AFTER building current state
|
|
152
|
+
if decision.action == MemoryActionType.SAVE and decision.memory_entry:
|
|
153
|
+
cumulative_memory.append(decision.memory_entry)
|
|
154
|
+
|
|
155
|
+
# Next state: state at the next decision, or None if last
|
|
156
|
+
done = step_index == len(decisions_sorted) - 1
|
|
157
|
+
if not done:
|
|
158
|
+
next_decision = decisions_sorted[step_index + 1]
|
|
159
|
+
next_turn_id = next_decision.turn_id
|
|
160
|
+
if next_turn_id < len(episode.turns):
|
|
161
|
+
next_state = self._build_state(
|
|
162
|
+
episode, next_turn_id, list(cumulative_memory)
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
next_state = None
|
|
166
|
+
done = True
|
|
167
|
+
else:
|
|
168
|
+
next_state = None
|
|
169
|
+
|
|
170
|
+
triple = SARTriple(
|
|
171
|
+
episode_id=episode.episode_id,
|
|
172
|
+
step_index=step_index,
|
|
173
|
+
state=state,
|
|
174
|
+
action=action,
|
|
175
|
+
reward=reward,
|
|
176
|
+
next_state=next_state,
|
|
177
|
+
done=done,
|
|
178
|
+
)
|
|
179
|
+
triples.append(triple)
|
|
180
|
+
|
|
181
|
+
return triples
|
|
182
|
+
|
|
183
|
+
def build_from_episodes(
|
|
184
|
+
self,
|
|
185
|
+
episodes: list[Episode],
|
|
186
|
+
reward_maps: dict[str, dict[int, RewardBreakdown]] | None = None,
|
|
187
|
+
) -> list[SARTriple]:
|
|
188
|
+
"""Build SARTriples from a list of episodes.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
episodes: List of episodes to process.
|
|
192
|
+
reward_maps: Optional mapping episode_id → {turn_id → RewardBreakdown}.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Flat list of all SARTriples.
|
|
196
|
+
"""
|
|
197
|
+
all_triples: list[SARTriple] = []
|
|
198
|
+
for episode in episodes:
|
|
199
|
+
reward_map = None
|
|
200
|
+
if reward_maps:
|
|
201
|
+
reward_map = reward_maps.get(episode.episode_id)
|
|
202
|
+
triples = self.episode_to_sar_triples(episode, reward_map=reward_map)
|
|
203
|
+
all_triples.extend(triples)
|
|
204
|
+
logger.debug(
|
|
205
|
+
"Episode %s → %d triples", episode.episode_id, len(triples)
|
|
206
|
+
)
|
|
207
|
+
return all_triples
|
|
208
|
+
|
|
209
|
+
def iter_episodes_from_jsonl(self, path: Path) -> Iterator[Episode]:
|
|
210
|
+
"""Iterate over Episodes from a JSONL file."""
|
|
211
|
+
with open(path, encoding="utf-8") as f:
|
|
212
|
+
for line_num, line in enumerate(f, 1):
|
|
213
|
+
line = line.strip()
|
|
214
|
+
if not line:
|
|
215
|
+
continue
|
|
216
|
+
try:
|
|
217
|
+
data = json.loads(line)
|
|
218
|
+
yield Episode.model_validate(data)
|
|
219
|
+
except Exception as exc:
|
|
220
|
+
logger.error("Line %d parse error in %s: %s", line_num, path, exc)
|
|
221
|
+
|
|
222
|
+
def triples_to_parquet_rows(self, triples: list[SARTriple]) -> list[dict]:
|
|
223
|
+
"""Convert SARTriples to flat dicts suitable for Parquet storage.
|
|
224
|
+
|
|
225
|
+
Scalar fields become columns; nested structures become JSON string columns.
|
|
226
|
+
"""
|
|
227
|
+
rows = []
|
|
228
|
+
for t in triples:
|
|
229
|
+
row = {
|
|
230
|
+
"triple_id": t.triple_id,
|
|
231
|
+
"episode_id": t.episode_id,
|
|
232
|
+
"step_index": t.step_index,
|
|
233
|
+
"done": t.done,
|
|
234
|
+
# State scalars
|
|
235
|
+
"state_turn_id": t.state.turn_id,
|
|
236
|
+
"state_memory_count": t.state.memory_count,
|
|
237
|
+
"state_turn_position": t.state.turn_position,
|
|
238
|
+
# State nested as JSON
|
|
239
|
+
"state_recent_turns_json": json.dumps(
|
|
240
|
+
[turn.model_dump(mode="json") for turn in t.state.recent_turns]
|
|
241
|
+
),
|
|
242
|
+
"state_memory_summary_json": json.dumps(
|
|
243
|
+
t.state.current_memory_summary
|
|
244
|
+
),
|
|
245
|
+
# Action scalars
|
|
246
|
+
"action_type": t.action.action_type.value,
|
|
247
|
+
"action_retrieved_count": t.action.retrieved_count,
|
|
248
|
+
# Action nested as JSON
|
|
249
|
+
"action_saved_content": t.action.saved_content or "",
|
|
250
|
+
"action_saved_keywords_json": json.dumps(t.action.saved_keywords),
|
|
251
|
+
# Reward scalars
|
|
252
|
+
"reward_r1": t.reward.r1_keyword_reappearance,
|
|
253
|
+
"reward_r2": t.reward.r2_repeated_question_penalty,
|
|
254
|
+
"reward_r3": t.reward.r3_efficiency,
|
|
255
|
+
"reward_r4": t.reward.r4_retrieval_relevance,
|
|
256
|
+
"reward_r5": t.reward.r5_speech_act_weight,
|
|
257
|
+
"reward_r6": t.reward.r6_self_reference,
|
|
258
|
+
"reward_r7": t.reward.r7_info_density,
|
|
259
|
+
"reward_r8": t.reward.r8_preference_constraint,
|
|
260
|
+
"reward_r9": t.reward.r9_emotional_salience,
|
|
261
|
+
"reward_r10": t.reward.r10_topic_boundary,
|
|
262
|
+
"reward_r11": t.reward.r11_user_feedback,
|
|
263
|
+
"reward_total": t.reward.total,
|
|
264
|
+
# Next state (nullable)
|
|
265
|
+
"next_state_json": json.dumps(
|
|
266
|
+
t.next_state.model_dump(mode="json") if t.next_state else None
|
|
267
|
+
),
|
|
268
|
+
}
|
|
269
|
+
rows.append(row)
|
|
270
|
+
return rows
|