MemoryOS 0.2.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (62) hide show
  1. {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/METADATA +6 -1
  2. {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/RECORD +61 -55
  3. memos/__init__.py +1 -1
  4. memos/api/config.py +6 -8
  5. memos/api/context/context.py +1 -1
  6. memos/api/context/dependencies.py +11 -0
  7. memos/configs/internet_retriever.py +13 -0
  8. memos/configs/mem_scheduler.py +38 -16
  9. memos/graph_dbs/base.py +30 -3
  10. memos/graph_dbs/nebular.py +442 -194
  11. memos/graph_dbs/neo4j.py +14 -5
  12. memos/log.py +5 -0
  13. memos/mem_os/core.py +19 -9
  14. memos/mem_os/main.py +1 -1
  15. memos/mem_os/product.py +6 -69
  16. memos/mem_os/utils/default_config.py +1 -1
  17. memos/mem_os/utils/format_utils.py +11 -47
  18. memos/mem_os/utils/reference_utils.py +133 -0
  19. memos/mem_scheduler/base_scheduler.py +58 -55
  20. memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
  21. memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
  22. memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
  23. memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
  24. memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
  25. memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
  26. memos/mem_scheduler/general_scheduler.py +110 -67
  27. memos/mem_scheduler/monitors/__init__.py +0 -0
  28. memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
  29. memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
  30. memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
  31. memos/mem_scheduler/schemas/general_schemas.py +3 -2
  32. memos/mem_scheduler/schemas/message_schemas.py +2 -1
  33. memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
  34. memos/mem_scheduler/utils/misc_utils.py +43 -2
  35. memos/memories/activation/item.py +1 -1
  36. memos/memories/activation/kv.py +20 -8
  37. memos/memories/textual/base.py +1 -1
  38. memos/memories/textual/general.py +1 -1
  39. memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
  40. memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
  41. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
  42. memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
  43. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +229 -0
  44. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
  45. memos/memories/textual/tree_text_memory/retrieve/recall.py +15 -8
  46. memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
  47. memos/memories/textual/tree_text_memory/retrieve/searcher.py +177 -125
  48. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
  49. memos/memories/textual/tree_text_memory/retrieve/utils.py +1 -1
  50. memos/memos_tools/lockfree_dict.py +120 -0
  51. memos/memos_tools/thread_safe_dict.py +288 -0
  52. memos/templates/mem_reader_prompts.py +2 -0
  53. memos/templates/mem_scheduler_prompts.py +23 -10
  54. memos/templates/mos_prompts.py +40 -11
  55. memos/templates/tree_reorganize_prompts.py +24 -17
  56. memos/utils.py +19 -0
  57. memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
  58. {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/LICENSE +0 -0
  59. {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/WHEEL +0 -0
  60. {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/entry_points.txt +0 -0
  61. /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
  62. /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
@@ -73,10 +73,12 @@ class RelationAndReasoningDetector:
73
73
  results["sequence_links"].extend(seq)
74
74
  """
75
75
 
76
+ """
76
77
  # 4) Aggregate
77
78
  agg = self._detect_aggregate_node_for_group(node, nearest, min_group_size=5)
78
79
  if agg:
79
80
  results["aggregate_nodes"].append(agg)
81
+ """
80
82
 
81
83
  except Exception as e:
82
84
  logger.error(
@@ -3,7 +3,7 @@ import threading
3
3
  import time
4
4
  import traceback
5
5
 
6
- from collections import Counter, defaultdict
6
+ from collections import defaultdict
7
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
8
  from queue import PriorityQueue
9
9
  from typing import Literal
@@ -17,8 +17,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB
17
17
  from memos.llms.base import BaseLLM
18
18
  from memos.log import get_logger
19
19
  from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
20
- from memos.memories.textual.tree_text_memory.organize.conflict import ConflictHandler
21
- from memos.memories.textual.tree_text_memory.organize.redundancy import RedundancyHandler
20
+ from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
22
21
  from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
23
22
  RelationAndReasoningDetector,
24
23
  )
@@ -63,10 +62,10 @@ class GraphStructureReorganizer:
63
62
  self.relation_detector = RelationAndReasoningDetector(
64
63
  self.graph_store, self.llm, self.embedder
65
64
  )
66
- self.conflict = ConflictHandler(graph_store=graph_store, llm=llm, embedder=embedder)
67
- self.redundancy = RedundancyHandler(graph_store=graph_store, llm=llm, embedder=embedder)
65
+ self.resolver = NodeHandler(graph_store=graph_store, llm=llm, embedder=embedder)
68
66
 
69
67
  self.is_reorganize = is_reorganize
68
+ self._reorganize_needed = True
70
69
  if self.is_reorganize:
71
70
  # ____ 1. For queue message driven thread ___________
72
71
  self.thread = threading.Thread(target=self._run_message_consumer_loop)
@@ -125,13 +124,17 @@ class GraphStructureReorganizer:
125
124
  """
126
125
  import schedule
127
126
 
128
- schedule.every(600).seconds.do(self.optimize_structure, scope="LongTermMemory")
129
- schedule.every(600).seconds.do(self.optimize_structure, scope="UserMemory")
127
+ schedule.every(100).seconds.do(self.optimize_structure, scope="LongTermMemory")
128
+ schedule.every(100).seconds.do(self.optimize_structure, scope="UserMemory")
130
129
 
131
130
  logger.info("Structure optimizer schedule started.")
132
131
  while not getattr(self, "_stop_scheduler", False):
133
- schedule.run_pending()
134
- time.sleep(1)
132
+ if self._reorganize_needed:
133
+ logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
134
+ self.optimize_structure(scope="LongTermMemory")
135
+ self.optimize_structure(scope="UserMemory")
136
+ self._reorganize_needed = False
137
+ time.sleep(30)
135
138
 
136
139
  def stop(self):
137
140
  """
@@ -148,45 +151,31 @@ class GraphStructureReorganizer:
148
151
  logger.info("Structure optimizer stopped.")
149
152
 
150
153
  def handle_message(self, message: QueueMessage):
151
- handle_map = {
152
- "add": self.handle_add,
153
- "remove": self.handle_remove,
154
- "merge": self.handle_merge,
155
- }
154
+ handle_map = {"add": self.handle_add, "remove": self.handle_remove}
156
155
  handle_map[message.op](message)
157
156
  logger.debug(f"message queue size: {self.queue.qsize()}")
158
157
 
159
158
  def handle_add(self, message: QueueMessage):
160
159
  logger.debug(f"Handling add operation: {str(message)[:500]}")
161
- # ———————— 1. check for conflicts ————————
162
160
  added_node = message.after_node[0]
163
- conflicts = self.conflict.detect(added_node, scope=added_node.metadata.memory_type)
164
- if conflicts:
165
- for added_node, existing_node in conflicts:
166
- self.conflict.resolve(added_node, existing_node)
167
- logger.info(f"Resolved conflict between {added_node.id} and {existing_node.id}.")
168
-
169
- # ———————— 2. check for redundancy ————————
170
- redundancies = self.redundancy.detect(added_node, scope=added_node.metadata.memory_type)
171
- if redundancies:
172
- for added_node, existing_node in redundancies:
173
- self.redundancy.resolve_two_nodes(added_node, existing_node)
174
- logger.info(f"Resolved redundancy between {added_node.id} and {existing_node.id}.")
161
+ detected_relationships = self.resolver.detect(
162
+ added_node, scope=added_node.metadata.memory_type
163
+ )
164
+ if detected_relationships:
165
+ for added_node, existing_node, relation in detected_relationships:
166
+ self.resolver.resolve(added_node, existing_node, relation)
167
+
168
+ self._reorganize_needed = True
175
169
 
176
170
  def handle_remove(self, message: QueueMessage):
177
171
  logger.debug(f"Handling remove operation: {str(message)[:50]}")
178
172
 
179
- def handle_merge(self, message: QueueMessage):
180
- after_node = message.after_node[0]
181
- logger.debug(f"Handling merge operation: <{after_node.memory}>")
182
- self.redundancy.resolve_one_node(after_node)
183
-
184
173
  def optimize_structure(
185
174
  self,
186
175
  scope: str = "LongTermMemory",
187
176
  local_tree_threshold: int = 10,
188
- min_cluster_size: int = 3,
189
- min_group_size: int = 5,
177
+ min_cluster_size: int = 4,
178
+ min_group_size: int = 20,
190
179
  ):
191
180
  """
192
181
  Periodically reorganize the graph:
@@ -253,7 +242,7 @@ class GraphStructureReorganizer:
253
242
  except Exception as e:
254
243
  logger.warning(
255
244
  f"[Reorganize] Cluster processing "
256
- f"failed: {e}, trace: {traceback.format_exc()}"
245
+ f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
257
246
  )
258
247
  logger.info("[GraphStructure Reorganize] Structure optimization finished.")
259
248
 
@@ -271,29 +260,23 @@ class GraphStructureReorganizer:
271
260
  if len(cluster_nodes) <= min_cluster_size:
272
261
  return
273
262
 
274
- if len(cluster_nodes) <= local_tree_threshold:
275
- # Small cluster ➜ single parent
276
- parent_node = self._summarize_cluster(cluster_nodes, scope)
277
- self._create_parent_node(parent_node)
278
- self._link_cluster_nodes(parent_node, cluster_nodes)
279
- else:
280
- # Large cluster ➜ local sub-clustering
281
- sub_clusters = self._local_subcluster(cluster_nodes)
282
- sub_parents = []
283
-
284
- for sub_nodes in sub_clusters:
285
- if len(sub_nodes) < min_cluster_size:
286
- continue # Skip tiny noise
287
- sub_parent_node = self._summarize_cluster(sub_nodes, scope)
288
- self._create_parent_node(sub_parent_node)
289
- self._link_cluster_nodes(sub_parent_node, sub_nodes)
290
- sub_parents.append(sub_parent_node)
291
-
292
- if sub_parents:
293
- cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
294
- self._create_parent_node(cluster_parent_node)
295
- for sub_parent in sub_parents:
296
- self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
263
+ # Large cluster ➜ local sub-clustering
264
+ sub_clusters = self._local_subcluster(cluster_nodes)
265
+ sub_parents = []
266
+
267
+ for sub_nodes in sub_clusters:
268
+ if len(sub_nodes) < min_cluster_size:
269
+ continue # Skip tiny noise
270
+ sub_parent_node = self._summarize_cluster(sub_nodes, scope)
271
+ self._create_parent_node(sub_parent_node)
272
+ self._link_cluster_nodes(sub_parent_node, sub_nodes)
273
+ sub_parents.append(sub_parent_node)
274
+
275
+ if sub_parents and len(sub_parents) >= min_cluster_size:
276
+ cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
277
+ self._create_parent_node(cluster_parent_node)
278
+ for sub_parent in sub_parents:
279
+ self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
297
280
 
298
281
  logger.info("Adding relations/reasons")
299
282
  nodes_to_check = cluster_nodes
@@ -350,7 +333,9 @@ class GraphStructureReorganizer:
350
333
 
351
334
  logger.info("[Reorganizer] Cluster relation/reasoning done.")
352
335
 
353
- def _local_subcluster(self, cluster_nodes: list[GraphDBNode]) -> list[list[GraphDBNode]]:
336
+ def _local_subcluster(
337
+ self, cluster_nodes: list[GraphDBNode], max_length: int = 8000
338
+ ) -> (list)[list[GraphDBNode]]:
354
339
  """
355
340
  Use LLM to split a large cluster into semantically coherent sub-clusters.
356
341
  """
@@ -364,7 +349,9 @@ class GraphStructureReorganizer:
364
349
  scene_lines.append(line)
365
350
 
366
351
  joined_scene = "\n".join(scene_lines)
367
- prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene)
352
+ if len(joined_scene) > max_length:
353
+ logger.warning(f"Sub-cluster too long: {joined_scene}")
354
+ prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
368
355
 
369
356
  messages = [{"role": "user", "content": prompt}]
370
357
  response_text = self.llm.generate(messages)
@@ -389,12 +376,12 @@ class GraphStructureReorganizer:
389
376
  install_command="pip install scikit-learn",
390
377
  install_link="https://scikit-learn.org/stable/install.html",
391
378
  )
392
- def _partition(self, nodes, min_cluster_size: int = 3, max_cluster_size: int = 20):
379
+ def _partition(self, nodes, min_cluster_size: int = 10, max_cluster_size: int = 20):
393
380
  """
394
381
  Partition nodes by:
395
- 1) Frequent tags (top N & above threshold)
396
- 2) Remaining nodes by embedding clustering (MiniBatchKMeans)
397
- 3) Small clusters merged or assigned to 'Other'
382
+ - If total nodes <= max_cluster_size -> return all nodes in one cluster.
383
+ - If total nodes > max_cluster_size -> cluster by embeddings, recursively split.
384
+ - Only keep clusters with size > min_cluster_size.
398
385
 
399
386
  Args:
400
387
  nodes: List of GraphDBNode
@@ -405,105 +392,80 @@ class GraphStructureReorganizer:
405
392
  """
406
393
  from sklearn.cluster import MiniBatchKMeans
407
394
 
408
- # 1) Count all tags
409
- tag_counter = Counter()
410
- for node in nodes:
411
- for tag in node.metadata.tags:
412
- tag_counter[tag] += 1
413
-
414
- # Select frequent tags
415
- top_n_tags = {tag for tag, count in tag_counter.most_common(50)}
416
- threshold_tags = {tag for tag, count in tag_counter.items() if count >= 50}
417
- frequent_tags = top_n_tags | threshold_tags
418
-
419
- # Group nodes by tags
420
- tag_groups = defaultdict(list)
421
-
422
- for node in nodes:
423
- for tag in node.metadata.tags:
424
- if tag in frequent_tags:
425
- tag_groups[tag].append(node)
426
- break
427
-
428
- filtered_tag_clusters = []
429
- assigned_ids = set()
430
- for tag, group in tag_groups.items():
431
- if len(group) >= min_cluster_size:
432
- # Split large groups into chunks of at most max_cluster_size
433
- for i in range(0, len(group), max_cluster_size):
434
- sub_group = group[i : i + max_cluster_size]
435
- filtered_tag_clusters.append(sub_group)
436
- assigned_ids.update(n.id for n in sub_group)
437
- else:
438
- logger.info(f"... dropped tag {tag} due to low size ...")
439
-
440
- logger.info(
441
- f"[MixedPartition] Created {len(filtered_tag_clusters)} clusters from tags. "
442
- f"Nodes grouped by tags: {len(assigned_ids)} / {len(nodes)}"
443
- )
444
-
445
- # Remaining nodes -> embedding clustering
446
- remaining_nodes = [n for n in nodes if n.id not in assigned_ids]
447
- logger.info(
448
- f"[MixedPartition] Remaining nodes for embedding clustering: {len(remaining_nodes)}"
449
- )
450
-
451
- embedding_clusters = []
395
+ if len(nodes) <= max_cluster_size:
396
+ logger.info(
397
+ f"[KMeansPartition] Node count {len(nodes)} <= {max_cluster_size}, skipping KMeans."
398
+ )
399
+ return [nodes]
452
400
 
453
- def recursive_clustering(nodes_list):
401
+ def recursive_clustering(nodes_list, depth=0):
454
402
  """Recursively split clusters until each is <= max_cluster_size."""
403
+ indent = " " * depth
404
+ logger.info(
405
+ f"{indent}[Recursive] Start clustering {len(nodes_list)} nodes at depth {depth}"
406
+ )
407
+
455
408
  if len(nodes_list) <= max_cluster_size:
409
+ logger.info(
410
+ f"{indent}[Recursive] Node count <= {max_cluster_size}, stop splitting."
411
+ )
456
412
  return [nodes_list]
457
-
458
413
  # Try kmeans with k = ceil(len(nodes) / max_cluster_size)
459
- x = np.array([n.metadata.embedding for n in nodes_list if n.metadata.embedding])
460
- if len(x) < 2:
414
+ x_nodes = [n for n in nodes_list if n.metadata.embedding]
415
+ x = np.array([n.metadata.embedding for n in x_nodes])
416
+
417
+ if len(x) < min_cluster_size:
418
+ logger.info(
419
+ f"{indent}[Recursive] Too few embeddings ({len(x)}), skipping clustering."
420
+ )
461
421
  return [nodes_list]
462
422
 
463
423
  k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size)
464
- k = max(1, min(k, len(x)))
424
+ k = max(1, k)
465
425
 
466
426
  try:
427
+ logger.info(f"{indent}[Recursive] Clustering with k={k} on {len(x)} points.")
467
428
  kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42)
468
429
  labels = kmeans.fit_predict(x)
469
430
 
470
431
  label_groups = defaultdict(list)
471
- for node, label in zip(nodes_list, labels, strict=False):
432
+ for node, label in zip(x_nodes, labels, strict=False):
472
433
  label_groups[label].append(node)
473
434
 
435
+ # Map: label -> nodes with no embedding (fallback group)
436
+ no_embedding_nodes = [n for n in nodes_list if not n.metadata.embedding]
437
+ if no_embedding_nodes:
438
+ logger.warning(
439
+ f"{indent}[Recursive] {len(no_embedding_nodes)} nodes have no embedding. Added to largest cluster."
440
+ )
441
+ # Assign to largest cluster
442
+ largest_label = max(label_groups.items(), key=lambda kv: len(kv[1]))[0]
443
+ label_groups[largest_label].extend(no_embedding_nodes)
444
+
474
445
  result = []
475
- for sub_group in label_groups.values():
476
- result.extend(recursive_clustering(sub_group))
446
+ for label, sub_group in label_groups.items():
447
+ logger.info(f"{indent} Cluster-{label}: {len(sub_group)} nodes")
448
+ result.extend(recursive_clustering(sub_group, depth=depth + 1))
477
449
  return result
450
+
478
451
  except Exception as e:
479
- logger.warning(f"Clustering failed: {e}, falling back to single cluster.")
452
+ logger.warning(
453
+ f"{indent}[Recursive] Clustering failed: {e}, fallback to one cluster."
454
+ )
480
455
  return [nodes_list]
481
456
 
482
- if remaining_nodes:
483
- clusters = recursive_clustering(remaining_nodes)
484
- embedding_clusters.extend(clusters)
485
- logger.info(
486
- f"[MixedPartition] Created {len(embedding_clusters)} clusters from embeddings."
487
- )
488
-
489
- # Merge all clusters
490
- all_clusters = filtered_tag_clusters + embedding_clusters
457
+ raw_clusters = recursive_clustering(nodes)
458
+ filtered_clusters = [c for c in raw_clusters if len(c) > min_cluster_size]
491
459
 
492
- # Handle small clusters (< min_cluster_size)
493
- final_clusters = []
494
- small_nodes = []
495
- for group in all_clusters:
496
- if len(group) < min_cluster_size:
497
- small_nodes.extend(group)
498
- else:
499
- final_clusters.append(group)
460
+ logger.info(f"[KMeansPartition] Total clusters before filtering: {len(raw_clusters)}")
461
+ for i, cluster in enumerate(raw_clusters):
462
+ logger.info(f"[KMeansPartition] Cluster-{i}: {len(cluster)} nodes")
500
463
 
501
- if small_nodes:
502
- final_clusters.append(small_nodes)
503
- logger.info(f"[MixedPartition] {len(small_nodes)} nodes assigned to 'Other' cluster.")
464
+ logger.info(
465
+ f"[KMeansPartition] Clusters after filtering (>{min_cluster_size}): {len(filtered_clusters)}"
466
+ )
504
467
 
505
- logger.info(f"[MixedPartition] Total final clusters: {len(final_clusters)}")
506
- return final_clusters
468
+ return filtered_clusters
507
469
 
508
470
  def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> GraphDBNode:
509
471
  """
@@ -600,7 +562,7 @@ class GraphStructureReorganizer:
600
562
  for i, node in enumerate(message.after_node or []):
601
563
  if not isinstance(node, str):
602
564
  continue
603
- raw_node = self.graph_store.get_node(node)
565
+ raw_node = self.graph_store.get_node(node, include_embedding=True)
604
566
  if raw_node is None:
605
567
  logger.debug(f"Node with ID {node} not found in the graph store.")
606
568
  message.after_node[i] = None
@@ -0,0 +1,229 @@
1
+ """BochaAI Search API retriever for tree text memory."""
2
+
3
+ import json
4
+
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from datetime import datetime
7
+
8
+ import requests
9
+
10
+ from memos.embedders.factory import OllamaEmbedder
11
+ from memos.log import get_logger
12
+ from memos.mem_reader.base import BaseMemReader
13
+ from memos.memories.textual.item import TextualMemoryItem
14
+
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class BochaAISearchAPI:
20
+ """BochaAI Search API Client"""
21
+
22
+ def __init__(self, api_key: str, max_results: int = 20):
23
+ """
24
+ Initialize BochaAI Search API client.
25
+
26
+ Args:
27
+ api_key: BochaAI API key
28
+ max_results: Maximum number of search results to retrieve
29
+ """
30
+ self.api_key = api_key
31
+ self.max_results = max_results
32
+
33
+ self.web_url = "https://api.bochaai.com/v1/web-search"
34
+ self.ai_url = "https://api.bochaai.com/v1/ai-search"
35
+
36
+ self.headers = {
37
+ "Authorization": f"Bearer {api_key}",
38
+ "Content-Type": "application/json",
39
+ }
40
+
41
+ def search_web(self, query: str, summary: bool = True, freshness="noLimit") -> list[dict]:
42
+ """
43
+ Perform a Web Search (equivalent to the first curl).
44
+
45
+ Args:
46
+ query: Search query string
47
+ summary: Whether to include summary in the results
48
+ freshness: Freshness filter (e.g. 'noLimit', 'day', 'week')
49
+
50
+ Returns:
51
+ A list of search result dicts
52
+ """
53
+ body = {
54
+ "query": query,
55
+ "summary": summary,
56
+ "freshness": freshness,
57
+ "count": self.max_results,
58
+ }
59
+ return self._post(self.web_url, body)
60
+
61
+ def search_ai(
62
+ self, query: str, answer: bool = False, stream: bool = False, freshness="noLimit"
63
+ ) -> list[dict]:
64
+ """
65
+ Perform an AI Search (equivalent to the second curl).
66
+
67
+ Args:
68
+ query: Search query string
69
+ answer: Whether BochaAI should generate an answer
70
+ stream: Whether to use streaming response
71
+ freshness: Freshness filter (e.g. 'noLimit', 'day', 'week')
72
+
73
+ Returns:
74
+ A list of search result dicts
75
+ """
76
+ body = {
77
+ "query": query,
78
+ "freshness": freshness,
79
+ "count": self.max_results,
80
+ "answer": answer,
81
+ "stream": stream,
82
+ }
83
+ return self._post(self.ai_url, body)
84
+
85
+ def _post(self, url: str, body: dict) -> list[dict]:
86
+ """Send POST request and parse BochaAI search results."""
87
+ try:
88
+ resp = requests.post(url, headers=self.headers, json=body)
89
+ resp.raise_for_status()
90
+ raw_data = resp.json()
91
+
92
+ # parse the nested structure correctly
93
+ # ✅ AI Search
94
+ if "messages" in raw_data:
95
+ results = []
96
+ for msg in raw_data["messages"]:
97
+ if msg.get("type") == "source" and msg.get("content_type") == "webpage":
98
+ try:
99
+ content_json = json.loads(msg["content"])
100
+ results.extend(content_json.get("value", []))
101
+ except Exception as e:
102
+ logger.error(f"Failed to parse message content: {e}")
103
+ return results
104
+
105
+ # ✅ Web Search
106
+ return raw_data.get("data", {}).get("webPages", {}).get("value", [])
107
+
108
+ except Exception:
109
+ import traceback
110
+
111
+ logger.error(f"BochaAI search error: {traceback.format_exc()}")
112
+ return []
113
+
114
+
115
+ class BochaAISearchRetriever:
116
+ """BochaAI retriever that converts search results into TextualMemoryItem objects"""
117
+
118
+ def __init__(
119
+ self,
120
+ access_key: str,
121
+ embedder: OllamaEmbedder,
122
+ reader: BaseMemReader,
123
+ max_results: int = 20,
124
+ ):
125
+ """
126
+ Initialize BochaAI Search retriever.
127
+
128
+ Args:
129
+ access_key: BochaAI API key
130
+ embedder: Embedder instance for generating embeddings
131
+ reader: MemReader instance for processing internet content
132
+ max_results: Maximum number of search results to retrieve
133
+ """
134
+ self.bocha_api = BochaAISearchAPI(access_key, max_results=max_results)
135
+ self.embedder = embedder
136
+ self.reader = reader
137
+
138
+ def retrieve_from_internet(
139
+ self, query: str, top_k: int = 10, parsed_goal=None, info=None
140
+ ) -> list[TextualMemoryItem]:
141
+ """
142
+ Default internet retrieval (Web Search).
143
+ This keeps consistent API with Xinyu and Google retrievers.
144
+
145
+ Args:
146
+ query: Search query
147
+ top_k: Number of results to retrieve
148
+ parsed_goal: Parsed task goal (optional)
149
+ info (dict): Metadata for memory consumption tracking
150
+
151
+ Returns:
152
+ List of TextualMemoryItem
153
+ """
154
+ search_results = self.bocha_api.search_ai(query) # ✅ default to
155
+ # web-search
156
+ return self._convert_to_mem_items(search_results, query, parsed_goal, info)
157
+
158
+ def retrieve_from_web(
159
+ self, query: str, top_k: int = 10, parsed_goal=None, info=None
160
+ ) -> list[TextualMemoryItem]:
161
+ """Explicitly retrieve using Bocha Web Search."""
162
+ search_results = self.bocha_api.search_web(query)
163
+ return self._convert_to_mem_items(search_results, query, parsed_goal, info)
164
+
165
+ def retrieve_from_ai(
166
+ self, query: str, top_k: int = 10, parsed_goal=None, info=None
167
+ ) -> list[TextualMemoryItem]:
168
+ """Explicitly retrieve using Bocha AI Search."""
169
+ search_results = self.bocha_api.search_ai(query)
170
+ return self._convert_to_mem_items(search_results, query, parsed_goal, info)
171
+
172
+ def _convert_to_mem_items(
173
+ self, search_results: list[dict], query: str, parsed_goal=None, info=None
174
+ ):
175
+ """Convert API search results into TextualMemoryItem objects."""
176
+ memory_items = []
177
+ if not info:
178
+ info = {"user_id": "", "session_id": ""}
179
+
180
+ with ThreadPoolExecutor(max_workers=8) as executor:
181
+ futures = [
182
+ executor.submit(self._process_result, r, query, parsed_goal, info)
183
+ for r in search_results
184
+ ]
185
+ for future in as_completed(futures):
186
+ try:
187
+ memory_items.extend(future.result())
188
+ except Exception as e:
189
+ logger.error(f"Error processing BochaAI search result: {e}")
190
+
191
+ # Deduplicate items by memory text
192
+ unique_memory_items = {item.memory: item for item in memory_items}
193
+ return list(unique_memory_items.values())
194
+
195
+ def _process_result(
196
+ self, result: dict, query: str, parsed_goal: str, info: None
197
+ ) -> list[TextualMemoryItem]:
198
+ """Process one Bocha search result into TextualMemoryItem."""
199
+ title = result.get("name", "")
200
+ content = result.get("summary", "") or result.get("snippet", "")
201
+ summary = result.get("snippet", "")
202
+ url = result.get("url", "")
203
+ publish_time = result.get("datePublished", "")
204
+
205
+ if publish_time:
206
+ try:
207
+ publish_time = datetime.fromisoformat(publish_time.replace("Z", "+00:00")).strftime(
208
+ "%Y-%m-%d"
209
+ )
210
+ except Exception:
211
+ publish_time = datetime.now().strftime("%Y-%m-%d")
212
+ else:
213
+ publish_time = datetime.now().strftime("%Y-%m-%d")
214
+
215
+ # Use reader to split and process the content into chunks
216
+ read_items = self.reader.get_memory([content], type="doc", info=info)
217
+
218
+ memory_items = []
219
+ for read_item_i in read_items[0]:
220
+ read_item_i.memory = (
221
+ f"Title: {title}\nNewsTime: {publish_time}\nSummary: {summary}\n"
222
+ f"Content: {read_item_i.memory}"
223
+ )
224
+ read_item_i.metadata.source = "web"
225
+ read_item_i.metadata.memory_type = "OuterMemory"
226
+ read_item_i.metadata.sources = [url] if url else []
227
+ read_item_i.metadata.visibility = "public"
228
+ memory_items.append(read_item_i)
229
+ return memory_items
@@ -5,6 +5,7 @@ from typing import Any, ClassVar
5
5
  from memos.configs.internet_retriever import InternetRetrieverConfigFactory
6
6
  from memos.embedders.base import BaseEmbedder
7
7
  from memos.mem_reader.factory import MemReaderFactory
8
+ from memos.memories.textual.tree_text_memory.retrieve.bochasearch import BochaAISearchRetriever
8
9
  from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import (
9
10
  InternetGoogleRetriever,
10
11
  )
@@ -18,6 +19,7 @@ class InternetRetrieverFactory:
18
19
  "google": InternetGoogleRetriever,
19
20
  "bing": InternetGoogleRetriever, # TODO: Implement BingRetriever
20
21
  "xinyu": XinyuSearchRetriever,
22
+ "bocha": BochaAISearchRetriever,
21
23
  }
22
24
 
23
25
  @classmethod
@@ -70,6 +72,13 @@ class InternetRetrieverFactory:
70
72
  reader=MemReaderFactory.from_config(config.reader),
71
73
  max_results=config.max_results,
72
74
  )
75
+ elif backend == "bocha":
76
+ return retriever_class(
77
+ access_key=config.api_key, # Use api_key as access_key for xinyu
78
+ embedder=embedder,
79
+ reader=MemReaderFactory.from_config(config.reader),
80
+ max_results=config.max_results,
81
+ )
73
82
  else:
74
83
  raise ValueError(f"Unsupported backend: {backend}")
75
84