reme-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reme_ai/__init__.py +6 -0
- reme_ai/app.py +17 -0
- reme_ai/config/__init__.py +0 -0
- reme_ai/config/config_parser.py +6 -0
- reme_ai/constants/__init__.py +7 -0
- reme_ai/constants/common_constants.py +48 -0
- reme_ai/constants/language_constants.py +215 -0
- reme_ai/enumeration/__init__.py +0 -0
- reme_ai/enumeration/language_constants.py +215 -0
- reme_ai/react/__init__.py +1 -0
- reme_ai/react/simple_react_op.py +21 -0
- reme_ai/retrieve/__init__.py +2 -0
- reme_ai/retrieve/personal/__init__.py +17 -0
- reme_ai/retrieve/personal/extract_time_op.py +97 -0
- reme_ai/retrieve/personal/fuse_rerank_op.py +180 -0
- reme_ai/retrieve/personal/print_memory_op.py +131 -0
- reme_ai/retrieve/personal/read_message_op.py +52 -0
- reme_ai/retrieve/personal/retrieve_memory_op.py +13 -0
- reme_ai/retrieve/personal/semantic_rank_op.py +170 -0
- reme_ai/retrieve/personal/set_query_op.py +37 -0
- reme_ai/retrieve/task/__init__.py +4 -0
- reme_ai/retrieve/task/build_query_op.py +38 -0
- reme_ai/retrieve/task/merge_memory_op.py +27 -0
- reme_ai/retrieve/task/rerank_memory_op.py +149 -0
- reme_ai/retrieve/task/rewrite_memory_op.py +149 -0
- reme_ai/schema/__init__.py +1 -0
- reme_ai/schema/memory.py +144 -0
- reme_ai/summary/__init__.py +2 -0
- reme_ai/summary/personal/__init__.py +8 -0
- reme_ai/summary/personal/contra_repeat_op.py +143 -0
- reme_ai/summary/personal/get_observation_op.py +147 -0
- reme_ai/summary/personal/get_observation_with_time_op.py +165 -0
- reme_ai/summary/personal/get_reflection_subject_op.py +179 -0
- reme_ai/summary/personal/info_filter_op.py +177 -0
- reme_ai/summary/personal/load_today_memory_op.py +117 -0
- reme_ai/summary/personal/long_contra_repeat_op.py +210 -0
- reme_ai/summary/personal/update_insight_op.py +244 -0
- reme_ai/summary/task/__init__.py +10 -0
- reme_ai/summary/task/comparative_extraction_op.py +233 -0
- reme_ai/summary/task/failure_extraction_op.py +73 -0
- reme_ai/summary/task/memory_deduplication_op.py +163 -0
- reme_ai/summary/task/memory_validation_op.py +108 -0
- reme_ai/summary/task/pdf_preprocess_op_wrapper.py +50 -0
- reme_ai/summary/task/simple_comparative_summary_op.py +71 -0
- reme_ai/summary/task/simple_summary_op.py +67 -0
- reme_ai/summary/task/success_extraction_op.py +73 -0
- reme_ai/summary/task/trajectory_preprocess_op.py +76 -0
- reme_ai/summary/task/trajectory_segmentation_op.py +118 -0
- reme_ai/utils/__init__.py +0 -0
- reme_ai/utils/datetime_handler.py +345 -0
- reme_ai/utils/miner_u_pdf_processor.py +726 -0
- reme_ai/utils/op_utils.py +115 -0
- reme_ai/vector_store/__init__.py +6 -0
- reme_ai/vector_store/delete_memory_op.py +25 -0
- reme_ai/vector_store/recall_vector_store_op.py +36 -0
- reme_ai/vector_store/update_memory_freq_op.py +33 -0
- reme_ai/vector_store/update_memory_utility_op.py +32 -0
- reme_ai/vector_store/update_vector_store_op.py +32 -0
- reme_ai/vector_store/vector_store_action_op.py +55 -0
- reme_ai-0.1.0.dist-info/METADATA +218 -0
- reme_ai-0.1.0.dist-info/RECORD +65 -0
- reme_ai-0.1.0.dist-info/WHEEL +5 -0
- reme_ai-0.1.0.dist-info/entry_points.txt +2 -0
- reme_ai-0.1.0.dist-info/licenses/LICENSE +201 -0
- reme_ai-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,149 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
from flowllm import C, BaseLLMOp
|
6
|
+
from flowllm.enumeration.role import Role
|
7
|
+
from flowllm.schema.message import Message
|
8
|
+
from loguru import logger
|
9
|
+
|
10
|
+
from reme_ai.schema.memory import BaseMemory
|
11
|
+
|
12
|
+
|
13
|
+
@C.register_op()
|
14
|
+
class RerankMemoryOp(BaseLLMOp):
|
15
|
+
"""
|
16
|
+
Rerank and filter recalled experiences using LLM and score-based filtering
|
17
|
+
"""
|
18
|
+
file_path: str = __file__
|
19
|
+
|
20
|
+
def execute(self):
|
21
|
+
"""Execute rerank operation"""
|
22
|
+
memory_list: List[BaseMemory] = self.context.response.metadata["memory_list"]
|
23
|
+
retrieval_query: str = self.context.query
|
24
|
+
enable_llm_rerank = self.op_params.get("enable_llm_rerank", True)
|
25
|
+
enable_score_filter = self.op_params.get("enable_score_filter", False)
|
26
|
+
min_score_threshold = self.op_params.get("min_score_threshold", 0.3)
|
27
|
+
top_k = self.op_params.get("top_k", 5)
|
28
|
+
|
29
|
+
logger.info(f"top_k: {top_k}")
|
30
|
+
|
31
|
+
if not memory_list:
|
32
|
+
logger.info("No recalled memory_list to rerank")
|
33
|
+
return
|
34
|
+
|
35
|
+
logger.info(f"Reranking {len(memory_list)} memories")
|
36
|
+
|
37
|
+
# Step 1: LLM reranking (optional)
|
38
|
+
if enable_llm_rerank:
|
39
|
+
memory_list = self._llm_rerank(retrieval_query, memory_list)
|
40
|
+
logger.info(f"After LLM reranking: {len(memory_list)} memories")
|
41
|
+
|
42
|
+
# Step 2: Score-based filtering (optional)
|
43
|
+
if enable_score_filter:
|
44
|
+
memory_list = self._score_based_filter(memory_list, min_score_threshold)
|
45
|
+
logger.info(f"After score filtering: {len(memory_list)} memories")
|
46
|
+
|
47
|
+
# Step 3: Return top-k results
|
48
|
+
reranked_memories = memory_list[:top_k]
|
49
|
+
logger.info(f"Final reranked results: {len(reranked_memories)} memories")
|
50
|
+
|
51
|
+
# Store results in context
|
52
|
+
self.context.response.metadata["memory_list"] = reranked_memories
|
53
|
+
|
54
|
+
def _llm_rerank(self, query: str, candidates: List[BaseMemory]) -> List[BaseMemory]:
|
55
|
+
"""LLM-based reranking of candidate experiences"""
|
56
|
+
if not candidates:
|
57
|
+
return candidates
|
58
|
+
|
59
|
+
# Format candidates for LLM evaluation
|
60
|
+
candidates_text = self._format_candidates_for_rerank(candidates)
|
61
|
+
|
62
|
+
prompt = self.prompt_format(
|
63
|
+
prompt_name="memory_rerank_prompt",
|
64
|
+
query=query,
|
65
|
+
candidates=candidates_text,
|
66
|
+
num_candidates=len(candidates))
|
67
|
+
|
68
|
+
response = self.llm.chat([Message(role=Role.USER, content=prompt)])
|
69
|
+
|
70
|
+
# Parse reranking results
|
71
|
+
reranked_indices = self._parse_rerank_response(response.content)
|
72
|
+
|
73
|
+
# Reorder candidates based on LLM ranking
|
74
|
+
if reranked_indices:
|
75
|
+
reranked_candidates = []
|
76
|
+
for idx in reranked_indices:
|
77
|
+
if 0 <= idx < len(candidates):
|
78
|
+
reranked_candidates.append(candidates[idx])
|
79
|
+
|
80
|
+
# Add any remaining candidates that weren't explicitly ranked
|
81
|
+
ranked_indices_set = set(reranked_indices)
|
82
|
+
for i, candidate in enumerate(candidates):
|
83
|
+
if i not in ranked_indices_set:
|
84
|
+
reranked_candidates.append(candidate)
|
85
|
+
|
86
|
+
return reranked_candidates
|
87
|
+
|
88
|
+
return candidates
|
89
|
+
|
90
|
+
@staticmethod
|
91
|
+
def _score_based_filter(memories: List[BaseMemory], min_score: float) -> List[BaseMemory]:
|
92
|
+
"""Filter memories based on quality scores"""
|
93
|
+
filtered_memories = []
|
94
|
+
|
95
|
+
for memory in memories:
|
96
|
+
# Get confidence score from metadata
|
97
|
+
confidence = memory.metadata.get("confidence", 0.5)
|
98
|
+
validation_score = memory.score or 0.5
|
99
|
+
|
100
|
+
# Calculate combined score
|
101
|
+
combined_score = (confidence + validation_score) / 2
|
102
|
+
|
103
|
+
if combined_score >= min_score:
|
104
|
+
filtered_memories.append(memory)
|
105
|
+
else:
|
106
|
+
logger.debug(f"Filtered out memory with score {combined_score:.2f}")
|
107
|
+
|
108
|
+
logger.info(f"Score filtering: {len(filtered_memories)}/{len(memories)} memories retained")
|
109
|
+
return filtered_memories
|
110
|
+
|
111
|
+
@staticmethod
|
112
|
+
def _format_candidates_for_rerank(candidates: List[BaseMemory]) -> str:
|
113
|
+
"""Format candidates for LLM reranking"""
|
114
|
+
formatted_candidates = []
|
115
|
+
|
116
|
+
for i, candidate in enumerate(candidates):
|
117
|
+
condition = candidate.when_to_use
|
118
|
+
content = candidate.content
|
119
|
+
|
120
|
+
candidate_text = f"Candidate {i}:\n"
|
121
|
+
candidate_text += f"Condition: {condition}\n"
|
122
|
+
candidate_text += f"Experience: {content}\n"
|
123
|
+
|
124
|
+
formatted_candidates.append(candidate_text)
|
125
|
+
|
126
|
+
return "\n---\n".join(formatted_candidates)
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def _parse_rerank_response(response: str) -> List[int]:
|
130
|
+
"""Parse LLM reranking response to extract ranked indices"""
|
131
|
+
try:
|
132
|
+
# Try to extract JSON format
|
133
|
+
json_pattern = r'```json\s*([\s\S]*?)\s*```'
|
134
|
+
json_blocks = re.findall(json_pattern, response)
|
135
|
+
|
136
|
+
if json_blocks:
|
137
|
+
parsed = json.loads(json_blocks[0])
|
138
|
+
if isinstance(parsed, dict) and "ranked_indices" in parsed:
|
139
|
+
return parsed["ranked_indices"]
|
140
|
+
elif isinstance(parsed, list):
|
141
|
+
return parsed
|
142
|
+
|
143
|
+
# Try to extract numbers from text
|
144
|
+
numbers = re.findall(r'\b\d+\b', response)
|
145
|
+
return [int(num) for num in numbers if int(num) < 100] # Reasonable upper bound
|
146
|
+
|
147
|
+
except Exception as e:
|
148
|
+
logger.error(f"Error parsing rerank response: {e}")
|
149
|
+
return []
|
@@ -0,0 +1,149 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
from flowllm import C, BaseLLMOp
|
6
|
+
from flowllm.enumeration.role import Role
|
7
|
+
from flowllm.schema.message import Message
|
8
|
+
from loguru import logger
|
9
|
+
|
10
|
+
from reme_ai.schema.memory import BaseMemory
|
11
|
+
|
12
|
+
|
13
|
+
@C.register_op()
|
14
|
+
class RewriteMemoryOp(BaseLLMOp):
|
15
|
+
"""
|
16
|
+
Generate and rewrite context messages from reranked experiences
|
17
|
+
"""
|
18
|
+
file_path: str = __file__
|
19
|
+
|
20
|
+
def execute(self):
|
21
|
+
"""Execute rewrite operation"""
|
22
|
+
memory_list: List[BaseMemory] = self.context.response.metadata["memory_list"]
|
23
|
+
query: str = self.context.query
|
24
|
+
messages: List[Message] = \
|
25
|
+
[Message(**x) if isinstance(x, dict) else x for x in self.context.get('messages', [])]
|
26
|
+
|
27
|
+
if not memory_list:
|
28
|
+
logger.info("No reranked memories to rewrite")
|
29
|
+
self.context.response.answer = ""
|
30
|
+
return
|
31
|
+
|
32
|
+
logger.info(f"Generating context from {len(memory_list)} memories")
|
33
|
+
|
34
|
+
# Generate initial context message
|
35
|
+
rewritten_memory = self._generate_context_message(query, messages, memory_list)
|
36
|
+
|
37
|
+
# Store results in context
|
38
|
+
self.context.response.answer = rewritten_memory
|
39
|
+
self.context.response.metadata["memory_list"] = [memory.model_dump() for memory in memory_list]
|
40
|
+
|
41
|
+
def _generate_context_message(self, query: str, messages: List[Message], memories: List[BaseMemory]) -> str:
|
42
|
+
"""Generate context message from retrieved memories"""
|
43
|
+
if not memories:
|
44
|
+
return ""
|
45
|
+
|
46
|
+
try:
|
47
|
+
logger.info("memories")
|
48
|
+
# Format retrieved memories
|
49
|
+
formatted_memories = self._format_memories_for_context(memories)
|
50
|
+
|
51
|
+
if self.op_params.get("enable_llm_rewrite", True):
|
52
|
+
context_content = self._rewrite_context(query, formatted_memories, messages)
|
53
|
+
else:
|
54
|
+
context_content = formatted_memories
|
55
|
+
|
56
|
+
return context_content
|
57
|
+
|
58
|
+
except Exception as e:
|
59
|
+
logger.error(f"Error generating context message: {e}")
|
60
|
+
return self._format_memories_for_context(memories)
|
61
|
+
|
62
|
+
def _rewrite_context(self, query: str, context_content: str, messages: List[Message]) -> str:
|
63
|
+
"""LLM-based context rewriting to make experiences more relevant and actionable"""
|
64
|
+
if not context_content:
|
65
|
+
return context_content
|
66
|
+
|
67
|
+
try:
|
68
|
+
# Extract current context
|
69
|
+
current_context = self._extract_context(messages)
|
70
|
+
|
71
|
+
prompt = self.prompt_format(
|
72
|
+
prompt_name="memory_rewrite_prompt",
|
73
|
+
current_query=query,
|
74
|
+
current_context=current_context,
|
75
|
+
original_context=context_content)
|
76
|
+
|
77
|
+
response = self.llm.chat([Message(role=Role.USER, content=prompt)])
|
78
|
+
|
79
|
+
# Extract rewritten context
|
80
|
+
rewritten_context = self._parse_json_response(response.content, "rewritten_context")
|
81
|
+
|
82
|
+
if rewritten_context and rewritten_context.strip():
|
83
|
+
logger.info("Context successfully rewritten for current task")
|
84
|
+
return rewritten_context.strip()
|
85
|
+
|
86
|
+
return context_content
|
87
|
+
|
88
|
+
except Exception as e:
|
89
|
+
logger.error(f"Error in context rewriting: {e}")
|
90
|
+
return context_content
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def _format_memories_for_context(memories: List[BaseMemory]) -> str:
|
94
|
+
"""Format memories for context generation"""
|
95
|
+
formatted_memories = []
|
96
|
+
|
97
|
+
for i, memory in enumerate(memories, 1):
|
98
|
+
condition = memory.when_to_use
|
99
|
+
memory_content = memory.content
|
100
|
+
memory_text = f"Memory {i} :\n When to use: {condition}\n Content: {memory_content}\n"
|
101
|
+
|
102
|
+
formatted_memories.append(memory_text)
|
103
|
+
|
104
|
+
return "\n".join(formatted_memories)
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def _extract_context(messages: List[Message]) -> str:
|
108
|
+
"""Extract relevant context from messages"""
|
109
|
+
if not messages:
|
110
|
+
return ""
|
111
|
+
|
112
|
+
context_parts = []
|
113
|
+
|
114
|
+
# Add recent messages if available
|
115
|
+
recent_messages = messages[-3:] # Last 3 messages
|
116
|
+
message_summaries = []
|
117
|
+
for message in recent_messages:
|
118
|
+
content = message.content[:300] + "..." if len(message.content) > 300 else message.content
|
119
|
+
message_summaries.append(f"- {message.role.value}: {content}")
|
120
|
+
|
121
|
+
if message_summaries:
|
122
|
+
context_parts.append("Recent conversation:\n" + "\n".join(message_summaries))
|
123
|
+
|
124
|
+
return "\n\n".join(context_parts)
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def _parse_json_response(response: str, key: str) -> str:
|
128
|
+
"""Parse JSON response to extract specific key"""
|
129
|
+
try:
|
130
|
+
# Try to extract JSON blocks
|
131
|
+
json_pattern = r'```json\s*([\s\S]*?)\s*```'
|
132
|
+
json_blocks = re.findall(json_pattern, response)
|
133
|
+
|
134
|
+
if json_blocks:
|
135
|
+
parsed = json.loads(json_blocks[0])
|
136
|
+
if isinstance(parsed, dict) and key in parsed:
|
137
|
+
return parsed[key]
|
138
|
+
|
139
|
+
# Fallback: try to parse the entire response as JSON
|
140
|
+
parsed = json.loads(response)
|
141
|
+
if isinstance(parsed, dict) and key in parsed:
|
142
|
+
return parsed[key]
|
143
|
+
|
144
|
+
except json.JSONDecodeError:
|
145
|
+
logger.warning(f"Failed to parse JSON response for key '{key}', using raw response")
|
146
|
+
# If JSON parsing fails, return the response as-is for fallback
|
147
|
+
return response.strip()
|
148
|
+
|
149
|
+
return ""
|
@@ -0,0 +1 @@
|
|
1
|
+
from flowllm.schema.message import Message, Role, Trajectory # noqa
|
reme_ai/schema/memory.py
ADDED
@@ -0,0 +1,144 @@
|
|
1
|
+
import datetime
|
2
|
+
from abc import ABC
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from flowllm.schema.vector_node import VectorNode
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
|
8
|
+
|
9
|
+
class BaseMemory(BaseModel, ABC):
|
10
|
+
workspace_id: str = Field(default="")
|
11
|
+
memory_id: str = Field(default_factory=lambda: uuid4().hex)
|
12
|
+
memory_type: str = Field(default=...)
|
13
|
+
|
14
|
+
when_to_use: str = Field(default="")
|
15
|
+
content: str | bytes = Field(default="")
|
16
|
+
score: float = Field(default=0)
|
17
|
+
|
18
|
+
time_created: str = Field(default_factory=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
19
|
+
time_modified: str = Field(default_factory=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
20
|
+
author: str = Field(default="")
|
21
|
+
|
22
|
+
metadata: dict = Field(default_factory=dict)
|
23
|
+
|
24
|
+
def update_modified_time(self):
|
25
|
+
self.time_modified = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
26
|
+
|
27
|
+
def update_metadata(self, new_metadata):
|
28
|
+
self.metadata = new_metadata
|
29
|
+
|
30
|
+
def to_vector_node(self) -> VectorNode:
|
31
|
+
raise NotImplementedError
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def from_vector_node(cls, node: VectorNode):
|
35
|
+
raise NotImplementedError
|
36
|
+
|
37
|
+
|
38
|
+
class TaskMemory(BaseMemory):
|
39
|
+
memory_type: str = Field(default="task")
|
40
|
+
|
41
|
+
def to_vector_node(self) -> VectorNode:
|
42
|
+
return VectorNode(unique_id=self.memory_id,
|
43
|
+
workspace_id=self.workspace_id,
|
44
|
+
content=self.when_to_use,
|
45
|
+
metadata={
|
46
|
+
"memory_type": self.memory_type,
|
47
|
+
"content": self.content,
|
48
|
+
"score": self.score,
|
49
|
+
"time_created": self.time_created,
|
50
|
+
"time_modified": self.time_modified,
|
51
|
+
"author": self.author,
|
52
|
+
"metadata": self.metadata,
|
53
|
+
})
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def from_vector_node(cls, node: VectorNode) -> "TaskMemory":
|
57
|
+
metadata = node.metadata.copy()
|
58
|
+
return cls(workspace_id=node.workspace_id,
|
59
|
+
memory_id=node.unique_id,
|
60
|
+
memory_type=metadata.pop("memory_type"),
|
61
|
+
when_to_use=node.content,
|
62
|
+
content=metadata.pop("content"),
|
63
|
+
score=metadata.pop("score"),
|
64
|
+
time_created=metadata.pop("time_created"),
|
65
|
+
time_modified=metadata.pop("time_modified"),
|
66
|
+
author=metadata.pop("author"),
|
67
|
+
metadata=metadata.pop("metadata", {}))
|
68
|
+
|
69
|
+
|
70
|
+
class PersonalMemory(BaseMemory):
|
71
|
+
memory_type: str = Field(default="personal")
|
72
|
+
target: str = Field(default="")
|
73
|
+
reflection_subject: str = Field(default="") # For storing reflection subject attributes
|
74
|
+
|
75
|
+
def to_vector_node(self) -> VectorNode:
|
76
|
+
return VectorNode(unique_id=self.memory_id,
|
77
|
+
workspace_id=self.workspace_id,
|
78
|
+
content=self.when_to_use,
|
79
|
+
metadata={
|
80
|
+
"memory_type": self.memory_type,
|
81
|
+
"content": self.content,
|
82
|
+
"target": self.target,
|
83
|
+
"reflection_subject": self.reflection_subject,
|
84
|
+
"score": self.score,
|
85
|
+
"time_created": self.time_created,
|
86
|
+
"time_modified": self.time_modified,
|
87
|
+
"author": self.author,
|
88
|
+
"metadata": self.metadata,
|
89
|
+
})
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def from_vector_node(cls, node: VectorNode) -> "PersonalMemory":
|
93
|
+
metadata = node.metadata.copy()
|
94
|
+
return cls(workspace_id=node.workspace_id,
|
95
|
+
memory_id=node.unique_id,
|
96
|
+
memory_type=metadata.pop("memory_type"),
|
97
|
+
when_to_use=node.content,
|
98
|
+
content=metadata.pop("content"),
|
99
|
+
target=metadata.pop("target", ""),
|
100
|
+
reflection_subject=metadata.pop("reflection_subject", ""),
|
101
|
+
score=metadata.pop("score"),
|
102
|
+
time_created=metadata.pop("time_created"),
|
103
|
+
time_modified=metadata.pop("time_modified"),
|
104
|
+
author=metadata.pop("author"),
|
105
|
+
metadata=metadata.pop("metadata", {}))
|
106
|
+
|
107
|
+
|
108
|
+
def vector_node_to_memory(node: VectorNode) -> BaseMemory:
|
109
|
+
memory_type = node.metadata.get("memory_type")
|
110
|
+
if memory_type == "task":
|
111
|
+
return TaskMemory.from_vector_node(node)
|
112
|
+
|
113
|
+
elif memory_type == "personal":
|
114
|
+
return PersonalMemory.from_vector_node(node)
|
115
|
+
|
116
|
+
else:
|
117
|
+
raise RuntimeError(f"memory_type={memory_type} not supported!")
|
118
|
+
|
119
|
+
|
120
|
+
def dict_to_memory(memory_dict: dict):
|
121
|
+
memory_type = memory_dict.get("memory_type", "task")
|
122
|
+
if memory_type == "task":
|
123
|
+
return TaskMemory(**memory_dict)
|
124
|
+
|
125
|
+
elif memory_type == "personal":
|
126
|
+
return PersonalMemory(**memory_dict)
|
127
|
+
|
128
|
+
else:
|
129
|
+
raise RuntimeError(f"memory_type={memory_type} not supported!")
|
130
|
+
|
131
|
+
|
132
|
+
if __name__ == "__main__":
|
133
|
+
e1 = TaskMemory(
|
134
|
+
workspace_id="w_1024",
|
135
|
+
memory_id="123",
|
136
|
+
when_to_use="test case use",
|
137
|
+
content="test content",
|
138
|
+
score=0.99,
|
139
|
+
metadata={})
|
140
|
+
print(e1.model_dump_json(indent=2))
|
141
|
+
v1 = e1.to_vector_node()
|
142
|
+
print(v1.model_dump_json(indent=2))
|
143
|
+
e2 = vector_node_to_memory(v1)
|
144
|
+
print(e2.model_dump_json(indent=2))
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from .contra_repeat_op import ContraRepeatOp
|
2
|
+
from .get_observation_op import GetObservationOp
|
3
|
+
from .get_observation_with_time_op import GetObservationWithTimeOp
|
4
|
+
from .get_reflection_subject_op import GetReflectionSubjectOp
|
5
|
+
from .info_filter_op import InfoFilterOp
|
6
|
+
from .load_today_memory_op import LoadTodayMemoryOp
|
7
|
+
from .long_contra_repeat_op import LongContraRepeatOp
|
8
|
+
from .update_insight_op import UpdateInsightOp
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import List, Tuple
|
4
|
+
|
5
|
+
from flowllm import C, BaseLLMOp
|
6
|
+
from flowllm.enumeration.role import Role
|
7
|
+
from flowllm.schema.message import Message
|
8
|
+
from loguru import logger
|
9
|
+
|
10
|
+
from reme_ai.schema.memory import BaseMemory
|
11
|
+
|
12
|
+
|
13
|
+
@C.register_op()
|
14
|
+
class ContraRepeatOp(BaseLLMOp):
|
15
|
+
"""
|
16
|
+
The `ContraRepeatOp` class specializes in processing memory nodes to identify and handle
|
17
|
+
contradictory and repetitive information. It extends the base functionality of `BaseLLMOp`.
|
18
|
+
|
19
|
+
Responsibilities:
|
20
|
+
- Collects observation memories from context.
|
21
|
+
- Constructs a prompt with these observations for language model analysis.
|
22
|
+
- Parses the model's response to detect contradictions or redundancies.
|
23
|
+
- Filters and returns the processed memories.
|
24
|
+
"""
|
25
|
+
file_path: str = __file__
|
26
|
+
|
27
|
+
def execute(self):
|
28
|
+
"""
|
29
|
+
Executes the primary routine of the ContraRepeatOp which involves:
|
30
|
+
1. Gets memory list from context
|
31
|
+
2. Constructs a prompt with these memories for language model analysis
|
32
|
+
3. Parses the model's response to detect contradictions or redundancies
|
33
|
+
4. Filters and returns the processed memories
|
34
|
+
"""
|
35
|
+
# Get memory list from context - standardized key
|
36
|
+
memory_list: List[BaseMemory] = []
|
37
|
+
memory_list.extend(self.context.get("observation_memories", []))
|
38
|
+
memory_list.extend(self.context.get("observation_memories_with_time", []))
|
39
|
+
memory_list.extend(self.context.get("today_memories", []))
|
40
|
+
|
41
|
+
self.context.response.metadata["memory_list"] = memory_list
|
42
|
+
|
43
|
+
if not memory_list:
|
44
|
+
logger.info("memory_list is empty!")
|
45
|
+
self.context.response.metadata["deleted_memory_ids"] = []
|
46
|
+
return
|
47
|
+
|
48
|
+
# Get operation parameters
|
49
|
+
contra_repeat_max_count: int = self.op_params.get("contra_repeat_max_count", 50)
|
50
|
+
enable_contra_repeat: bool = self.op_params.get("enable_contra_repeat", True)
|
51
|
+
|
52
|
+
if not enable_contra_repeat:
|
53
|
+
logger.warning("contra_repeat is not enabled!")
|
54
|
+
self.context.response.metadata["deleted_memory_ids"] = []
|
55
|
+
return
|
56
|
+
|
57
|
+
# Sort and limit memories by count
|
58
|
+
sorted_memories = sorted(memory_list, key=lambda x: x.time_created, reverse=True)[:contra_repeat_max_count]
|
59
|
+
|
60
|
+
if len(sorted_memories) <= 1:
|
61
|
+
logger.info("sorted_memories.size<=1, stop.")
|
62
|
+
self.context.response.metadata["memory_list"] = sorted_memories
|
63
|
+
self.context.response.metadata["deleted_memory_ids"] = []
|
64
|
+
return
|
65
|
+
|
66
|
+
# Build prompt
|
67
|
+
user_query_list = []
|
68
|
+
for i, memory in enumerate(sorted_memories):
|
69
|
+
user_query_list.append(f"{i + 1} {memory.content}")
|
70
|
+
|
71
|
+
user_name = self.context.get("user_name", "user")
|
72
|
+
|
73
|
+
# Create prompt using the new pattern
|
74
|
+
system_prompt = self.prompt_format(prompt_name="contra_repeat_system",
|
75
|
+
num_obs=len(user_query_list),
|
76
|
+
user_name=user_name)
|
77
|
+
few_shot = self.prompt_format(prompt_name="contra_repeat_few_shot", user_name=user_name)
|
78
|
+
user_query = self.prompt_format(prompt_name="contra_repeat_user_query",
|
79
|
+
user_query="\n".join(user_query_list))
|
80
|
+
|
81
|
+
full_prompt = f"{system_prompt}\n\n{few_shot}\n\n{user_query}"
|
82
|
+
logger.info(f"contra_repeat_prompt={full_prompt}")
|
83
|
+
|
84
|
+
# Call LLM
|
85
|
+
response = self.llm.chat([Message(role=Role.USER, content=full_prompt)])
|
86
|
+
|
87
|
+
# Return if empty
|
88
|
+
if not response or not response.content:
|
89
|
+
logger.warning("Empty response from LLM")
|
90
|
+
self.context.response.metadata["memory_list"] = sorted_memories
|
91
|
+
self.context.response.metadata["deleted_memory_ids"] = []
|
92
|
+
return
|
93
|
+
|
94
|
+
response_text = response.content
|
95
|
+
logger.info(f"contra_repeat_response={response_text}")
|
96
|
+
|
97
|
+
# Parse response and filter memories
|
98
|
+
filtered_memories, deleted_memory_ids = self._parse_and_filter_memories(response_text, sorted_memories)
|
99
|
+
|
100
|
+
# Update context with filtered memories and deleted memory IDs - standardized keys
|
101
|
+
self.context.response.metadata["memory_list"] = filtered_memories
|
102
|
+
self.context.response.metadata["deleted_memory_ids"] = deleted_memory_ids
|
103
|
+
logger.info(f"Filtered {len(memory_list)} memories to {len(filtered_memories)} memories")
|
104
|
+
logger.info(f"Deleted memory IDs: {json.dumps(deleted_memory_ids, indent=2)}")
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def _parse_and_filter_memories(response_text: str, memories: List[BaseMemory]) -> Tuple[
|
108
|
+
List[BaseMemory], List[str]]:
|
109
|
+
"""Parse LLM response and filter memories based on contradiction/containment analysis"""
|
110
|
+
|
111
|
+
# Parse the response to extract judgments
|
112
|
+
pattern = r"<(\d+)>\s*<(矛盾|被包含|无|Contradiction|Contained|None)>"
|
113
|
+
matches = re.findall(pattern, response_text, re.IGNORECASE)
|
114
|
+
|
115
|
+
if not matches:
|
116
|
+
logger.warning("No valid judgments found in response")
|
117
|
+
return memories, []
|
118
|
+
|
119
|
+
# Create a set of indices to remove (contradictory or contained memories)
|
120
|
+
indices_to_remove = set()
|
121
|
+
deleted_memory_ids = []
|
122
|
+
|
123
|
+
for idx_str, judgment in matches:
|
124
|
+
try:
|
125
|
+
idx = int(idx_str) - 1 # Convert to 0-based index
|
126
|
+
if idx >= len(memories):
|
127
|
+
logger.warning(f"Invalid index {idx} for memories list of length {len(memories)}")
|
128
|
+
continue
|
129
|
+
|
130
|
+
judgment_lower = judgment.lower()
|
131
|
+
if judgment_lower in ['矛盾', 'contradiction', '被包含', 'contained']:
|
132
|
+
indices_to_remove.add(idx)
|
133
|
+
deleted_memory_ids.append(memories[idx].memory_id)
|
134
|
+
logger.info(f"Marking memory {idx + 1} for removal: {judgment} - {memories[idx].content[:100]}...")
|
135
|
+
|
136
|
+
except ValueError:
|
137
|
+
logger.warning(f"Invalid index format: {idx_str}")
|
138
|
+
continue
|
139
|
+
|
140
|
+
# Filter out the memories marked for removal
|
141
|
+
filtered_memories = [memory for i, memory in enumerate(memories) if i not in indices_to_remove]
|
142
|
+
|
143
|
+
return filtered_memories, deleted_memory_ids
|