busbot-memory 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- busbot_memory/__init__.py +15 -0
- busbot_memory/core/__init__.py +21 -0
- busbot_memory/core/config.py +60 -0
- busbot_memory/core/manager.py +244 -0
- busbot_memory/core/models.py +193 -0
- busbot_memory/domains/__init__.py +1 -0
- busbot_memory/extractors/__init__.py +7 -0
- busbot_memory/extractors/base.py +27 -0
- busbot_memory/extractors/llm.py +197 -0
- busbot_memory/extractors/regex.py +128 -0
- busbot_memory/memory/__init__.py +1 -0
- busbot_memory/state/__init__.py +5 -0
- busbot_memory/state/manager.py +90 -0
- busbot_memory/storage/__init__.py +1 -0
- busbot_memory/utils/__init__.py +1 -0
- busbot_memory/version.py +2 -0
- busbot_memory-0.1.0.dist-info/METADATA +121 -0
- busbot_memory-0.1.0.dist-info/RECORD +20 -0
- busbot_memory-0.1.0.dist-info/WHEEL +5 -0
- busbot_memory-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""BusBot Memory SDK - LLM-powered working memory for bus booking bots"""
|
|
2
|
+
|
|
3
|
+
from busbot_memory.core.manager import BusBotMemory
|
|
4
|
+
from busbot_memory.core.models import BookingState, MemoryItem, ProcessResult
|
|
5
|
+
from busbot_memory.core.config import BusBotConfig
|
|
6
|
+
from busbot_memory.version import __version__
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BusBotMemory",
|
|
10
|
+
"BookingState",
|
|
11
|
+
"MemoryItem",
|
|
12
|
+
"ProcessResult",
|
|
13
|
+
"BusBotConfig",
|
|
14
|
+
"__version__",
|
|
15
|
+
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Core module exports"""
|
|
2
|
+
|
|
3
|
+
from busbot_memory.core.models import (
|
|
4
|
+
BookingState,
|
|
5
|
+
MemoryItem,
|
|
6
|
+
MemoryMetadata,
|
|
7
|
+
ExtractionResult,
|
|
8
|
+
ProcessResult,
|
|
9
|
+
Intent,
|
|
10
|
+
)
|
|
11
|
+
from busbot_memory.core.config import BusBotConfig
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BookingState",
|
|
15
|
+
"MemoryItem",
|
|
16
|
+
"MemoryMetadata",
|
|
17
|
+
"ExtractionResult",
|
|
18
|
+
"ProcessResult",
|
|
19
|
+
"Intent",
|
|
20
|
+
"BusBotConfig",
|
|
21
|
+
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Configuration for BusBot Memory SDK"""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class BusBotConfig:
|
|
10
|
+
"""
|
|
11
|
+
SDK Configuration
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
config = BusBotConfig(
|
|
15
|
+
groq_api_key="gsk_xxx",
|
|
16
|
+
redis_url="redis://localhost:6379",
|
|
17
|
+
domain="bus_booking"
|
|
18
|
+
)
|
|
19
|
+
"""
|
|
20
|
+
# LLM Provider
|
|
21
|
+
groq_api_key: Optional[str] = field(
|
|
22
|
+
default_factory=lambda: os.getenv("GROQ_API_KEY")
|
|
23
|
+
)
|
|
24
|
+
groq_model: str = "llama-3.3-70b-versatile"
|
|
25
|
+
groq_fallback_model: str = "llama-3.1-8b-instant"
|
|
26
|
+
|
|
27
|
+
# OpenAI (optional, for higher quality)
|
|
28
|
+
openai_api_key: Optional[str] = field(
|
|
29
|
+
default_factory=lambda: os.getenv("OPENAI_API_KEY")
|
|
30
|
+
)
|
|
31
|
+
openai_model: str = "gpt-4o-mini"
|
|
32
|
+
|
|
33
|
+
# Storage
|
|
34
|
+
redis_url: Optional[str] = field(
|
|
35
|
+
default_factory=lambda: os.getenv("REDIS_URL")
|
|
36
|
+
)
|
|
37
|
+
session_ttl_seconds: int = 3600 # 1 hour
|
|
38
|
+
user_memory_ttl_days: int = 30 # 30 days
|
|
39
|
+
|
|
40
|
+
# Domain
|
|
41
|
+
domain: str = "bus_booking"
|
|
42
|
+
|
|
43
|
+
# Memory settings
|
|
44
|
+
max_working_items: int = 20
|
|
45
|
+
max_context_window: int = 5
|
|
46
|
+
|
|
47
|
+
# Performance
|
|
48
|
+
latency_target_ms: int = 250
|
|
49
|
+
enable_fallback: bool = True # Fallback to regex if LLM fails
|
|
50
|
+
enable_metrics: bool = True # Track latency metrics
|
|
51
|
+
|
|
52
|
+
# Logging
|
|
53
|
+
log_level: str = "INFO"
|
|
54
|
+
log_extractions: bool = False # Log LLM extraction results
|
|
55
|
+
|
|
56
|
+
def validate(self) -> bool:
|
|
57
|
+
"""Validate configuration"""
|
|
58
|
+
if not self.groq_api_key and not self.openai_api_key:
|
|
59
|
+
raise ValueError("At least one of groq_api_key or openai_api_key must be set")
|
|
60
|
+
return True
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BusBotMemory - Main SDK Entry Point
|
|
3
|
+
|
|
4
|
+
This is the primary class users will interact with.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional, List
|
|
10
|
+
from collections import deque
|
|
11
|
+
|
|
12
|
+
from busbot_memory.core.config import BusBotConfig
|
|
13
|
+
from busbot_memory.core.models import (
|
|
14
|
+
BookingState,
|
|
15
|
+
MemoryItem,
|
|
16
|
+
MemoryMetadata,
|
|
17
|
+
ProcessResult,
|
|
18
|
+
ExtractionResult,
|
|
19
|
+
)
|
|
20
|
+
from busbot_memory.extractors.llm import LLMExtractor
|
|
21
|
+
from busbot_memory.extractors.regex import RegexExtractor
|
|
22
|
+
from busbot_memory.state.manager import StateManager
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BusBotMemory:
|
|
28
|
+
"""
|
|
29
|
+
LLM-powered working memory for bus booking bots
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
from busbot_memory import BusBotMemory, BusBotConfig
|
|
33
|
+
|
|
34
|
+
config = BusBotConfig(groq_api_key="gsk_xxx")
|
|
35
|
+
memory = BusBotMemory(
|
|
36
|
+
session_id="call_001",
|
|
37
|
+
customer_id="0987654321",
|
|
38
|
+
config=config
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
result = await memory.process("đặt 2 vé đi đà nẵng ngày mai")
|
|
42
|
+
print(result.state.slots) # {"destination": "Đà Nẵng", "quantity": 2, ...}
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
session_id: str,
|
|
48
|
+
customer_id: Optional[str] = None,
|
|
49
|
+
config: Optional[BusBotConfig] = None,
|
|
50
|
+
):
|
|
51
|
+
self.session_id = session_id
|
|
52
|
+
self.customer_id = customer_id
|
|
53
|
+
self.config = config or BusBotConfig()
|
|
54
|
+
|
|
55
|
+
# Validate config
|
|
56
|
+
self.config.validate()
|
|
57
|
+
|
|
58
|
+
# Initialize components
|
|
59
|
+
self._llm_extractor = LLMExtractor(self.config)
|
|
60
|
+
self._regex_extractor = RegexExtractor()
|
|
61
|
+
self._state_manager = StateManager()
|
|
62
|
+
|
|
63
|
+
# Working memory storage
|
|
64
|
+
self._memory: deque = deque(maxlen=self.config.max_working_items)
|
|
65
|
+
|
|
66
|
+
# Booking state
|
|
67
|
+
self._state: BookingState = self._state_manager.create_initial_state()
|
|
68
|
+
|
|
69
|
+
# User memory (persistent info)
|
|
70
|
+
self._user_memory: dict = {}
|
|
71
|
+
|
|
72
|
+
# Metrics
|
|
73
|
+
self._latencies: List[int] = []
|
|
74
|
+
|
|
75
|
+
logger.info(f"BusBotMemory initialized: session={session_id}")
|
|
76
|
+
|
|
77
|
+
async def process(self, message: str, role: str = "user") -> ProcessResult:
|
|
78
|
+
"""
|
|
79
|
+
Process a message and update memory + state
|
|
80
|
+
|
|
81
|
+
This is the main entry point for the SDK.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
message: The message to process
|
|
85
|
+
role: "user" or "assistant"
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
ProcessResult with entities, state, changes, and latency
|
|
89
|
+
"""
|
|
90
|
+
start_time = time.perf_counter()
|
|
91
|
+
|
|
92
|
+
# Build context from recent memory
|
|
93
|
+
context = self._build_context()
|
|
94
|
+
|
|
95
|
+
# Extract entities using LLM (with fallback)
|
|
96
|
+
try:
|
|
97
|
+
extraction = await self._llm_extractor.extract(message, context)
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.warning(f"LLM extraction failed: {e}")
|
|
100
|
+
if self.config.enable_fallback:
|
|
101
|
+
extraction = await self._regex_extractor.extract(message, context)
|
|
102
|
+
else:
|
|
103
|
+
raise
|
|
104
|
+
|
|
105
|
+
# Skip state update for noise messages
|
|
106
|
+
changes = []
|
|
107
|
+
if not extraction.is_noise:
|
|
108
|
+
# Update state
|
|
109
|
+
self._state, changes = self._state_manager.update(
|
|
110
|
+
self._state,
|
|
111
|
+
extraction
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Add to memory
|
|
115
|
+
self._add_to_memory(message, role, extraction)
|
|
116
|
+
|
|
117
|
+
# Extract user info if present
|
|
118
|
+
self._extract_user_info(extraction)
|
|
119
|
+
|
|
120
|
+
# Calculate latency
|
|
121
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
122
|
+
self._latencies.append(latency_ms)
|
|
123
|
+
|
|
124
|
+
if self.config.enable_metrics:
|
|
125
|
+
logger.debug(f"Process latency: {latency_ms}ms")
|
|
126
|
+
|
|
127
|
+
return ProcessResult(
|
|
128
|
+
entities=extraction.entities,
|
|
129
|
+
state=self._state,
|
|
130
|
+
is_noise=extraction.is_noise,
|
|
131
|
+
is_change=extraction.is_change,
|
|
132
|
+
changes=changes,
|
|
133
|
+
intent=extraction.intent,
|
|
134
|
+
confidence=extraction.confidence,
|
|
135
|
+
latency_ms=latency_ms,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _build_context(self) -> str:
|
|
139
|
+
"""Build context string from recent memory"""
|
|
140
|
+
if not self._memory:
|
|
141
|
+
return "Đây là tin nhắn đầu tiên trong cuộc hội thoại."
|
|
142
|
+
|
|
143
|
+
recent = list(self._memory)[-self.config.max_context_window:]
|
|
144
|
+
|
|
145
|
+
lines = []
|
|
146
|
+
for item in recent:
|
|
147
|
+
role_label = "User" if item.role == "user" else "Bot"
|
|
148
|
+
lines.append(f"{role_label}: {item.content}")
|
|
149
|
+
|
|
150
|
+
# Add current state summary
|
|
151
|
+
if self._state.slots:
|
|
152
|
+
state_str = ", ".join(f"{k}={v}" for k, v in self._state.slots.items())
|
|
153
|
+
lines.append(f"Current booking: {state_str}")
|
|
154
|
+
|
|
155
|
+
return "\n".join(lines)
|
|
156
|
+
|
|
157
|
+
def _add_to_memory(
|
|
158
|
+
self,
|
|
159
|
+
message: str,
|
|
160
|
+
role: str,
|
|
161
|
+
extraction: ExtractionResult
|
|
162
|
+
):
|
|
163
|
+
"""Add message to working memory"""
|
|
164
|
+
item = MemoryItem(
|
|
165
|
+
content=message,
|
|
166
|
+
key=f"{role}_{len(self._memory)}",
|
|
167
|
+
role=role,
|
|
168
|
+
metadata=MemoryMetadata(
|
|
169
|
+
confidence=extraction.confidence,
|
|
170
|
+
tags=list(extraction.entities.keys()),
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
self._memory.append(item)
|
|
174
|
+
|
|
175
|
+
def _extract_user_info(self, extraction: ExtractionResult):
|
|
176
|
+
"""Extract and store user information"""
|
|
177
|
+
user_fields = ["customer_name", "customer_phone"]
|
|
178
|
+
for field in user_fields:
|
|
179
|
+
if field in extraction.entities:
|
|
180
|
+
self._user_memory[field] = extraction.entities[field]
|
|
181
|
+
|
|
182
|
+
# ========================================================================
|
|
183
|
+
# State Access
|
|
184
|
+
# ========================================================================
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def state(self) -> BookingState:
|
|
188
|
+
"""Get current booking state"""
|
|
189
|
+
return self._state
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def memory(self) -> List[MemoryItem]:
|
|
193
|
+
"""Get all memory items"""
|
|
194
|
+
return list(self._memory)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def user_memory(self) -> dict:
|
|
198
|
+
"""Get user persistent memory"""
|
|
199
|
+
return self._user_memory.copy()
|
|
200
|
+
|
|
201
|
+
# ========================================================================
|
|
202
|
+
# Metrics
|
|
203
|
+
# ========================================================================
|
|
204
|
+
|
|
205
|
+
def get_metrics(self) -> dict:
|
|
206
|
+
"""Get performance metrics"""
|
|
207
|
+
if not self._latencies:
|
|
208
|
+
return {"count": 0}
|
|
209
|
+
|
|
210
|
+
sorted_latencies = sorted(self._latencies)
|
|
211
|
+
|
|
212
|
+
return {
|
|
213
|
+
"count": len(self._latencies),
|
|
214
|
+
"avg_ms": sum(self._latencies) // len(self._latencies),
|
|
215
|
+
"p50_ms": sorted_latencies[len(sorted_latencies) // 2],
|
|
216
|
+
"p95_ms": sorted_latencies[int(len(sorted_latencies) * 0.95)],
|
|
217
|
+
"max_ms": max(self._latencies),
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
# ========================================================================
|
|
221
|
+
# Serialization
|
|
222
|
+
# ========================================================================
|
|
223
|
+
|
|
224
|
+
def to_dict(self) -> dict:
|
|
225
|
+
"""Export memory state for persistence"""
|
|
226
|
+
return {
|
|
227
|
+
"session_id": self.session_id,
|
|
228
|
+
"customer_id": self.customer_id,
|
|
229
|
+
"state": self._state.to_dict(),
|
|
230
|
+
"memory": [item.to_dict() for item in self._memory],
|
|
231
|
+
"user_memory": self._user_memory,
|
|
232
|
+
"metrics": self.get_metrics(),
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
def load_state(self, state_dict: dict):
|
|
236
|
+
"""Load state from dict (e.g., from Redis)"""
|
|
237
|
+
if "state" in state_dict:
|
|
238
|
+
self._state = BookingState.from_dict(state_dict["state"])
|
|
239
|
+
if "user_memory" in state_dict:
|
|
240
|
+
self._user_memory = state_dict["user_memory"]
|
|
241
|
+
if "memory" in state_dict:
|
|
242
|
+
self._memory.clear()
|
|
243
|
+
for item_dict in state_dict["memory"]:
|
|
244
|
+
self._memory.append(MemoryItem.from_dict(item_dict))
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Core data models for BusBot Memory SDK"""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field, asdict
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Intent(str, Enum):
|
|
10
|
+
"""Possible user intents"""
|
|
11
|
+
BOOK_TICKET = "book_ticket"
|
|
12
|
+
CANCEL = "cancel"
|
|
13
|
+
RESCHEDULE = "reschedule"
|
|
14
|
+
INQUIRY = "inquiry"
|
|
15
|
+
COMPLAINT = "complaint"
|
|
16
|
+
CONFIRM = "confirm"
|
|
17
|
+
GREETING = "greeting"
|
|
18
|
+
UNCLEAR = "unclear"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class BookingState:
|
|
23
|
+
"""
|
|
24
|
+
Structured booking state - tracks all slot values
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
state = BookingState(
|
|
28
|
+
intent=Intent.BOOK_TICKET,
|
|
29
|
+
slots={"destination": "Đà Nẵng", "time": "08:00"},
|
|
30
|
+
missing_slots=["customer_name", "phone"]
|
|
31
|
+
)
|
|
32
|
+
"""
|
|
33
|
+
intent: str = Intent.UNCLEAR.value
|
|
34
|
+
slots: Dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
missing_slots: List[str] = field(default_factory=list)
|
|
36
|
+
confidence: float = 0.0
|
|
37
|
+
last_updated: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
38
|
+
|
|
39
|
+
# Slot definitions for bus booking domain
|
|
40
|
+
SLOT_TYPES = [
|
|
41
|
+
"departure", # Điểm đi
|
|
42
|
+
"destination", # Điểm đến
|
|
43
|
+
"date", # Ngày đi
|
|
44
|
+
"time", # Giờ đi
|
|
45
|
+
"quantity", # Số vé
|
|
46
|
+
"seat_type", # Loại ghế
|
|
47
|
+
"customer_name", # Tên khách
|
|
48
|
+
"customer_phone", # SĐT khách
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"intent": self.intent,
|
|
54
|
+
"slots": self.slots,
|
|
55
|
+
"missing_slots": self.missing_slots,
|
|
56
|
+
"confidence": self.confidence,
|
|
57
|
+
"last_updated": self.last_updated,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def from_dict(cls, data: Dict[str, Any]) -> "BookingState":
|
|
62
|
+
return cls(
|
|
63
|
+
intent=data.get("intent", Intent.UNCLEAR.value),
|
|
64
|
+
slots=data.get("slots", {}),
|
|
65
|
+
missing_slots=data.get("missing_slots", []),
|
|
66
|
+
confidence=data.get("confidence", 0.0),
|
|
67
|
+
last_updated=data.get("last_updated", datetime.now().isoformat()),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def update_slots(self, new_slots: Dict[str, Any]) -> List[str]:
|
|
71
|
+
"""
|
|
72
|
+
Update slots and return list of changes made
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List of change descriptions, e.g. ["destination: Hải Phòng → Đà Nẵng"]
|
|
76
|
+
"""
|
|
77
|
+
changes = []
|
|
78
|
+
for key, new_value in new_slots.items():
|
|
79
|
+
if new_value is None:
|
|
80
|
+
continue
|
|
81
|
+
old_value = self.slots.get(key)
|
|
82
|
+
if old_value != new_value:
|
|
83
|
+
if old_value:
|
|
84
|
+
changes.append(f"{key}: {old_value} → {new_value}")
|
|
85
|
+
else:
|
|
86
|
+
changes.append(f"{key}: {new_value}")
|
|
87
|
+
self.slots[key] = new_value
|
|
88
|
+
|
|
89
|
+
# Remove from missing if was missing
|
|
90
|
+
if key in self.missing_slots:
|
|
91
|
+
self.missing_slots.remove(key)
|
|
92
|
+
|
|
93
|
+
self.last_updated = datetime.now().isoformat()
|
|
94
|
+
self._update_confidence()
|
|
95
|
+
return changes
|
|
96
|
+
|
|
97
|
+
def _update_confidence(self):
|
|
98
|
+
"""Calculate confidence based on filled slots"""
|
|
99
|
+
filled = len(self.slots)
|
|
100
|
+
total = filled + len(self.missing_slots)
|
|
101
|
+
self.confidence = filled / total if total > 0 else 0.0
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class MemoryMetadata:
|
|
106
|
+
"""Metadata for memory items"""
|
|
107
|
+
memory_type: str = "working" # working, user, long_term
|
|
108
|
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
109
|
+
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
110
|
+
confidence: float = 1.0
|
|
111
|
+
source: str = "conversation"
|
|
112
|
+
tags: List[str] = field(default_factory=list)
|
|
113
|
+
usage_count: int = 0
|
|
114
|
+
importance: float = 0.5
|
|
115
|
+
is_obsolete: bool = False
|
|
116
|
+
superseded_by: Optional[str] = None
|
|
117
|
+
|
|
118
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
119
|
+
return asdict(self)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class MemoryItem:
|
|
124
|
+
"""Single memory item in working memory"""
|
|
125
|
+
id: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d%H%M%S%f"))
|
|
126
|
+
content: str = ""
|
|
127
|
+
key: str = ""
|
|
128
|
+
role: str = "user" # user, assistant
|
|
129
|
+
metadata: MemoryMetadata = field(default_factory=MemoryMetadata)
|
|
130
|
+
|
|
131
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
132
|
+
return {
|
|
133
|
+
"id": self.id,
|
|
134
|
+
"content": self.content,
|
|
135
|
+
"key": self.key,
|
|
136
|
+
"role": self.role,
|
|
137
|
+
"metadata": self.metadata.to_dict(),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def from_dict(cls, data: Dict[str, Any]) -> "MemoryItem":
|
|
142
|
+
metadata = MemoryMetadata(**data.get("metadata", {}))
|
|
143
|
+
return cls(
|
|
144
|
+
id=data.get("id", ""),
|
|
145
|
+
content=data.get("content", ""),
|
|
146
|
+
key=data.get("key", ""),
|
|
147
|
+
role=data.get("role", "user"),
|
|
148
|
+
metadata=metadata,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass
|
|
153
|
+
class ExtractionResult:
|
|
154
|
+
"""Result from LLM entity extraction"""
|
|
155
|
+
entities: Dict[str, Any] = field(default_factory=dict)
|
|
156
|
+
intent: str = Intent.UNCLEAR.value
|
|
157
|
+
is_noise: bool = False
|
|
158
|
+
is_change: bool = False
|
|
159
|
+
changed_fields: List[str] = field(default_factory=list)
|
|
160
|
+
confidence: float = 0.0
|
|
161
|
+
raw_response: Optional[str] = None
|
|
162
|
+
|
|
163
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
164
|
+
return asdict(self)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass
|
|
168
|
+
class ProcessResult:
|
|
169
|
+
"""
|
|
170
|
+
Final result from memory.process() call
|
|
171
|
+
|
|
172
|
+
This is the main output that users of the SDK will work with.
|
|
173
|
+
"""
|
|
174
|
+
entities: Dict[str, Any] = field(default_factory=dict)
|
|
175
|
+
state: BookingState = field(default_factory=BookingState)
|
|
176
|
+
is_noise: bool = False
|
|
177
|
+
is_change: bool = False
|
|
178
|
+
changes: List[str] = field(default_factory=list)
|
|
179
|
+
intent: str = Intent.UNCLEAR.value
|
|
180
|
+
confidence: float = 0.0
|
|
181
|
+
latency_ms: int = 0
|
|
182
|
+
|
|
183
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
184
|
+
return {
|
|
185
|
+
"entities": self.entities,
|
|
186
|
+
"state": self.state.to_dict(),
|
|
187
|
+
"is_noise": self.is_noise,
|
|
188
|
+
"is_change": self.is_change,
|
|
189
|
+
"changes": self.changes,
|
|
190
|
+
"intent": self.intent,
|
|
191
|
+
"confidence": self.confidence,
|
|
192
|
+
"latency_ms": self.latency_ms,
|
|
193
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Domains module - placeholder for Phase 2"""
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Extractors module exports"""
|
|
2
|
+
|
|
3
|
+
from busbot_memory.extractors.base import BaseExtractor
|
|
4
|
+
from busbot_memory.extractors.llm import LLMExtractor
|
|
5
|
+
from busbot_memory.extractors.regex import RegexExtractor
|
|
6
|
+
|
|
7
|
+
__all__ = ["BaseExtractor", "LLMExtractor", "RegexExtractor"]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Base extractor interface"""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from busbot_memory.core.models import ExtractionResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseExtractor(ABC):
|
|
9
|
+
"""Abstract base class for entity extractors"""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def extract(
|
|
13
|
+
self,
|
|
14
|
+
message: str,
|
|
15
|
+
context: Optional[str] = None
|
|
16
|
+
) -> ExtractionResult:
|
|
17
|
+
"""
|
|
18
|
+
Extract entities from message
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
message: User message to extract from
|
|
22
|
+
context: Optional conversation context
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
ExtractionResult with entities, intent, and flags
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""LLM-based entity extractor using Groq"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from busbot_memory.extractors.base import BaseExtractor
|
|
8
|
+
from busbot_memory.core.models import ExtractionResult
|
|
9
|
+
from busbot_memory.core.config import BusBotConfig
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# Optimized prompt for Vietnamese bus booking - compact for low latency
|
|
15
|
+
EXTRACTION_PROMPT = """Bạn là NLU system cho đặt vé xe Việt Nam. Extract thông tin từ tin nhắn.
|
|
16
|
+
|
|
17
|
+
OUTPUT JSON:
|
|
18
|
+
{
|
|
19
|
+
"entities": {
|
|
20
|
+
"departure": "điểm đi nếu có",
|
|
21
|
+
"destination": "điểm đến nếu có",
|
|
22
|
+
"date": "YYYY-MM-DD hoặc 'ngày mai'/'hôm nay' nếu có",
|
|
23
|
+
"time": "HH:MM nếu có (8 giờ sáng = 08:00, 2 giờ chiều = 14:00)",
|
|
24
|
+
"quantity": số nguyên nếu có,
|
|
25
|
+
"seat_type": "ghế ngồi|giường nằm|limousine nếu có",
|
|
26
|
+
"customer_name": "họ tên nếu có",
|
|
27
|
+
"customer_phone": "số điện thoại nếu có"
|
|
28
|
+
},
|
|
29
|
+
"intent": "book_ticket|cancel|reschedule|inquiry|complaint|confirm|greeting|unclear",
|
|
30
|
+
"is_noise": true nếu chỉ là filler ("ừ","ok","dạ","chờ xíu"),
|
|
31
|
+
"is_change": true nếu đổi thông tin đã cung cấp trước đó,
|
|
32
|
+
"changed_fields": ["field1"] nếu is_change=true,
|
|
33
|
+
"confidence": 0.0-1.0
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
RULES:
|
|
37
|
+
- Chỉ điền field nếu có info, bỏ null
|
|
38
|
+
- is_noise=true với câu filler vô nghĩa
|
|
39
|
+
- Chỉ trả JSON, không giải thích
|
|
40
|
+
|
|
41
|
+
CONTEXT: {context}
|
|
42
|
+
MESSAGE: {message}"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LLMExtractor(BaseExtractor):
|
|
46
|
+
"""
|
|
47
|
+
LLM-based entity extractor using Groq (primary) or OpenAI (fallback)
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
extractor = LLMExtractor(config)
|
|
51
|
+
result = await extractor.extract("đặt 2 vé đi đà nẵng ngày mai")
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, config: BusBotConfig):
|
|
55
|
+
self.config = config
|
|
56
|
+
self._groq_client = None
|
|
57
|
+
self._openai_client = None
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def groq_client(self):
|
|
61
|
+
"""Lazy init Groq client"""
|
|
62
|
+
if self._groq_client is None and self.config.groq_api_key:
|
|
63
|
+
try:
|
|
64
|
+
from groq import AsyncGroq
|
|
65
|
+
self._groq_client = AsyncGroq(api_key=self.config.groq_api_key)
|
|
66
|
+
except ImportError:
|
|
67
|
+
logger.warning("groq package not installed")
|
|
68
|
+
return self._groq_client
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def openai_client(self):
|
|
72
|
+
"""Lazy init OpenAI client"""
|
|
73
|
+
if self._openai_client is None and self.config.openai_api_key:
|
|
74
|
+
try:
|
|
75
|
+
from openai import AsyncOpenAI
|
|
76
|
+
self._openai_client = AsyncOpenAI(api_key=self.config.openai_api_key)
|
|
77
|
+
except ImportError:
|
|
78
|
+
logger.warning("openai package not installed")
|
|
79
|
+
return self._openai_client
|
|
80
|
+
|
|
81
|
+
async def extract(
|
|
82
|
+
self,
|
|
83
|
+
message: str,
|
|
84
|
+
context: Optional[str] = None
|
|
85
|
+
) -> ExtractionResult:
|
|
86
|
+
"""
|
|
87
|
+
Extract entities using Groq LLM
|
|
88
|
+
|
|
89
|
+
Falls back to OpenAI if Groq fails and openai_api_key is configured.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
return await self._extract_with_groq(message, context)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.warning(f"Groq extraction failed: {e}")
|
|
95
|
+
|
|
96
|
+
if self.openai_client:
|
|
97
|
+
logger.info("Falling back to OpenAI")
|
|
98
|
+
return await self._extract_with_openai(message, context)
|
|
99
|
+
|
|
100
|
+
raise
|
|
101
|
+
|
|
102
|
+
async def _extract_with_groq(
|
|
103
|
+
self,
|
|
104
|
+
message: str,
|
|
105
|
+
context: Optional[str] = None
|
|
106
|
+
) -> ExtractionResult:
|
|
107
|
+
"""Extract using Groq"""
|
|
108
|
+
if not self.groq_client:
|
|
109
|
+
raise RuntimeError("Groq client not available")
|
|
110
|
+
|
|
111
|
+
prompt = EXTRACTION_PROMPT.format(
|
|
112
|
+
message=message,
|
|
113
|
+
context=context or "Không có context"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
response = await self.groq_client.chat.completions.create(
|
|
118
|
+
model=self.config.groq_model,
|
|
119
|
+
messages=[{"role": "user", "content": prompt}],
|
|
120
|
+
temperature=0.1,
|
|
121
|
+
max_tokens=500,
|
|
122
|
+
response_format={"type": "json_object"},
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
raw = response.choices[0].message.content
|
|
126
|
+
return self._parse_response(raw)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
# Try fallback model
|
|
130
|
+
logger.warning(f"Primary model failed, trying fallback: {e}")
|
|
131
|
+
|
|
132
|
+
response = await self.groq_client.chat.completions.create(
|
|
133
|
+
model=self.config.groq_fallback_model,
|
|
134
|
+
messages=[{"role": "user", "content": prompt}],
|
|
135
|
+
temperature=0.1,
|
|
136
|
+
max_tokens=500,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
raw = response.choices[0].message.content
|
|
140
|
+
return self._parse_response(raw)
|
|
141
|
+
|
|
142
|
+
async def _extract_with_openai(
|
|
143
|
+
self,
|
|
144
|
+
message: str,
|
|
145
|
+
context: Optional[str] = None
|
|
146
|
+
) -> ExtractionResult:
|
|
147
|
+
"""Extract using OpenAI"""
|
|
148
|
+
if not self.openai_client:
|
|
149
|
+
raise RuntimeError("OpenAI client not available")
|
|
150
|
+
|
|
151
|
+
prompt = EXTRACTION_PROMPT.format(
|
|
152
|
+
message=message,
|
|
153
|
+
context=context or "Không có context"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
response = await self.openai_client.chat.completions.create(
|
|
157
|
+
model=self.config.openai_model,
|
|
158
|
+
messages=[{"role": "user", "content": prompt}],
|
|
159
|
+
temperature=0.1,
|
|
160
|
+
max_tokens=500,
|
|
161
|
+
response_format={"type": "json_object"},
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
raw = response.choices[0].message.content
|
|
165
|
+
return self._parse_response(raw)
|
|
166
|
+
|
|
167
|
+
def _parse_response(self, raw: str) -> ExtractionResult:
|
|
168
|
+
"""Parse LLM response into ExtractionResult"""
|
|
169
|
+
try:
|
|
170
|
+
# Clean up response if needed
|
|
171
|
+
raw = raw.strip()
|
|
172
|
+
if raw.startswith("```"):
|
|
173
|
+
raw = raw.split("```")[1]
|
|
174
|
+
if raw.startswith("json"):
|
|
175
|
+
raw = raw[4:]
|
|
176
|
+
|
|
177
|
+
data = json.loads(raw)
|
|
178
|
+
|
|
179
|
+
# Extract entities, removing null values
|
|
180
|
+
entities = {k: v for k, v in data.get("entities", {}).items() if v is not None}
|
|
181
|
+
|
|
182
|
+
return ExtractionResult(
|
|
183
|
+
entities=entities,
|
|
184
|
+
intent=data.get("intent", "unclear"),
|
|
185
|
+
is_noise=data.get("is_noise", False),
|
|
186
|
+
is_change=data.get("is_change", False),
|
|
187
|
+
changed_fields=data.get("changed_fields", []),
|
|
188
|
+
confidence=data.get("confidence", 0.5),
|
|
189
|
+
raw_response=raw,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
except json.JSONDecodeError as e:
|
|
193
|
+
logger.error(f"Failed to parse LLM response: {e}\nRaw: {raw}")
|
|
194
|
+
return ExtractionResult(
|
|
195
|
+
confidence=0.0,
|
|
196
|
+
raw_response=raw,
|
|
197
|
+
)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Regex-based entity extractor (fallback)"""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from busbot_memory.extractors.base import BaseExtractor
|
|
7
|
+
from busbot_memory.core.models import ExtractionResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Vietnam provinces/cities
|
|
11
|
+
LOCATIONS = (
|
|
12
|
+
r"Hà Nội|Hải Phòng|Hải Dương|Nam Định|Ninh Bình|Thái Bình|Vĩnh Phúc|Bắc Ninh|"
|
|
13
|
+
r"Bắc Giang|Quảng Ninh|Lạng Sơn|Cao Bằng|Hòa Bình|Sơn La|Lai Châu|Điện Biên|"
|
|
14
|
+
r"Lào Cai|Yên Bái|Phú Thọ|Tuyên Quang|Thái Nguyên|Hà Giang|Sài Gòn|TP HCM|"
|
|
15
|
+
r"TP\.?\s*HCM|Hồ Chí Minh|Đà Nẵng|Huế|Quảng Nam|Quảng Ngãi|Bình Định|Phú Yên|"
|
|
16
|
+
r"Khánh Hòa|Nha Trang|Ninh Thuận|Bình Thuận|Đà Lạt|Lâm Đồng|Bình Phước|Tây Ninh|"
|
|
17
|
+
r"Bình Dương|Đồng Nai|Vũng Tàu|Bà Rịa|Long An|Tiền Giang|Mỹ Tho|Bến Tre|Trà Vinh|"
|
|
18
|
+
r"Vĩnh Long|Cần Thơ|Đồng Tháp|An Giang|Kiên Giang|Hậu Giang|Sóc Trăng|Bạc Liêu|"
|
|
19
|
+
r"Cà Mau|Nghệ An|Vinh|Hà Tĩnh|Quảng Bình|Quảng Trị|Kon Tum|Gia Lai|Đắk Lắk|"
|
|
20
|
+
r"Đắk Nông|Buôn Ma Thuột"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
PATTERNS = {
|
|
24
|
+
"location": rf"\b({LOCATIONS})\b",
|
|
25
|
+
"time": r"\b(\d{1,2})\s*(?:h|giờ|:|g)\s*(\d{0,2})\b",
|
|
26
|
+
"date": r"(ngày mai|hôm nay|hôm qua|ngày kia|\d{1,2}[/-]\d{1,2}(?:[/-]\d{2,4})?)",
|
|
27
|
+
"phone": r"\b(0\d{9}|0\d{3}[\s.-]?\d{3}[\s.-]?\d{3,4})\b",
|
|
28
|
+
"quantity": r"\b(\d+)\s*(vé|ghế|chỗ|người|suất)\b",
|
|
29
|
+
"seat_type": r"\b(ghế ngồi|giường nằm|limousine|vip|thường)\b",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
NOISE_PATTERNS = [
|
|
33
|
+
r"^(ừ|uh|à|ờ|dạ|vâng|ok|được|rồi|nhé|nha)\s*$",
|
|
34
|
+
r"^(xin chào|chào|hello|hi|alo)\s*",
|
|
35
|
+
r"^(cảm ơn|thanks|thank you)\s*$",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
INTENT_KEYWORDS = {
|
|
39
|
+
"book_ticket": ["đặt", "mua", "lấy", "book"],
|
|
40
|
+
"cancel": ["hủy", "cancel", "không đi nữa"],
|
|
41
|
+
"reschedule": ["đổi", "chuyển", "thay đổi", "dời"],
|
|
42
|
+
"inquiry": ["hỏi", "giá", "mấy giờ", "còn không", "bao nhiêu"],
|
|
43
|
+
"complaint": ["khiếu nại", "phàn nàn", "tệ", "chán"],
|
|
44
|
+
"confirm": ["xác nhận", "đồng ý", "ok", "được"],
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RegexExtractor(BaseExtractor):
|
|
49
|
+
"""
|
|
50
|
+
Regex-based entity extractor
|
|
51
|
+
|
|
52
|
+
Used as fallback when LLM is unavailable or fails.
|
|
53
|
+
Lower accuracy but zero latency and no API costs.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
async def extract(
|
|
57
|
+
self,
|
|
58
|
+
message: str,
|
|
59
|
+
context: Optional[str] = None
|
|
60
|
+
) -> ExtractionResult:
|
|
61
|
+
"""Extract entities using regex patterns"""
|
|
62
|
+
entities = {}
|
|
63
|
+
|
|
64
|
+
# Extract locations
|
|
65
|
+
locations = re.findall(PATTERNS["location"], message, re.IGNORECASE)
|
|
66
|
+
if locations:
|
|
67
|
+
# First location = destination (simplified logic)
|
|
68
|
+
entities["destination"] = locations[0]
|
|
69
|
+
if len(locations) > 1:
|
|
70
|
+
entities["departure"] = locations[1]
|
|
71
|
+
|
|
72
|
+
# Extract time
|
|
73
|
+
time_match = re.search(PATTERNS["time"], message, re.IGNORECASE)
|
|
74
|
+
if time_match:
|
|
75
|
+
hour = time_match.group(1)
|
|
76
|
+
minute = time_match.group(2) or "00"
|
|
77
|
+
# Simple AM/PM detection
|
|
78
|
+
if "chiều" in message or "tối" in message:
|
|
79
|
+
hour = str(int(hour) + 12) if int(hour) < 12 else hour
|
|
80
|
+
entities["time"] = f"{int(hour):02d}:{int(minute):02d}"
|
|
81
|
+
|
|
82
|
+
# Extract date
|
|
83
|
+
date_match = re.search(PATTERNS["date"], message, re.IGNORECASE)
|
|
84
|
+
if date_match:
|
|
85
|
+
entities["date"] = date_match.group(1)
|
|
86
|
+
|
|
87
|
+
# Extract phone
|
|
88
|
+
phone_match = re.search(PATTERNS["phone"], message)
|
|
89
|
+
if phone_match:
|
|
90
|
+
entities["customer_phone"] = phone_match.group(1).replace(" ", "").replace("-", "")
|
|
91
|
+
|
|
92
|
+
# Extract quantity
|
|
93
|
+
qty_match = re.search(PATTERNS["quantity"], message)
|
|
94
|
+
if qty_match:
|
|
95
|
+
entities["quantity"] = int(qty_match.group(1))
|
|
96
|
+
|
|
97
|
+
# Extract seat type
|
|
98
|
+
seat_match = re.search(PATTERNS["seat_type"], message, re.IGNORECASE)
|
|
99
|
+
if seat_match:
|
|
100
|
+
entities["seat_type"] = seat_match.group(1).lower()
|
|
101
|
+
|
|
102
|
+
# Check if noise
|
|
103
|
+
is_noise = any(
|
|
104
|
+
re.match(pattern, message.strip(), re.IGNORECASE)
|
|
105
|
+
for pattern in NOISE_PATTERNS
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Detect intent
|
|
109
|
+
intent = self._detect_intent(message)
|
|
110
|
+
|
|
111
|
+
return ExtractionResult(
|
|
112
|
+
entities=entities,
|
|
113
|
+
intent=intent,
|
|
114
|
+
is_noise=is_noise,
|
|
115
|
+
is_change=False, # Regex can't detect change-of-mind reliably
|
|
116
|
+
changed_fields=[],
|
|
117
|
+
confidence=0.7 if entities else 0.3,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _detect_intent(self, message: str) -> str:
|
|
121
|
+
"""Simple keyword-based intent detection"""
|
|
122
|
+
message_lower = message.lower()
|
|
123
|
+
|
|
124
|
+
for intent, keywords in INTENT_KEYWORDS.items():
|
|
125
|
+
if any(kw in message_lower for kw in keywords):
|
|
126
|
+
return intent
|
|
127
|
+
|
|
128
|
+
return "unclear"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Memory module - placeholder for Phase 2"""
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""State management for booking flow"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, Tuple, List
|
|
5
|
+
|
|
6
|
+
from busbot_memory.core.models import BookingState, ExtractionResult
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class StateManager:
|
|
12
|
+
"""
|
|
13
|
+
Manages booking state updates
|
|
14
|
+
|
|
15
|
+
Handles:
|
|
16
|
+
- Merging new entities into state
|
|
17
|
+
- Tracking changes (change-of-mind)
|
|
18
|
+
- Managing missing slots
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# Default slots to track for bus booking
|
|
22
|
+
DEFAULT_SLOTS = [
|
|
23
|
+
"departure",
|
|
24
|
+
"destination",
|
|
25
|
+
"date",
|
|
26
|
+
"time",
|
|
27
|
+
"quantity",
|
|
28
|
+
"seat_type",
|
|
29
|
+
"customer_name",
|
|
30
|
+
"customer_phone",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
def __init__(self, slot_types: Optional[List[str]] = None):
|
|
34
|
+
self.slot_types = slot_types or self.DEFAULT_SLOTS
|
|
35
|
+
|
|
36
|
+
def update(
|
|
37
|
+
self,
|
|
38
|
+
current_state: BookingState,
|
|
39
|
+
extraction: ExtractionResult
|
|
40
|
+
) -> Tuple[BookingState, List[str]]:
|
|
41
|
+
"""
|
|
42
|
+
Update state with extracted entities
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
current_state: Current booking state
|
|
46
|
+
extraction: Extraction result from LLM/regex
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Tuple of (updated_state, list_of_changes)
|
|
50
|
+
"""
|
|
51
|
+
# Update intent if detected
|
|
52
|
+
if extraction.intent and extraction.intent != "unclear":
|
|
53
|
+
current_state.intent = extraction.intent
|
|
54
|
+
|
|
55
|
+
# Update slots
|
|
56
|
+
changes = current_state.update_slots(extraction.entities)
|
|
57
|
+
|
|
58
|
+
# Track changes from LLM change-of-mind detection
|
|
59
|
+
if extraction.is_change and extraction.changed_fields:
|
|
60
|
+
for field in extraction.changed_fields:
|
|
61
|
+
if field not in [c.split(":")[0] for c in changes]:
|
|
62
|
+
changes.append(f"{field}: changed")
|
|
63
|
+
|
|
64
|
+
# Update missing slots
|
|
65
|
+
current_state.missing_slots = self._calculate_missing_slots(current_state)
|
|
66
|
+
|
|
67
|
+
logger.debug(f"State updated: {len(changes)} changes, {len(current_state.missing_slots)} missing")
|
|
68
|
+
|
|
69
|
+
return current_state, changes
|
|
70
|
+
|
|
71
|
+
def _calculate_missing_slots(self, state: BookingState) -> List[str]:
|
|
72
|
+
"""Calculate which slots are still missing"""
|
|
73
|
+
# For booking intent, we need certain slots
|
|
74
|
+
if state.intent == "book_ticket":
|
|
75
|
+
required = ["destination", "date", "time", "quantity", "customer_name", "customer_phone"]
|
|
76
|
+
elif state.intent == "reschedule":
|
|
77
|
+
required = ["date", "time"]
|
|
78
|
+
elif state.intent == "cancel":
|
|
79
|
+
required = ["customer_phone"]
|
|
80
|
+
else:
|
|
81
|
+
required = ["destination"]
|
|
82
|
+
|
|
83
|
+
missing = [slot for slot in required if slot not in state.slots]
|
|
84
|
+
return missing
|
|
85
|
+
|
|
86
|
+
def create_initial_state(self) -> BookingState:
|
|
87
|
+
"""Create a fresh booking state"""
|
|
88
|
+
return BookingState(
|
|
89
|
+
missing_slots=self._calculate_missing_slots(BookingState())
|
|
90
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Storage module - placeholder for Phase 2"""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utils module - placeholder for Phase 2"""
|
busbot_memory/version.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: busbot-memory
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: LLM-powered working memory for Vietnamese bus booking bots
|
|
5
|
+
Author-email: QuocAnh <quocanhnguyen.work@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/biva-ai/busbot-memory
|
|
8
|
+
Project-URL: Documentation, https://github.com/biva-ai/busbot-memory#readme
|
|
9
|
+
Project-URL: Repository, https://github.com/biva-ai/busbot-memory.git
|
|
10
|
+
Project-URL: Issues, https://github.com/biva-ai/busbot-memory/issues
|
|
11
|
+
Keywords: llm,memory,chatbot,bus-booking,vietnamese,groq,voice-bot
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
Requires-Dist: groq>=0.4.0
|
|
24
|
+
Requires-Dist: pydantic>=2.0.0
|
|
25
|
+
Provides-Extra: redis
|
|
26
|
+
Requires-Dist: redis>=5.0.0; extra == "redis"
|
|
27
|
+
Provides-Extra: openai
|
|
28
|
+
Requires-Dist: openai>=1.0.0; extra == "openai"
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
31
|
+
Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
|
|
32
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
33
|
+
Provides-Extra: all
|
|
34
|
+
Requires-Dist: redis>=5.0.0; extra == "all"
|
|
35
|
+
Requires-Dist: openai>=1.0.0; extra == "all"
|
|
36
|
+
|
|
37
|
+
# BusBotMemory SDK
|
|
38
|
+
|
|
39
|
+
LLM-powered working memory for Vietnamese bus booking bots.
|
|
40
|
+
|
|
41
|
+
## Installation
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install busbot-memory
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Or install from source:
|
|
48
|
+
```bash
|
|
49
|
+
cd busbot-memory
|
|
50
|
+
pip install -e .
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Quick Start
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
import asyncio
|
|
57
|
+
from busbot_memory import BusBotMemory, BusBotConfig
|
|
58
|
+
|
|
59
|
+
async def main():
|
|
60
|
+
# Configure (set GROQ_API_KEY env var or pass directly)
|
|
61
|
+
config = BusBotConfig(groq_api_key="gsk_xxx")
|
|
62
|
+
|
|
63
|
+
# Initialize memory for a session
|
|
64
|
+
memory = BusBotMemory(
|
|
65
|
+
session_id="call_001",
|
|
66
|
+
customer_id="0987654321",
|
|
67
|
+
config=config
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Process messages
|
|
71
|
+
result = await memory.process("đặt 2 vé đi đà nẵng ngày mai 8h sáng")
|
|
72
|
+
|
|
73
|
+
print(result.state.slots)
|
|
74
|
+
# {"destination": "Đà Nẵng", "date": "ngày mai", "time": "08:00", "quantity": 2}
|
|
75
|
+
|
|
76
|
+
asyncio.run(main())
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## Features
|
|
80
|
+
|
|
81
|
+
- **LLM-powered extraction**: Uses Groq (llama-3.3-70b) for accurate entity extraction
|
|
82
|
+
- **Change-of-mind detection**: Automatically detects when user changes their booking
|
|
83
|
+
- **State tracking**: Maintains structured booking state with missing slot tracking
|
|
84
|
+
- **Low latency**: Optimized for < 250ms processing time
|
|
85
|
+
- **Fallback support**: Falls back to regex when LLM is unavailable
|
|
86
|
+
|
|
87
|
+
## Configuration
|
|
88
|
+
|
|
89
|
+
```python
|
|
90
|
+
config = BusBotConfig(
|
|
91
|
+
# LLM Provider (at least one required)
|
|
92
|
+
groq_api_key="gsk_xxx", # Primary - fast & free
|
|
93
|
+
openai_api_key="sk-xxx", # Optional fallback
|
|
94
|
+
|
|
95
|
+
# Performance
|
|
96
|
+
latency_target_ms=250,
|
|
97
|
+
enable_fallback=True, # Use regex if LLM fails
|
|
98
|
+
|
|
99
|
+
# Memory settings
|
|
100
|
+
max_working_items=20,
|
|
101
|
+
max_context_window=5,
|
|
102
|
+
)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
## ProcessResult
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
result = await memory.process(message)
|
|
109
|
+
|
|
110
|
+
result.entities # Extracted entities
|
|
111
|
+
result.state # BookingState object
|
|
112
|
+
result.is_noise # Is filler message ("ừ", "ok")
|
|
113
|
+
result.is_change # Did user change their mind
|
|
114
|
+
result.changes # List of changes made
|
|
115
|
+
result.intent # Detected intent
|
|
116
|
+
result.latency_ms # Processing time
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
## License
|
|
120
|
+
|
|
121
|
+
MIT
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
busbot_memory/__init__.py,sha256=dB6vdBKqfcsq3u_lmHND8eg7M9u1T2OtMPtzvxKVbXQ,436
|
|
2
|
+
busbot_memory/version.py,sha256=FtxHurtf6YcDi4rEA3wEUJiwTKc8mZ9jDZQbq6lQxMg,48
|
|
3
|
+
busbot_memory/core/__init__.py,sha256=feszx5rF9gi4eSP9dlgeN2a-xTfwgdS0ZM3tQNQ9KG8,381
|
|
4
|
+
busbot_memory/core/config.py,sha256=Q-ffA9lM_7zVHep5qMIy8AfhjIGSWiyUW2wE0NGmPCM,1715
|
|
5
|
+
busbot_memory/core/manager.py,sha256=djJJ3C-DbxZqDHiRL-vYnzGyM9xZvSdhfZ8_x1PhxQE,8094
|
|
6
|
+
busbot_memory/core/models.py,sha256=8VM-xSBos839arFjp7vKRJsK3e-rMnZrjNU46nmwX7A,6211
|
|
7
|
+
busbot_memory/domains/__init__.py,sha256=Z7zCVBmCcXi56N5CMnM3KCvYm7TjCvRnv1i0zOrhjlU,47
|
|
8
|
+
busbot_memory/extractors/__init__.py,sha256=Z_lj6BkF9P-6PQPgijQayrkMR7m3jp3Emgx1Pf8L20I,264
|
|
9
|
+
busbot_memory/extractors/base.py,sha256=iBKnOXRo9WW1R-2kqtKkjs2n_YFtD6paoar7mV_AJLM,671
|
|
10
|
+
busbot_memory/extractors/llm.py,sha256=G-qkv6IuVZWfAzo4l6DR5o8fsuSd1UUw5auln-qltdU,6847
|
|
11
|
+
busbot_memory/extractors/regex.py,sha256=AMTzQcrDlkJsYzOSJQNeM1P8Vma6dViRFrsapBgnz_U,4933
|
|
12
|
+
busbot_memory/memory/__init__.py,sha256=Pirbd_SpJaBnt1l7tpXT86QHzrhQ--Qg2gJuP_jHnQU,46
|
|
13
|
+
busbot_memory/state/__init__.py,sha256=O-f2tTE2NhyhFv-GT7Iytb1w_SpPARLbT894QqU0u0c,109
|
|
14
|
+
busbot_memory/state/manager.py,sha256=E_KP3RNCTA4HneYTpuuk-M7DroDIFlJY-UtuEdsRaWM,2915
|
|
15
|
+
busbot_memory/storage/__init__.py,sha256=_iMe3AwItiW2RoimbzphMsSwqccX6GyAZFAqh1UEeuo,47
|
|
16
|
+
busbot_memory/utils/__init__.py,sha256=7wzd5GIIgAfEo9u9lRTpU6FdIKdQslbNHfc_4Cce_9Q,45
|
|
17
|
+
busbot_memory-0.1.0.dist-info/METADATA,sha256=HavIxnEJ4BInBd6Di_uYlS7-lXnHVXKdma7viWLqgAU,3628
|
|
18
|
+
busbot_memory-0.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
19
|
+
busbot_memory-0.1.0.dist-info/top_level.txt,sha256=s1GaBVqTS_djIs8JsgAg6-Z2vOIkCY4NRM-avkk-M04,14
|
|
20
|
+
busbot_memory-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
busbot_memory
|