memu-py 1.3.0__cp313-abi3-win_amd64.whl → 1.4.0__cp313-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
memu/_core.pyd CHANGED
Binary file
memu/app/crud.py CHANGED
@@ -654,7 +654,7 @@ class CRUDMixin:
654
654
  prompt = self._build_category_patch_prompt(
655
655
  category=cat, content_before=content_before, content_after=content_after
656
656
  )
657
- tasks.append(client.summarize(prompt, system_prompt=None))
657
+ tasks.append(client.chat(prompt))
658
658
  target_ids.append(cid)
659
659
  if not tasks:
660
660
  return
memu/app/memorize.py CHANGED
@@ -240,7 +240,7 @@ class MemorizeMixin:
240
240
  resources: list[Resource] = []
241
241
  items: list[MemoryItem] = []
242
242
  relations: list[CategoryItem] = []
243
- category_updates: dict[str, list[str]] = {}
243
+ category_updates: dict[str, list[tuple[str, str]]] = {}
244
244
  user_scope = state.get("user", {})
245
245
 
246
246
  for plan in state.get("resource_plans", []):
@@ -282,12 +282,18 @@ class MemorizeMixin:
282
282
 
283
283
  async def _memorize_persist_and_index(self, state: WorkflowState, step_context: Any) -> WorkflowState:
284
284
  llm_client = self._get_step_llm_client(step_context)
285
- await self._update_category_summaries(
285
+ updated_summaries = await self._update_category_summaries(
286
286
  state.get("category_updates", {}),
287
287
  ctx=state["ctx"],
288
288
  store=state["store"],
289
289
  llm_client=llm_client,
290
290
  )
291
+ if self.memorize_config.enable_item_references:
292
+ await self._persist_item_references(
293
+ updated_summaries=updated_summaries,
294
+ category_updates=state.get("category_updates", {}),
295
+ store=state["store"],
296
+ )
291
297
  return state
292
298
 
293
299
  def _memorize_build_response(self, state: WorkflowState, step_context: Any) -> WorkflowState:
@@ -522,7 +528,8 @@ class MemorizeMixin:
522
528
  for mtype in memory_types
523
529
  ]
524
530
  valid_prompts = [prompt for prompt in prompts if prompt.strip()]
525
- tasks = [client.summarize(prompt_text) for prompt_text in valid_prompts]
531
+ # These prompts are instructions that request structured output, not text summaries.
532
+ tasks = [client.chat(prompt_text) for prompt_text in valid_prompts]
526
533
  responses = await asyncio.gather(*tasks)
527
534
  return self._parse_structured_entries(memory_types, responses)
528
535
 
@@ -577,14 +584,23 @@ class MemorizeMixin:
577
584
  store: Database,
578
585
  embed_client: Any | None = None,
579
586
  user: Mapping[str, Any] | None = None,
580
- ) -> tuple[list[MemoryItem], list[CategoryItem], dict[str, list[str]]]:
587
+ ) -> tuple[list[MemoryItem], list[CategoryItem], dict[str, list[tuple[str, str]]]]:
588
+ """
589
+ Persist memory items and track category updates.
590
+
591
+ Returns:
592
+ Tuple of (items, relations, category_updates)
593
+ where category_updates maps category_id -> list of (item_id, summary) tuples
594
+ """
581
595
  summary_payloads = [content for _, content, _ in structured_entries]
582
596
  client = embed_client or self._get_llm_client()
583
597
  item_embeddings = await client.embed(summary_payloads) if summary_payloads else []
584
598
  items: list[MemoryItem] = []
585
599
  rels: list[CategoryItem] = []
586
- category_memory_updates: dict[str, list[str]] = {}
600
+ # Changed: now stores (item_id, summary) tuples for reference support
601
+ category_memory_updates: dict[str, list[tuple[str, str]]] = {}
587
602
 
603
+ reinforce = self.memorize_config.enable_item_reinforcement
588
604
  for (memory_type, summary_text, cat_names), emb in zip(structured_entries, item_embeddings, strict=True):
589
605
  item = store.memory_item_repo.create_item(
590
606
  resource_id=resource_id,
@@ -592,12 +608,17 @@ class MemorizeMixin:
592
608
  summary=summary_text,
593
609
  embedding=emb,
594
610
  user_data=dict(user or {}),
611
+ reinforce=reinforce,
595
612
  )
596
613
  items.append(item)
614
+ if reinforce and item.extra.get("reinforcement_count", 1) > 1:
615
+ # existing item
616
+ continue
597
617
  mapped_cat_ids = self._map_category_names_to_ids(cat_names, ctx)
598
618
  for cid in mapped_cat_ids:
599
619
  rels.append(store.category_item_repo.link_item_category(item.id, cid, user_data=dict(user or {})))
600
- category_memory_updates.setdefault(cid, []).append(summary_text)
620
+ # Store (item_id, summary) tuple for reference support
621
+ category_memory_updates.setdefault(cid, []).append((item.id, summary_text))
601
622
 
602
623
  return items, rels, category_memory_updates
603
624
 
@@ -779,7 +800,7 @@ class MemorizeMixin:
779
800
  preprocessed_text = format_conversation_for_preprocess(text)
780
801
  prompt = template.format(conversation=self._escape_prompt_value(preprocessed_text))
781
802
  client = llm_client or self._get_llm_client()
782
- processed = await client.summarize(prompt, system_prompt=None)
803
+ processed = await client.chat(prompt)
783
804
  _conv, segments = self._parse_conversation_preprocess_with_segments(processed, preprocessed_text)
784
805
 
785
806
  # Important: always use the original JSON-derived, indexed conversation text for downstream
@@ -809,16 +830,13 @@ class MemorizeMixin:
809
830
 
810
831
  async def _summarize_segment(self, segment_text: str, llm_client: Any | None = None) -> str | None:
811
832
  """Summarize a single conversation segment."""
812
- prompt = f"""Summarize the following conversation segment in 1-2 concise sentences.
813
- Focus on the main topic or theme discussed.
814
-
815
- Conversation:
816
- {segment_text}
817
-
818
- Summary:"""
833
+ system_prompt = (
834
+ "Summarize the given conversation segment in 1-2 concise sentences. "
835
+ "Focus on the main topic or theme discussed."
836
+ )
819
837
  try:
820
838
  client = llm_client or self._get_llm_client()
821
- response = await client.summarize(prompt, system_prompt=None)
839
+ response = await client.chat(segment_text, system_prompt=system_prompt)
822
840
  return response.strip() if response else None
823
841
  except Exception:
824
842
  logger.exception("Failed to summarize segment")
@@ -895,7 +913,7 @@ Summary:"""
895
913
  """Preprocess document data - condense and extract caption"""
896
914
  prompt = template.format(document_text=self._escape_prompt_value(text))
897
915
  client = llm_client or self._get_llm_client()
898
- processed = await client.summarize(prompt, system_prompt=None)
916
+ processed = await client.chat(prompt)
899
917
  processed_content, caption = self._parse_multimodal_response(processed, "processed_content", "caption")
900
918
  return [{"text": processed_content or text, "caption": caption}]
901
919
 
@@ -905,7 +923,7 @@ Summary:"""
905
923
  """Preprocess audio data - format transcription and extract caption"""
906
924
  prompt = template.format(transcription=self._escape_prompt_value(text))
907
925
  client = llm_client or self._get_llm_client()
908
- processed = await client.summarize(prompt, system_prompt=None)
926
+ processed = await client.chat(prompt)
909
927
  processed_content, caption = self._parse_multimodal_response(processed, "processed_content", "caption")
910
928
  return [{"text": processed_content or text, "caption": caption}]
911
929
 
@@ -960,19 +978,115 @@ Summary:"""
960
978
  safe_categories = self._escape_prompt_value(categories_str)
961
979
  return template.format(resource=safe_resource, categories_str=safe_categories)
962
980
 
963
- def _build_category_summary_prompt(self, *, category: MemoryCategory, new_memories: list[str]) -> str:
964
- new_items_text = "\n".join(f"- {m}" for m in new_memories if m.strip())
981
+ def _build_item_ref_id(self, item_id: str) -> str:
982
+ return item_id.replace("-", "")[:6]
983
+
984
+ def _extract_refs_from_summaries(self, summaries: dict[str, str]) -> set[str]:
985
+ """
986
+ Extract all [ref:xxx] references from summary texts.
987
+
988
+ Args:
989
+ summaries: dict mapping category_id -> summary text
990
+
991
+ Returns:
992
+ Set of all referenced short IDs (the xxx part from [ref:xxx])
993
+ """
994
+ from memu.utils.references import extract_references
995
+
996
+ refs: set[str] = set()
997
+ for summary in summaries.values():
998
+ refs.update(extract_references(summary))
999
+ return refs
1000
+
1001
+ async def _persist_item_references(
1002
+ self,
1003
+ *,
1004
+ updated_summaries: dict[str, str],
1005
+ category_updates: dict[str, list[tuple[str, str]]],
1006
+ store: Database,
1007
+ ) -> None:
1008
+ """
1009
+ Persist ref_id to items that are referenced in category summaries.
1010
+
1011
+ This function:
1012
+ 1. Extracts all [ref:xxx] patterns from updated summaries
1013
+ 2. Builds a mapping of short_id -> full item_id for all items in category_updates
1014
+ 3. For items whose short_id appears in the references, updates their extra column
1015
+ with {"ref_id": short_id}
1016
+ """
1017
+ # Extract all referenced short IDs from summaries
1018
+ referenced_short_ids = self._extract_refs_from_summaries(updated_summaries)
1019
+ if not referenced_short_ids:
1020
+ return
1021
+
1022
+ # Build mapping of short_id -> full item_id for all items in category_updates
1023
+ short_id_to_item_id: dict[str, str] = {}
1024
+ for item_tuples in category_updates.values():
1025
+ for item_id, _ in item_tuples:
1026
+ short_id = self._build_item_ref_id(item_id)
1027
+ short_id_to_item_id[short_id] = item_id
1028
+
1029
+ # Update extra column for referenced items
1030
+ for short_id in referenced_short_ids:
1031
+ matched_item_id = short_id_to_item_id.get(short_id)
1032
+ if matched_item_id:
1033
+ store.memory_item_repo.update_item(
1034
+ item_id=matched_item_id,
1035
+ extra={"ref_id": short_id},
1036
+ )
1037
+
1038
+ def _build_category_summary_prompt(
1039
+ self,
1040
+ *,
1041
+ category: MemoryCategory,
1042
+ new_memories: list[str] | list[tuple[str, str]],
1043
+ ) -> str:
1044
+ """
1045
+ Build the prompt for updating a category summary.
1046
+
1047
+ Args:
1048
+ category: The category to update
1049
+ new_memories: Either list of summary strings (legacy) or list of (item_id, summary) tuples (with refs)
1050
+ """
1051
+ # Check if references are enabled and we have (id, summary) tuples
1052
+ enable_refs = getattr(self.memorize_config, "enable_item_references", False)
1053
+
1054
+ if enable_refs:
1055
+ from memu.prompts.category_summary import (
1056
+ CUSTOM_PROMPT_WITH_REFS as category_summary_custom_prompt,
1057
+ )
1058
+ from memu.prompts.category_summary import (
1059
+ PROMPT_WITH_REFS as category_summary_prompt,
1060
+ )
1061
+
1062
+ tuple_memories = cast(list[tuple[str, str]], new_memories)
1063
+ new_items_text = "\n".join(
1064
+ f"- [{self._build_item_ref_id(item_id)}] {summary}"
1065
+ for item_id, summary in tuple_memories
1066
+ if summary.strip()
1067
+ )
1068
+ else:
1069
+ category_summary_prompt = CATEGORY_SUMMARY_PROMPT
1070
+ category_summary_custom_prompt = CATEGORY_SUMMARY_CUSTOM_PROMPT
1071
+
1072
+ if new_memories and isinstance(new_memories[0], tuple):
1073
+ tuple_memories = cast(list[tuple[str, str]], new_memories)
1074
+ new_items_text = "\n".join(f"- {summary}" for item_id, summary in tuple_memories if summary.strip())
1075
+ else:
1076
+ str_memories = cast(list[str], new_memories)
1077
+ new_items_text = "\n".join(f"- {m}" for m in str_memories if m.strip())
1078
+
965
1079
  original = category.summary or ""
966
1080
  category_config = self.category_config_map.get(category.name)
967
1081
  configured_prompt = (
968
1082
  category_config and category_config.summary_prompt
969
1083
  ) or self.memorize_config.default_category_summary_prompt
970
1084
  if configured_prompt is None:
971
- prompt = CATEGORY_SUMMARY_PROMPT
1085
+ prompt = category_summary_prompt
972
1086
  elif isinstance(configured_prompt, str):
973
1087
  prompt = configured_prompt
974
1088
  else:
975
- prompt = self._resolve_custom_prompt(configured_prompt, CATEGORY_SUMMARY_CUSTOM_PROMPT)
1089
+ prompt = self._resolve_custom_prompt(configured_prompt, category_summary_custom_prompt)
976
1090
  target_length = (
977
1091
  category_config and category_config.target_length
978
1092
  ) or self.memorize_config.default_category_summary_target_length
@@ -985,13 +1099,20 @@ Summary:"""
985
1099
 
986
1100
  async def _update_category_summaries(
987
1101
  self,
988
- updates: dict[str, list[str]],
1102
+ updates: dict[str, list[tuple[str, str]]] | dict[str, list[str]],
989
1103
  ctx: Context,
990
1104
  store: Database,
991
1105
  llm_client: Any | None = None,
992
- ) -> None:
1106
+ ) -> dict[str, str]:
1107
+ """
1108
+ Update category summaries based on new memory items.
1109
+
1110
+ Returns:
1111
+ dict mapping category_id -> updated summary text
1112
+ """
1113
+ updated_summaries: dict[str, str] = {}
993
1114
  if not updates:
994
- return
1115
+ return updated_summaries
995
1116
  tasks = []
996
1117
  target_ids: list[str] = []
997
1118
  client = llm_client or self._get_llm_client()
@@ -1000,19 +1121,22 @@ Summary:"""
1000
1121
  if not cat or not memories:
1001
1122
  continue
1002
1123
  prompt = self._build_category_summary_prompt(category=cat, new_memories=memories)
1003
- tasks.append(client.summarize(prompt, system_prompt=None))
1124
+ tasks.append(client.chat(prompt))
1004
1125
  target_ids.append(cid)
1005
1126
  if not tasks:
1006
- return
1127
+ return updated_summaries
1007
1128
  summaries = await asyncio.gather(*tasks)
1008
1129
  for cid, summary in zip(target_ids, summaries, strict=True):
1009
1130
  cat = store.memory_category_repo.categories.get(cid)
1010
1131
  if not cat:
1011
1132
  continue
1133
+ cleaned_summary = summary.replace("```markdown", "").replace("```", "").strip()
1012
1134
  store.memory_category_repo.update_category(
1013
1135
  category_id=cid,
1014
- summary=summary.replace("```markdown", "").replace("```", "").strip(),
1136
+ summary=cleaned_summary,
1015
1137
  )
1138
+ updated_summaries[cid] = cleaned_summary
1139
+ return updated_summaries
1016
1140
 
1017
1141
  def _parse_conversation_preprocess(self, raw: str) -> tuple[str | None, str | None]:
1018
1142
  conversation = self._extract_tag_content(raw, "conversation")
memu/app/patch.py CHANGED
@@ -407,7 +407,7 @@ class PatchMixin:
407
407
  prompt = self._build_category_patch_prompt(
408
408
  category=cat, content_before=content_before, content_after=content_after
409
409
  )
410
- tasks.append(client.summarize(prompt, system_prompt=None))
410
+ tasks.append(client.chat(prompt))
411
411
  target_ids.append(cid)
412
412
  if not tasks:
413
413
  return
memu/app/retrieve.py CHANGED
@@ -321,6 +321,28 @@ class RetrieveMixin:
321
321
  state["query_vector"] = (await embed_client.embed([state["active_query"]]))[0]
322
322
  return state
323
323
 
324
+ def _extract_referenced_item_ids(self, state: WorkflowState) -> set[str]:
325
+ """Extract item IDs from category summary references."""
326
+ from memu.utils.references import extract_references
327
+
328
+ category_hits = state.get("category_hits") or []
329
+ summary_lookup = state.get("category_summary_lookup", {})
330
+ category_pool = state.get("category_pool") or {}
331
+ referenced_item_ids: set[str] = set()
332
+
333
+ for cid, _score in category_hits:
334
+ # Get summary from lookup or category
335
+ summary = summary_lookup.get(cid)
336
+ if not summary:
337
+ cat = category_pool.get(cid)
338
+ if cat:
339
+ summary = cat.summary
340
+ if summary:
341
+ refs = extract_references(summary)
342
+ referenced_item_ids.update(refs)
343
+
344
+ return referenced_item_ids
345
+
324
346
  async def _rag_recall_items(self, state: WorkflowState, step_context: Any) -> WorkflowState:
325
347
  if not state.get("retrieve_item") or not state.get("needs_retrieval") or not state.get("proceed_to_items"):
326
348
  state["item_hits"] = []
@@ -338,6 +360,8 @@ class RetrieveMixin:
338
360
  qvec,
339
361
  self.retrieve_config.item.top_k,
340
362
  where=where_filters,
363
+ ranking=self.retrieve_config.item.ranking,
364
+ recency_decay_days=self.retrieve_config.item.recency_decay_days,
341
365
  )
342
366
  state["item_pool"] = items_pool
343
367
  return state
@@ -594,10 +618,26 @@ class RetrieveMixin:
594
618
  return state
595
619
 
596
620
  where_filters = state.get("where") or {}
597
- category_ids = [cat["id"] for cat in state.get("category_hits", [])]
621
+ category_hits = state.get("category_hits", [])
622
+ category_ids = [cat["id"] for cat in category_hits]
598
623
  llm_client = self._get_step_llm_client(step_context)
599
624
  store = state["store"]
600
- items_pool = store.memory_item_repo.list_items(where_filters)
625
+
626
+ use_refs = getattr(self.retrieve_config.item, "use_category_references", False)
627
+ ref_ids: list[str] = []
628
+ if use_refs and category_hits:
629
+ # Extract all ref_ids from category summaries
630
+ from memu.utils.references import extract_references
631
+
632
+ for cat in category_hits:
633
+ summary = cat.get("summary") or ""
634
+ ref_ids.extend(extract_references(summary))
635
+ if ref_ids:
636
+ # Query items by ref_ids
637
+ items_pool = store.memory_item_repo.list_items_by_ref_ids(ref_ids, where_filters)
638
+ else:
639
+ items_pool = store.memory_item_repo.list_items(where_filters)
640
+
601
641
  relations = store.category_item_repo.list_relations(where_filters)
602
642
  category_pool = state.get("category_pool") or store.memory_category_repo.list_categories(where_filters)
603
643
  state["item_hits"] = await self._llm_rank_items(
@@ -737,7 +777,7 @@ class RetrieveMixin:
737
777
 
738
778
  sys_prompt = system_prompt or PRE_RETRIEVAL_SYSTEM_PROMPT
739
779
  client = llm_client or self._get_llm_client()
740
- response = await client.summarize(user_prompt, system_prompt=sys_prompt)
780
+ response = await client.chat(user_prompt, system_prompt=sys_prompt)
741
781
  decision = self._extract_decision(response)
742
782
  rewritten = self._extract_rewritten_query(response) or query
743
783
 
@@ -1195,7 +1235,7 @@ class RetrieveMixin:
1195
1235
  )
1196
1236
 
1197
1237
  client = llm_client or self._get_llm_client()
1198
- llm_response = await client.summarize(prompt, system_prompt=None)
1238
+ llm_response = await client.chat(prompt)
1199
1239
  return self._parse_llm_category_response(llm_response, store, categories=category_pool)
1200
1240
 
1201
1241
  async def _llm_rank_items(
@@ -1234,7 +1274,7 @@ class RetrieveMixin:
1234
1274
  )
1235
1275
 
1236
1276
  client = llm_client or self._get_llm_client()
1237
- llm_response = await client.summarize(prompt, system_prompt=None)
1277
+ llm_response = await client.chat(prompt)
1238
1278
  return self._parse_llm_item_response(llm_response, store, items=item_pool)
1239
1279
 
1240
1280
  async def _llm_rank_resources(
@@ -1279,7 +1319,7 @@ class RetrieveMixin:
1279
1319
  )
1280
1320
 
1281
1321
  client = llm_client or self._get_llm_client()
1282
- llm_response = await client.summarize(prompt, system_prompt=None)
1322
+ llm_response = await client.chat(prompt)
1283
1323
  return self._parse_llm_resource_response(llm_response, store, resources=resource_pool)
1284
1324
 
1285
1325
  def _parse_llm_category_response(
memu/app/settings.py CHANGED
@@ -151,6 +151,20 @@ class RetrieveCategoryConfig(BaseModel):
151
151
  class RetrieveItemConfig(BaseModel):
152
152
  enabled: bool = Field(default=True, description="Whether to enable item retrieval.")
153
153
  top_k: int = Field(default=5, description="Total number of items to retrieve.")
154
+ # Reference-aware retrieval
155
+ use_category_references: bool = Field(
156
+ default=False,
157
+ description="When category retrieval is insufficient, follow [ref:ITEM_ID] citations to fetch referenced items.",
158
+ )
159
+ # Salience-aware retrieval settings
160
+ ranking: Literal["similarity", "salience"] = Field(
161
+ default="similarity",
162
+ description="Ranking strategy: 'similarity' (cosine only) or 'salience' (weighted by reinforcement + recency).",
163
+ )
164
+ recency_decay_days: float = Field(
165
+ default=30.0,
166
+ description="Half-life in days for recency decay in salience scoring. After this many days, recency factor is ~0.5.",
167
+ )
154
168
 
155
169
 
156
170
  class RetrieveResourceConfig(BaseModel):
@@ -217,6 +231,15 @@ class MemorizeConfig(BaseModel):
217
231
  description="Target max length for auto-generated category summaries.",
218
232
  )
219
233
  category_update_llm_profile: str = Field(default="default", description="LLM profile for category summary.")
234
+ # Reference tracking for category summaries
235
+ enable_item_references: bool = Field(
236
+ default=False,
237
+ description="Enable inline [ref:ITEM_ID] citations in category summaries linking to source memory items.",
238
+ )
239
+ enable_item_reinforcement: bool = Field(
240
+ default=False,
241
+ description="Enable reinforcement tracking for memory items.",
242
+ )
220
243
 
221
244
 
222
245
  class PatchConfig(BaseModel):
@@ -225,6 +248,9 @@ class PatchConfig(BaseModel):
225
248
 
226
249
  class DefaultUserModel(BaseModel):
227
250
  user_id: str | None = None
251
+ # Agent/session scoping for multi-agent and multi-session memory filtering
252
+ # agent_id: str | None = None
253
+ # session_id: str | None = None
228
254
 
229
255
 
230
256
  class UserConfig(BaseModel):
@@ -0,0 +1,26 @@
1
+ """
2
+ memU Client Wrapper for Auto-Recall Memory Injection.
3
+
4
+ This module provides optional wrappers around OpenAI/Anthropic clients
5
+ that automatically inject recalled memories into prompts.
6
+
7
+ Usage:
8
+ from memu.client import wrap_openai
9
+ from openai import OpenAI
10
+
11
+ client = OpenAI()
12
+ service = MemoryService(...)
13
+
14
+ # Wrap the client for auto-recall
15
+ wrapped_client = wrap_openai(client, service, user_id="user123")
16
+
17
+ # Now all chat completions automatically include relevant memories
18
+ response = wrapped_client.chat.completions.create(
19
+ model="gpt-4",
20
+ messages=[{"role": "user", "content": "What's my favorite drink?"}]
21
+ )
22
+ """
23
+
24
+ from memu.client.openai_wrapper import MemuOpenAIWrapper, wrap_openai
25
+
26
+ __all__ = ["MemuOpenAIWrapper", "wrap_openai"]