airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,663 @@
|
|
1
|
+
"""
|
2
|
+
Memory components for AirTrain agents.
|
3
|
+
|
4
|
+
This module provides memory systems for agents, including short-term,
|
5
|
+
long-term, and shared memory implementations.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Dict, List, Any, Optional
|
9
|
+
import json
|
10
|
+
import os
|
11
|
+
import uuid
|
12
|
+
from datetime import datetime
|
13
|
+
from pathlib import Path
|
14
|
+
|
15
|
+
|
16
|
+
class BaseMemory:
|
17
|
+
"""Base class for all memory types."""
|
18
|
+
|
19
|
+
def __init__(self, name: Optional[str] = None):
|
20
|
+
"""
|
21
|
+
Initialize a memory instance.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
name: Optional name for the memory instance
|
25
|
+
"""
|
26
|
+
self.name = name or self.__class__.__name__
|
27
|
+
self.messages = []
|
28
|
+
|
29
|
+
def add(self, message: Dict[str, Any]):
|
30
|
+
"""
|
31
|
+
Add a message to memory.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
message: Message dictionary to add to memory
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
self for method chaining
|
38
|
+
"""
|
39
|
+
# Add timestamp if not present
|
40
|
+
if "timestamp" not in message:
|
41
|
+
message["timestamp"] = datetime.now().isoformat()
|
42
|
+
|
43
|
+
self.messages.append(message)
|
44
|
+
return self
|
45
|
+
|
46
|
+
def clear(self):
|
47
|
+
"""
|
48
|
+
Clear all messages from memory.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
self for method chaining
|
52
|
+
"""
|
53
|
+
self.messages = []
|
54
|
+
return self
|
55
|
+
|
56
|
+
def get_messages(self, limit: Optional[int] = None):
|
57
|
+
"""
|
58
|
+
Get messages with optional limit.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
limit: Maximum number of messages to return (from most recent)
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
List of message dictionaries
|
65
|
+
"""
|
66
|
+
if limit is None or limit <= 0:
|
67
|
+
return self.messages
|
68
|
+
return self.messages[-limit:]
|
69
|
+
|
70
|
+
def to_dict(self):
|
71
|
+
"""
|
72
|
+
Convert memory to a dictionary for serialization.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Dictionary representation of memory
|
76
|
+
"""
|
77
|
+
return {
|
78
|
+
"name": self.name,
|
79
|
+
"type": self.__class__.__name__,
|
80
|
+
"messages": self.messages
|
81
|
+
}
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def from_dict(cls, data: Dict[str, Any]):
|
85
|
+
"""
|
86
|
+
Create a memory instance from a dictionary.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
data: Dictionary representation of memory
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
Memory instance
|
93
|
+
"""
|
94
|
+
instance = cls(name=data.get("name"))
|
95
|
+
instance.messages = data.get("messages", [])
|
96
|
+
return instance
|
97
|
+
|
98
|
+
|
99
|
+
class ShortTermMemory(BaseMemory):
|
100
|
+
"""Short-term memory with automatic summarization capability."""
|
101
|
+
|
102
|
+
def __init__(self, name: Optional[str] = None, max_messages: int = 10):
|
103
|
+
"""
|
104
|
+
Initialize short-term memory.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
name: Optional name for the memory instance
|
108
|
+
max_messages: Maximum number of messages to keep before summarizing
|
109
|
+
"""
|
110
|
+
super().__init__(name)
|
111
|
+
self.max_messages = max_messages
|
112
|
+
self.summaries = []
|
113
|
+
|
114
|
+
def add(self, message: Dict[str, Any]):
|
115
|
+
"""
|
116
|
+
Add message and manage memory size.
|
117
|
+
|
118
|
+
If the number of messages exceeds max_messages, the oldest
|
119
|
+
messages will be summarized and removed.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
message: Message dictionary to add to memory
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
self for method chaining
|
126
|
+
"""
|
127
|
+
super().add(message)
|
128
|
+
if len(self.messages) > self.max_messages:
|
129
|
+
self.summarize_oldest()
|
130
|
+
return self
|
131
|
+
|
132
|
+
def summarize_oldest(self, count: int = 1):
|
133
|
+
"""
|
134
|
+
Summarize the oldest messages in memory.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
count: Number of oldest messages to summarize
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
self for method chaining
|
141
|
+
"""
|
142
|
+
if count <= 0 or count >= len(self.messages):
|
143
|
+
return self
|
144
|
+
|
145
|
+
# Get the oldest messages
|
146
|
+
oldest = self.messages[:count]
|
147
|
+
|
148
|
+
# Create a summary (simple concatenation for now - would use LLM in real impl)
|
149
|
+
contents = [m.get("content", "") for m in oldest if "content" in m]
|
150
|
+
summary_content = "\n".join(contents)
|
151
|
+
summary = {
|
152
|
+
"role": "system",
|
153
|
+
"content": f"Summary of previous messages: {summary_content[:100]}...",
|
154
|
+
"timestamp": datetime.now().isoformat(),
|
155
|
+
"type": "summary",
|
156
|
+
"original_count": count
|
157
|
+
}
|
158
|
+
|
159
|
+
# Add to summaries and remove oldest messages
|
160
|
+
self.summaries.append(summary)
|
161
|
+
self.messages = self.messages[count:]
|
162
|
+
|
163
|
+
# Insert the summary at the beginning of messages
|
164
|
+
self.messages.insert(0, summary)
|
165
|
+
|
166
|
+
return self
|
167
|
+
|
168
|
+
def to_dict(self):
|
169
|
+
"""
|
170
|
+
Convert memory to a dictionary for serialization.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Dictionary representation of memory
|
174
|
+
"""
|
175
|
+
data = super().to_dict()
|
176
|
+
data.update({
|
177
|
+
"max_messages": self.max_messages,
|
178
|
+
"summaries": self.summaries
|
179
|
+
})
|
180
|
+
return data
|
181
|
+
|
182
|
+
@classmethod
|
183
|
+
def from_dict(cls, data: Dict[str, Any]):
|
184
|
+
"""
|
185
|
+
Create a short-term memory instance from a dictionary.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
data: Dictionary representation of memory
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
ShortTermMemory instance
|
192
|
+
"""
|
193
|
+
instance = cls(
|
194
|
+
name=data.get("name"),
|
195
|
+
max_messages=data.get("max_messages", 10)
|
196
|
+
)
|
197
|
+
instance.messages = data.get("messages", [])
|
198
|
+
instance.summaries = data.get("summaries", [])
|
199
|
+
return instance
|
200
|
+
|
201
|
+
|
202
|
+
class LongTermMemory(BaseMemory):
|
203
|
+
"""Long-term persistent memory with advanced retrieval capabilities."""
|
204
|
+
|
205
|
+
def __init__(self, name: Optional[str] = None):
|
206
|
+
"""
|
207
|
+
Initialize long-term memory.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
name: Optional name for the memory instance
|
211
|
+
"""
|
212
|
+
super().__init__(name)
|
213
|
+
self.summaries = []
|
214
|
+
self.keywords = {}
|
215
|
+
self.embeddings = {}
|
216
|
+
self.uuid = str(uuid.uuid4())
|
217
|
+
|
218
|
+
def add(self, message: Dict[str, Any]):
|
219
|
+
"""
|
220
|
+
Add message and update indices.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
message: Message dictionary to add to memory
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
self for method chaining
|
227
|
+
"""
|
228
|
+
super().add(message)
|
229
|
+
self._extract_keywords(message)
|
230
|
+
# Embeddings would be created here in a real implementation
|
231
|
+
return self
|
232
|
+
|
233
|
+
def _extract_keywords(self, message: Dict[str, Any]):
|
234
|
+
"""
|
235
|
+
Extract keywords from message for later retrieval.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
message: Message to extract keywords from
|
239
|
+
"""
|
240
|
+
if "content" not in message:
|
241
|
+
return
|
242
|
+
|
243
|
+
# Simple keyword extraction (would use NLP in real implementation)
|
244
|
+
content = message.get("content", "").lower()
|
245
|
+
words = content.split()
|
246
|
+
|
247
|
+
for word in words:
|
248
|
+
# Remove punctuation
|
249
|
+
word = word.strip('.,!?():;-"\'')
|
250
|
+
if len(word) < 3:
|
251
|
+
continue
|
252
|
+
|
253
|
+
if word not in self.keywords:
|
254
|
+
self.keywords[word] = []
|
255
|
+
|
256
|
+
# Store the index of the message
|
257
|
+
self.keywords[word].append(len(self.messages) - 1)
|
258
|
+
|
259
|
+
def search_by_keyword(self, keyword: str, limit: int = 5):
|
260
|
+
"""
|
261
|
+
Search conversations by keyword.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
keyword: Keyword to search for
|
265
|
+
limit: Maximum number of results to return
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
List of matching messages
|
269
|
+
"""
|
270
|
+
keyword = keyword.lower()
|
271
|
+
|
272
|
+
if keyword not in self.keywords:
|
273
|
+
return []
|
274
|
+
|
275
|
+
# Get message indices for this keyword
|
276
|
+
indices = self.keywords[keyword][-limit:]
|
277
|
+
|
278
|
+
# Return the messages
|
279
|
+
return [self.messages[i] for i in indices if i < len(self.messages)]
|
280
|
+
|
281
|
+
def search_by_semantic(self, query: str, limit: int = 5):
|
282
|
+
"""
|
283
|
+
Search conversations by semantic similarity.
|
284
|
+
|
285
|
+
This is a placeholder. In a real implementation, this would create
|
286
|
+
an embedding for the query and find similar messages.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
query: Search query
|
290
|
+
limit: Maximum number of results to return
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
List of semantically similar messages
|
294
|
+
"""
|
295
|
+
# Placeholder - would use embeddings in real implementation
|
296
|
+
return self.search_by_keyword(query, limit)
|
297
|
+
|
298
|
+
def get_standard_storage_path(self, agent_name: str = None):
|
299
|
+
"""
|
300
|
+
Get the standard path for storing memory data.
|
301
|
+
|
302
|
+
Args:
|
303
|
+
agent_name: Name of the agent
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
Path object for storing memory data
|
307
|
+
"""
|
308
|
+
home_dir = str(Path.home())
|
309
|
+
trmx_dir = os.path.join(home_dir, ".trmx", "agents")
|
310
|
+
|
311
|
+
agent_part = agent_name or "default_agent"
|
312
|
+
memory_part = self.name or "default_memory"
|
313
|
+
|
314
|
+
# Use UUID to create a unique filename
|
315
|
+
filename = f"{self.uuid}.json"
|
316
|
+
|
317
|
+
# Create the complete path
|
318
|
+
storage_path = os.path.join(trmx_dir, agent_part, memory_part)
|
319
|
+
|
320
|
+
return os.path.join(storage_path, filename)
|
321
|
+
|
322
|
+
def persist(self, storage_path: Optional[str] = None, agent_name: Optional[str] = None):
|
323
|
+
"""
|
324
|
+
Save long-term memory to disk.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
storage_path: Path to save memory data (if None, use standard path)
|
328
|
+
agent_name: Name of the agent for standard path
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
self for method chaining
|
332
|
+
"""
|
333
|
+
if storage_path is None:
|
334
|
+
storage_path = self.get_standard_storage_path(agent_name)
|
335
|
+
|
336
|
+
os.makedirs(os.path.dirname(storage_path), exist_ok=True)
|
337
|
+
|
338
|
+
with open(storage_path, "w") as f:
|
339
|
+
json.dump(self.to_dict(), f, indent=2)
|
340
|
+
|
341
|
+
return self
|
342
|
+
|
343
|
+
def load(self, storage_path: Optional[str] = None, agent_name: Optional[str] = None):
|
344
|
+
"""
|
345
|
+
Load long-term memory from disk.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
storage_path: Path to load memory data from (if None, use standard path)
|
349
|
+
agent_name: Name of the agent for standard path
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
self for method chaining
|
353
|
+
"""
|
354
|
+
if storage_path is None:
|
355
|
+
storage_path = self.get_standard_storage_path(agent_name)
|
356
|
+
|
357
|
+
if not os.path.exists(storage_path):
|
358
|
+
return self
|
359
|
+
|
360
|
+
with open(storage_path, "r") as f:
|
361
|
+
data = json.load(f)
|
362
|
+
|
363
|
+
self.name = data.get("name", self.name)
|
364
|
+
self.messages = data.get("messages", [])
|
365
|
+
self.summaries = data.get("summaries", [])
|
366
|
+
self.keywords = data.get("keywords", {})
|
367
|
+
if "uuid" in data:
|
368
|
+
self.uuid = data["uuid"]
|
369
|
+
|
370
|
+
return self
|
371
|
+
|
372
|
+
def to_dict(self):
|
373
|
+
"""
|
374
|
+
Convert memory to a dictionary for serialization.
|
375
|
+
|
376
|
+
Returns:
|
377
|
+
Dictionary representation of memory
|
378
|
+
"""
|
379
|
+
data = super().to_dict()
|
380
|
+
data.update({
|
381
|
+
"summaries": self.summaries,
|
382
|
+
"keywords": self.keywords,
|
383
|
+
"uuid": self.uuid
|
384
|
+
})
|
385
|
+
return data
|
386
|
+
|
387
|
+
@classmethod
|
388
|
+
def from_dict(cls, data: Dict[str, Any]):
|
389
|
+
"""
|
390
|
+
Create a long-term memory instance from a dictionary.
|
391
|
+
|
392
|
+
Args:
|
393
|
+
data: Dictionary representation of memory
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
LongTermMemory instance
|
397
|
+
"""
|
398
|
+
instance = cls(name=data.get("name"))
|
399
|
+
instance.messages = data.get("messages", [])
|
400
|
+
instance.summaries = data.get("summaries", [])
|
401
|
+
instance.keywords = data.get("keywords", {})
|
402
|
+
instance.uuid = data.get("uuid", str(uuid.uuid4()))
|
403
|
+
return instance
|
404
|
+
|
405
|
+
|
406
|
+
class SharedMemory(BaseMemory):
|
407
|
+
"""Memory that can be shared across multiple agents."""
|
408
|
+
|
409
|
+
def __init__(self, name: Optional[str] = None):
|
410
|
+
"""
|
411
|
+
Initialize shared memory.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
name: Optional name for the memory instance
|
415
|
+
"""
|
416
|
+
super().__init__(name)
|
417
|
+
|
418
|
+
def to_dict(self):
|
419
|
+
"""
|
420
|
+
Convert memory to a dictionary for serialization.
|
421
|
+
|
422
|
+
Returns:
|
423
|
+
Dictionary representation of memory
|
424
|
+
"""
|
425
|
+
data = super().to_dict()
|
426
|
+
data["shared"] = True
|
427
|
+
return data
|
428
|
+
|
429
|
+
|
430
|
+
class AgentMemoryManager:
|
431
|
+
"""Manages multiple memory instances for an agent."""
|
432
|
+
|
433
|
+
def __init__(self):
|
434
|
+
"""Initialize memory manager."""
|
435
|
+
self.long_term_memory = LongTermMemory("primary_ltm")
|
436
|
+
self.short_term_memories = {}
|
437
|
+
self.shared_memories = {}
|
438
|
+
|
439
|
+
def create_short_term_memory(
|
440
|
+
self, name: str = "default", max_messages: int = 10
|
441
|
+
) -> ShortTermMemory:
|
442
|
+
"""
|
443
|
+
Create a new short-term memory instance.
|
444
|
+
|
445
|
+
Args:
|
446
|
+
name: Name for the memory instance
|
447
|
+
max_messages: Maximum number of messages before summarization
|
448
|
+
|
449
|
+
Returns:
|
450
|
+
The created ShortTermMemory instance
|
451
|
+
"""
|
452
|
+
self.short_term_memories[name] = ShortTermMemory(name, max_messages)
|
453
|
+
return self.short_term_memories[name]
|
454
|
+
|
455
|
+
def get_short_term_memory(self, name: str = "default") -> ShortTermMemory:
|
456
|
+
"""
|
457
|
+
Get or create a short-term memory by name.
|
458
|
+
|
459
|
+
Args:
|
460
|
+
name: Name of the short-term memory
|
461
|
+
|
462
|
+
Returns:
|
463
|
+
The requested ShortTermMemory instance
|
464
|
+
"""
|
465
|
+
if name not in self.short_term_memories:
|
466
|
+
return self.create_short_term_memory(name)
|
467
|
+
return self.short_term_memories[name]
|
468
|
+
|
469
|
+
def reset_short_term_memory(self, name: str = "default"):
|
470
|
+
"""
|
471
|
+
Reset a specific short-term memory.
|
472
|
+
|
473
|
+
Args:
|
474
|
+
name: Name of the short-term memory to reset
|
475
|
+
|
476
|
+
Returns:
|
477
|
+
self for method chaining
|
478
|
+
"""
|
479
|
+
if name in self.short_term_memories:
|
480
|
+
self.short_term_memories[name].clear()
|
481
|
+
return self
|
482
|
+
|
483
|
+
def add_shared_memory(self, shared_memory: SharedMemory):
|
484
|
+
"""
|
485
|
+
Add a reference to a shared memory.
|
486
|
+
|
487
|
+
Args:
|
488
|
+
shared_memory: SharedMemory instance to add
|
489
|
+
|
490
|
+
Returns:
|
491
|
+
self for method chaining
|
492
|
+
"""
|
493
|
+
self.shared_memories[shared_memory.name] = shared_memory
|
494
|
+
return self
|
495
|
+
|
496
|
+
def add_to_all(self, message: Dict[str, Any]):
|
497
|
+
"""
|
498
|
+
Add message to all memories.
|
499
|
+
|
500
|
+
Args:
|
501
|
+
message: Message to add to all memories
|
502
|
+
|
503
|
+
Returns:
|
504
|
+
self for method chaining
|
505
|
+
"""
|
506
|
+
self.long_term_memory.add(message)
|
507
|
+
for stm in self.short_term_memories.values():
|
508
|
+
stm.add(message)
|
509
|
+
return self
|
510
|
+
|
511
|
+
def add_to_memory(self, memory_name: str, message: Dict[str, Any]):
|
512
|
+
"""
|
513
|
+
Add message to a specific memory.
|
514
|
+
|
515
|
+
Args:
|
516
|
+
memory_name: Name of the memory to add to
|
517
|
+
message: Message to add
|
518
|
+
|
519
|
+
Returns:
|
520
|
+
self for method chaining
|
521
|
+
"""
|
522
|
+
# Add to long-term memory if specified
|
523
|
+
if memory_name == "long_term":
|
524
|
+
self.long_term_memory.add(message)
|
525
|
+
return self
|
526
|
+
|
527
|
+
# Try short-term memories
|
528
|
+
if memory_name in self.short_term_memories:
|
529
|
+
self.short_term_memories[memory_name].add(message)
|
530
|
+
return self
|
531
|
+
|
532
|
+
# Try shared memories
|
533
|
+
if memory_name in self.shared_memories:
|
534
|
+
self.shared_memories[memory_name].add(message)
|
535
|
+
|
536
|
+
return self
|
537
|
+
|
538
|
+
def get_context(
|
539
|
+
self, stm_name: str = "default", include_shared: bool = True
|
540
|
+
) -> List[Dict[str, Any]]:
|
541
|
+
"""
|
542
|
+
Get context from memories for agent processing.
|
543
|
+
|
544
|
+
Args:
|
545
|
+
stm_name: Name of the short-term memory to use
|
546
|
+
include_shared: Whether to include shared memories
|
547
|
+
|
548
|
+
Returns:
|
549
|
+
Combined context from the specified memories
|
550
|
+
"""
|
551
|
+
context = []
|
552
|
+
|
553
|
+
# Add short-term memory context
|
554
|
+
stm = self.get_short_term_memory(stm_name)
|
555
|
+
context.extend(stm.get_messages())
|
556
|
+
|
557
|
+
# Add shared memories if requested
|
558
|
+
if include_shared:
|
559
|
+
for shared_mem in self.shared_memories.values():
|
560
|
+
context.extend(shared_mem.get_messages())
|
561
|
+
|
562
|
+
# Would add relevant long-term memory here based on query/context
|
563
|
+
|
564
|
+
return context
|
565
|
+
|
566
|
+
def persist(self, storage_dir: str = None, agent_name: str = None):
|
567
|
+
"""
|
568
|
+
Save all memories to disk.
|
569
|
+
|
570
|
+
Args:
|
571
|
+
storage_dir: Directory to save memory data, if None use standard path
|
572
|
+
agent_name: Name of the agent for standard paths
|
573
|
+
|
574
|
+
Returns:
|
575
|
+
self for method chaining
|
576
|
+
"""
|
577
|
+
if storage_dir is None and agent_name is None:
|
578
|
+
raise ValueError("Either storage_dir or agent_name must be provided")
|
579
|
+
|
580
|
+
# Save long-term memory
|
581
|
+
if storage_dir:
|
582
|
+
os.makedirs(storage_dir, exist_ok=True)
|
583
|
+
self.long_term_memory.persist(os.path.join(storage_dir, "long_term.json"))
|
584
|
+
else:
|
585
|
+
self.long_term_memory.persist(agent_name=agent_name)
|
586
|
+
|
587
|
+
# Save short-term memories
|
588
|
+
if storage_dir:
|
589
|
+
stm_dir = os.path.join(storage_dir, "short_term")
|
590
|
+
os.makedirs(stm_dir, exist_ok=True)
|
591
|
+
|
592
|
+
for name, memory in self.short_term_memories.items():
|
593
|
+
memory_path = os.path.join(stm_dir, f"{name}.json")
|
594
|
+
with open(memory_path, "w") as f:
|
595
|
+
json.dump(memory.to_dict(), f, indent=2)
|
596
|
+
else:
|
597
|
+
# Use standard paths for each memory
|
598
|
+
home_dir = str(Path.home())
|
599
|
+
stm_dir = os.path.join(home_dir, ".trmx", "agents", agent_name, "short_term")
|
600
|
+
os.makedirs(stm_dir, exist_ok=True)
|
601
|
+
|
602
|
+
for name, memory in self.short_term_memories.items():
|
603
|
+
memory_path = os.path.join(stm_dir, f"{name}.json")
|
604
|
+
with open(memory_path, "w") as f:
|
605
|
+
json.dump(memory.to_dict(), f, indent=2)
|
606
|
+
|
607
|
+
return self
|
608
|
+
|
609
|
+
def load(self, storage_dir: str = None, agent_name: str = None):
|
610
|
+
"""
|
611
|
+
Load all memories from disk.
|
612
|
+
|
613
|
+
Args:
|
614
|
+
storage_dir: Directory to load memory data from
|
615
|
+
agent_name: Name of the agent for standard paths
|
616
|
+
|
617
|
+
Returns:
|
618
|
+
self for method chaining
|
619
|
+
"""
|
620
|
+
if storage_dir is None and agent_name is None:
|
621
|
+
raise ValueError("Either storage_dir or agent_name must be provided")
|
622
|
+
|
623
|
+
# Load long-term memory
|
624
|
+
if storage_dir:
|
625
|
+
ltm_path = os.path.join(storage_dir, "long_term.json")
|
626
|
+
if os.path.exists(ltm_path):
|
627
|
+
self.long_term_memory.load(ltm_path)
|
628
|
+
else:
|
629
|
+
self.long_term_memory.load(agent_name=agent_name)
|
630
|
+
|
631
|
+
# Load short-term memories
|
632
|
+
if storage_dir:
|
633
|
+
stm_dir = os.path.join(storage_dir, "short_term")
|
634
|
+
if os.path.exists(stm_dir):
|
635
|
+
for filename in os.listdir(stm_dir):
|
636
|
+
if not filename.endswith(".json"):
|
637
|
+
continue
|
638
|
+
|
639
|
+
memory_path = os.path.join(stm_dir, filename)
|
640
|
+
with open(memory_path, "r") as f:
|
641
|
+
data = json.load(f)
|
642
|
+
|
643
|
+
name = data.get("name", filename[:-5]) # Remove .json
|
644
|
+
memory = ShortTermMemory.from_dict(data)
|
645
|
+
self.short_term_memories[name] = memory
|
646
|
+
else:
|
647
|
+
# Use standard paths
|
648
|
+
home_dir = str(Path.home())
|
649
|
+
stm_dir = os.path.join(home_dir, ".trmx", "agents", agent_name, "short_term")
|
650
|
+
if os.path.exists(stm_dir):
|
651
|
+
for filename in os.listdir(stm_dir):
|
652
|
+
if not filename.endswith(".json"):
|
653
|
+
continue
|
654
|
+
|
655
|
+
memory_path = os.path.join(stm_dir, filename)
|
656
|
+
with open(memory_path, "r") as f:
|
657
|
+
data = json.load(f)
|
658
|
+
|
659
|
+
name = data.get("name", filename[:-5]) # Remove .json
|
660
|
+
memory = ShortTermMemory.from_dict(data)
|
661
|
+
self.short_term_memories[name] = memory
|
662
|
+
|
663
|
+
return self
|