airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,663 @@
1
+ """
2
+ Memory components for AirTrain agents.
3
+
4
+ This module provides memory systems for agents, including short-term,
5
+ long-term, and shared memory implementations.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Optional
9
+ import json
10
+ import os
11
+ import uuid
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+
15
+
16
+ class BaseMemory:
17
+ """Base class for all memory types."""
18
+
19
+ def __init__(self, name: Optional[str] = None):
20
+ """
21
+ Initialize a memory instance.
22
+
23
+ Args:
24
+ name: Optional name for the memory instance
25
+ """
26
+ self.name = name or self.__class__.__name__
27
+ self.messages = []
28
+
29
+ def add(self, message: Dict[str, Any]):
30
+ """
31
+ Add a message to memory.
32
+
33
+ Args:
34
+ message: Message dictionary to add to memory
35
+
36
+ Returns:
37
+ self for method chaining
38
+ """
39
+ # Add timestamp if not present
40
+ if "timestamp" not in message:
41
+ message["timestamp"] = datetime.now().isoformat()
42
+
43
+ self.messages.append(message)
44
+ return self
45
+
46
+ def clear(self):
47
+ """
48
+ Clear all messages from memory.
49
+
50
+ Returns:
51
+ self for method chaining
52
+ """
53
+ self.messages = []
54
+ return self
55
+
56
+ def get_messages(self, limit: Optional[int] = None):
57
+ """
58
+ Get messages with optional limit.
59
+
60
+ Args:
61
+ limit: Maximum number of messages to return (from most recent)
62
+
63
+ Returns:
64
+ List of message dictionaries
65
+ """
66
+ if limit is None or limit <= 0:
67
+ return self.messages
68
+ return self.messages[-limit:]
69
+
70
+ def to_dict(self):
71
+ """
72
+ Convert memory to a dictionary for serialization.
73
+
74
+ Returns:
75
+ Dictionary representation of memory
76
+ """
77
+ return {
78
+ "name": self.name,
79
+ "type": self.__class__.__name__,
80
+ "messages": self.messages
81
+ }
82
+
83
+ @classmethod
84
+ def from_dict(cls, data: Dict[str, Any]):
85
+ """
86
+ Create a memory instance from a dictionary.
87
+
88
+ Args:
89
+ data: Dictionary representation of memory
90
+
91
+ Returns:
92
+ Memory instance
93
+ """
94
+ instance = cls(name=data.get("name"))
95
+ instance.messages = data.get("messages", [])
96
+ return instance
97
+
98
+
99
+ class ShortTermMemory(BaseMemory):
100
+ """Short-term memory with automatic summarization capability."""
101
+
102
+ def __init__(self, name: Optional[str] = None, max_messages: int = 10):
103
+ """
104
+ Initialize short-term memory.
105
+
106
+ Args:
107
+ name: Optional name for the memory instance
108
+ max_messages: Maximum number of messages to keep before summarizing
109
+ """
110
+ super().__init__(name)
111
+ self.max_messages = max_messages
112
+ self.summaries = []
113
+
114
+ def add(self, message: Dict[str, Any]):
115
+ """
116
+ Add message and manage memory size.
117
+
118
+ If the number of messages exceeds max_messages, the oldest
119
+ messages will be summarized and removed.
120
+
121
+ Args:
122
+ message: Message dictionary to add to memory
123
+
124
+ Returns:
125
+ self for method chaining
126
+ """
127
+ super().add(message)
128
+ if len(self.messages) > self.max_messages:
129
+ self.summarize_oldest()
130
+ return self
131
+
132
+ def summarize_oldest(self, count: int = 1):
133
+ """
134
+ Summarize the oldest messages in memory.
135
+
136
+ Args:
137
+ count: Number of oldest messages to summarize
138
+
139
+ Returns:
140
+ self for method chaining
141
+ """
142
+ if count <= 0 or count >= len(self.messages):
143
+ return self
144
+
145
+ # Get the oldest messages
146
+ oldest = self.messages[:count]
147
+
148
+ # Create a summary (simple concatenation for now - would use LLM in real impl)
149
+ contents = [m.get("content", "") for m in oldest if "content" in m]
150
+ summary_content = "\n".join(contents)
151
+ summary = {
152
+ "role": "system",
153
+ "content": f"Summary of previous messages: {summary_content[:100]}...",
154
+ "timestamp": datetime.now().isoformat(),
155
+ "type": "summary",
156
+ "original_count": count
157
+ }
158
+
159
+ # Add to summaries and remove oldest messages
160
+ self.summaries.append(summary)
161
+ self.messages = self.messages[count:]
162
+
163
+ # Insert the summary at the beginning of messages
164
+ self.messages.insert(0, summary)
165
+
166
+ return self
167
+
168
+ def to_dict(self):
169
+ """
170
+ Convert memory to a dictionary for serialization.
171
+
172
+ Returns:
173
+ Dictionary representation of memory
174
+ """
175
+ data = super().to_dict()
176
+ data.update({
177
+ "max_messages": self.max_messages,
178
+ "summaries": self.summaries
179
+ })
180
+ return data
181
+
182
+ @classmethod
183
+ def from_dict(cls, data: Dict[str, Any]):
184
+ """
185
+ Create a short-term memory instance from a dictionary.
186
+
187
+ Args:
188
+ data: Dictionary representation of memory
189
+
190
+ Returns:
191
+ ShortTermMemory instance
192
+ """
193
+ instance = cls(
194
+ name=data.get("name"),
195
+ max_messages=data.get("max_messages", 10)
196
+ )
197
+ instance.messages = data.get("messages", [])
198
+ instance.summaries = data.get("summaries", [])
199
+ return instance
200
+
201
+
202
+ class LongTermMemory(BaseMemory):
203
+ """Long-term persistent memory with advanced retrieval capabilities."""
204
+
205
+ def __init__(self, name: Optional[str] = None):
206
+ """
207
+ Initialize long-term memory.
208
+
209
+ Args:
210
+ name: Optional name for the memory instance
211
+ """
212
+ super().__init__(name)
213
+ self.summaries = []
214
+ self.keywords = {}
215
+ self.embeddings = {}
216
+ self.uuid = str(uuid.uuid4())
217
+
218
+ def add(self, message: Dict[str, Any]):
219
+ """
220
+ Add message and update indices.
221
+
222
+ Args:
223
+ message: Message dictionary to add to memory
224
+
225
+ Returns:
226
+ self for method chaining
227
+ """
228
+ super().add(message)
229
+ self._extract_keywords(message)
230
+ # Embeddings would be created here in a real implementation
231
+ return self
232
+
233
+ def _extract_keywords(self, message: Dict[str, Any]):
234
+ """
235
+ Extract keywords from message for later retrieval.
236
+
237
+ Args:
238
+ message: Message to extract keywords from
239
+ """
240
+ if "content" not in message:
241
+ return
242
+
243
+ # Simple keyword extraction (would use NLP in real implementation)
244
+ content = message.get("content", "").lower()
245
+ words = content.split()
246
+
247
+ for word in words:
248
+ # Remove punctuation
249
+ word = word.strip('.,!?():;-"\'')
250
+ if len(word) < 3:
251
+ continue
252
+
253
+ if word not in self.keywords:
254
+ self.keywords[word] = []
255
+
256
+ # Store the index of the message
257
+ self.keywords[word].append(len(self.messages) - 1)
258
+
259
+ def search_by_keyword(self, keyword: str, limit: int = 5):
260
+ """
261
+ Search conversations by keyword.
262
+
263
+ Args:
264
+ keyword: Keyword to search for
265
+ limit: Maximum number of results to return
266
+
267
+ Returns:
268
+ List of matching messages
269
+ """
270
+ keyword = keyword.lower()
271
+
272
+ if keyword not in self.keywords:
273
+ return []
274
+
275
+ # Get message indices for this keyword
276
+ indices = self.keywords[keyword][-limit:]
277
+
278
+ # Return the messages
279
+ return [self.messages[i] for i in indices if i < len(self.messages)]
280
+
281
+ def search_by_semantic(self, query: str, limit: int = 5):
282
+ """
283
+ Search conversations by semantic similarity.
284
+
285
+ This is a placeholder. In a real implementation, this would create
286
+ an embedding for the query and find similar messages.
287
+
288
+ Args:
289
+ query: Search query
290
+ limit: Maximum number of results to return
291
+
292
+ Returns:
293
+ List of semantically similar messages
294
+ """
295
+ # Placeholder - would use embeddings in real implementation
296
+ return self.search_by_keyword(query, limit)
297
+
298
+ def get_standard_storage_path(self, agent_name: str = None):
299
+ """
300
+ Get the standard path for storing memory data.
301
+
302
+ Args:
303
+ agent_name: Name of the agent
304
+
305
+ Returns:
306
+ Path object for storing memory data
307
+ """
308
+ home_dir = str(Path.home())
309
+ trmx_dir = os.path.join(home_dir, ".trmx", "agents")
310
+
311
+ agent_part = agent_name or "default_agent"
312
+ memory_part = self.name or "default_memory"
313
+
314
+ # Use UUID to create a unique filename
315
+ filename = f"{self.uuid}.json"
316
+
317
+ # Create the complete path
318
+ storage_path = os.path.join(trmx_dir, agent_part, memory_part)
319
+
320
+ return os.path.join(storage_path, filename)
321
+
322
+ def persist(self, storage_path: Optional[str] = None, agent_name: Optional[str] = None):
323
+ """
324
+ Save long-term memory to disk.
325
+
326
+ Args:
327
+ storage_path: Path to save memory data (if None, use standard path)
328
+ agent_name: Name of the agent for standard path
329
+
330
+ Returns:
331
+ self for method chaining
332
+ """
333
+ if storage_path is None:
334
+ storage_path = self.get_standard_storage_path(agent_name)
335
+
336
+ os.makedirs(os.path.dirname(storage_path), exist_ok=True)
337
+
338
+ with open(storage_path, "w") as f:
339
+ json.dump(self.to_dict(), f, indent=2)
340
+
341
+ return self
342
+
343
+ def load(self, storage_path: Optional[str] = None, agent_name: Optional[str] = None):
344
+ """
345
+ Load long-term memory from disk.
346
+
347
+ Args:
348
+ storage_path: Path to load memory data from (if None, use standard path)
349
+ agent_name: Name of the agent for standard path
350
+
351
+ Returns:
352
+ self for method chaining
353
+ """
354
+ if storage_path is None:
355
+ storage_path = self.get_standard_storage_path(agent_name)
356
+
357
+ if not os.path.exists(storage_path):
358
+ return self
359
+
360
+ with open(storage_path, "r") as f:
361
+ data = json.load(f)
362
+
363
+ self.name = data.get("name", self.name)
364
+ self.messages = data.get("messages", [])
365
+ self.summaries = data.get("summaries", [])
366
+ self.keywords = data.get("keywords", {})
367
+ if "uuid" in data:
368
+ self.uuid = data["uuid"]
369
+
370
+ return self
371
+
372
+ def to_dict(self):
373
+ """
374
+ Convert memory to a dictionary for serialization.
375
+
376
+ Returns:
377
+ Dictionary representation of memory
378
+ """
379
+ data = super().to_dict()
380
+ data.update({
381
+ "summaries": self.summaries,
382
+ "keywords": self.keywords,
383
+ "uuid": self.uuid
384
+ })
385
+ return data
386
+
387
+ @classmethod
388
+ def from_dict(cls, data: Dict[str, Any]):
389
+ """
390
+ Create a long-term memory instance from a dictionary.
391
+
392
+ Args:
393
+ data: Dictionary representation of memory
394
+
395
+ Returns:
396
+ LongTermMemory instance
397
+ """
398
+ instance = cls(name=data.get("name"))
399
+ instance.messages = data.get("messages", [])
400
+ instance.summaries = data.get("summaries", [])
401
+ instance.keywords = data.get("keywords", {})
402
+ instance.uuid = data.get("uuid", str(uuid.uuid4()))
403
+ return instance
404
+
405
+
406
+ class SharedMemory(BaseMemory):
407
+ """Memory that can be shared across multiple agents."""
408
+
409
+ def __init__(self, name: Optional[str] = None):
410
+ """
411
+ Initialize shared memory.
412
+
413
+ Args:
414
+ name: Optional name for the memory instance
415
+ """
416
+ super().__init__(name)
417
+
418
+ def to_dict(self):
419
+ """
420
+ Convert memory to a dictionary for serialization.
421
+
422
+ Returns:
423
+ Dictionary representation of memory
424
+ """
425
+ data = super().to_dict()
426
+ data["shared"] = True
427
+ return data
428
+
429
+
430
+ class AgentMemoryManager:
431
+ """Manages multiple memory instances for an agent."""
432
+
433
+ def __init__(self):
434
+ """Initialize memory manager."""
435
+ self.long_term_memory = LongTermMemory("primary_ltm")
436
+ self.short_term_memories = {}
437
+ self.shared_memories = {}
438
+
439
+ def create_short_term_memory(
440
+ self, name: str = "default", max_messages: int = 10
441
+ ) -> ShortTermMemory:
442
+ """
443
+ Create a new short-term memory instance.
444
+
445
+ Args:
446
+ name: Name for the memory instance
447
+ max_messages: Maximum number of messages before summarization
448
+
449
+ Returns:
450
+ The created ShortTermMemory instance
451
+ """
452
+ self.short_term_memories[name] = ShortTermMemory(name, max_messages)
453
+ return self.short_term_memories[name]
454
+
455
+ def get_short_term_memory(self, name: str = "default") -> ShortTermMemory:
456
+ """
457
+ Get or create a short-term memory by name.
458
+
459
+ Args:
460
+ name: Name of the short-term memory
461
+
462
+ Returns:
463
+ The requested ShortTermMemory instance
464
+ """
465
+ if name not in self.short_term_memories:
466
+ return self.create_short_term_memory(name)
467
+ return self.short_term_memories[name]
468
+
469
+ def reset_short_term_memory(self, name: str = "default"):
470
+ """
471
+ Reset a specific short-term memory.
472
+
473
+ Args:
474
+ name: Name of the short-term memory to reset
475
+
476
+ Returns:
477
+ self for method chaining
478
+ """
479
+ if name in self.short_term_memories:
480
+ self.short_term_memories[name].clear()
481
+ return self
482
+
483
+ def add_shared_memory(self, shared_memory: SharedMemory):
484
+ """
485
+ Add a reference to a shared memory.
486
+
487
+ Args:
488
+ shared_memory: SharedMemory instance to add
489
+
490
+ Returns:
491
+ self for method chaining
492
+ """
493
+ self.shared_memories[shared_memory.name] = shared_memory
494
+ return self
495
+
496
+ def add_to_all(self, message: Dict[str, Any]):
497
+ """
498
+ Add message to all memories.
499
+
500
+ Args:
501
+ message: Message to add to all memories
502
+
503
+ Returns:
504
+ self for method chaining
505
+ """
506
+ self.long_term_memory.add(message)
507
+ for stm in self.short_term_memories.values():
508
+ stm.add(message)
509
+ return self
510
+
511
+ def add_to_memory(self, memory_name: str, message: Dict[str, Any]):
512
+ """
513
+ Add message to a specific memory.
514
+
515
+ Args:
516
+ memory_name: Name of the memory to add to
517
+ message: Message to add
518
+
519
+ Returns:
520
+ self for method chaining
521
+ """
522
+ # Add to long-term memory if specified
523
+ if memory_name == "long_term":
524
+ self.long_term_memory.add(message)
525
+ return self
526
+
527
+ # Try short-term memories
528
+ if memory_name in self.short_term_memories:
529
+ self.short_term_memories[memory_name].add(message)
530
+ return self
531
+
532
+ # Try shared memories
533
+ if memory_name in self.shared_memories:
534
+ self.shared_memories[memory_name].add(message)
535
+
536
+ return self
537
+
538
+ def get_context(
539
+ self, stm_name: str = "default", include_shared: bool = True
540
+ ) -> List[Dict[str, Any]]:
541
+ """
542
+ Get context from memories for agent processing.
543
+
544
+ Args:
545
+ stm_name: Name of the short-term memory to use
546
+ include_shared: Whether to include shared memories
547
+
548
+ Returns:
549
+ Combined context from the specified memories
550
+ """
551
+ context = []
552
+
553
+ # Add short-term memory context
554
+ stm = self.get_short_term_memory(stm_name)
555
+ context.extend(stm.get_messages())
556
+
557
+ # Add shared memories if requested
558
+ if include_shared:
559
+ for shared_mem in self.shared_memories.values():
560
+ context.extend(shared_mem.get_messages())
561
+
562
+ # Would add relevant long-term memory here based on query/context
563
+
564
+ return context
565
+
566
+ def persist(self, storage_dir: str = None, agent_name: str = None):
567
+ """
568
+ Save all memories to disk.
569
+
570
+ Args:
571
+ storage_dir: Directory to save memory data, if None use standard path
572
+ agent_name: Name of the agent for standard paths
573
+
574
+ Returns:
575
+ self for method chaining
576
+ """
577
+ if storage_dir is None and agent_name is None:
578
+ raise ValueError("Either storage_dir or agent_name must be provided")
579
+
580
+ # Save long-term memory
581
+ if storage_dir:
582
+ os.makedirs(storage_dir, exist_ok=True)
583
+ self.long_term_memory.persist(os.path.join(storage_dir, "long_term.json"))
584
+ else:
585
+ self.long_term_memory.persist(agent_name=agent_name)
586
+
587
+ # Save short-term memories
588
+ if storage_dir:
589
+ stm_dir = os.path.join(storage_dir, "short_term")
590
+ os.makedirs(stm_dir, exist_ok=True)
591
+
592
+ for name, memory in self.short_term_memories.items():
593
+ memory_path = os.path.join(stm_dir, f"{name}.json")
594
+ with open(memory_path, "w") as f:
595
+ json.dump(memory.to_dict(), f, indent=2)
596
+ else:
597
+ # Use standard paths for each memory
598
+ home_dir = str(Path.home())
599
+ stm_dir = os.path.join(home_dir, ".trmx", "agents", agent_name, "short_term")
600
+ os.makedirs(stm_dir, exist_ok=True)
601
+
602
+ for name, memory in self.short_term_memories.items():
603
+ memory_path = os.path.join(stm_dir, f"{name}.json")
604
+ with open(memory_path, "w") as f:
605
+ json.dump(memory.to_dict(), f, indent=2)
606
+
607
+ return self
608
+
609
+ def load(self, storage_dir: str = None, agent_name: str = None):
610
+ """
611
+ Load all memories from disk.
612
+
613
+ Args:
614
+ storage_dir: Directory to load memory data from
615
+ agent_name: Name of the agent for standard paths
616
+
617
+ Returns:
618
+ self for method chaining
619
+ """
620
+ if storage_dir is None and agent_name is None:
621
+ raise ValueError("Either storage_dir or agent_name must be provided")
622
+
623
+ # Load long-term memory
624
+ if storage_dir:
625
+ ltm_path = os.path.join(storage_dir, "long_term.json")
626
+ if os.path.exists(ltm_path):
627
+ self.long_term_memory.load(ltm_path)
628
+ else:
629
+ self.long_term_memory.load(agent_name=agent_name)
630
+
631
+ # Load short-term memories
632
+ if storage_dir:
633
+ stm_dir = os.path.join(storage_dir, "short_term")
634
+ if os.path.exists(stm_dir):
635
+ for filename in os.listdir(stm_dir):
636
+ if not filename.endswith(".json"):
637
+ continue
638
+
639
+ memory_path = os.path.join(stm_dir, filename)
640
+ with open(memory_path, "r") as f:
641
+ data = json.load(f)
642
+
643
+ name = data.get("name", filename[:-5]) # Remove .json
644
+ memory = ShortTermMemory.from_dict(data)
645
+ self.short_term_memories[name] = memory
646
+ else:
647
+ # Use standard paths
648
+ home_dir = str(Path.home())
649
+ stm_dir = os.path.join(home_dir, ".trmx", "agents", agent_name, "short_term")
650
+ if os.path.exists(stm_dir):
651
+ for filename in os.listdir(stm_dir):
652
+ if not filename.endswith(".json"):
653
+ continue
654
+
655
+ memory_path = os.path.join(stm_dir, filename)
656
+ with open(memory_path, "r") as f:
657
+ data = json.load(f)
658
+
659
+ name = data.get("name", filename[:-5]) # Remove .json
660
+ memory = ShortTermMemory.from_dict(data)
661
+ self.short_term_memories[name] = memory
662
+
663
+ return self