agentkit-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentkit/__init__.py +3 -0
- agentkit/cli/__init__.py +0 -0
- agentkit/cli/build.py +54 -0
- agentkit/cli/deploy.py +50 -0
- agentkit/cli/init.py +85 -0
- agentkit/cli/main.py +22 -0
- agentkit/cli/serve.py +35 -0
- agentkit/core/__init__.py +3 -0
- agentkit/core/context.py +75 -0
- agentkit/core/pipeline.py +132 -0
- agentkit/core/turn_detector.py +49 -0
- agentkit/learning/__init__.py +2 -0
- agentkit/learning/correction.py +109 -0
- agentkit/learning/recommender.py +76 -0
- agentkit/memory/__init__.py +5 -0
- agentkit/memory/base.py +114 -0
- agentkit/memory/extractor.py +79 -0
- agentkit/memory/markdown.py +109 -0
- agentkit/memory/user_model.py +1 -0
- agentkit/memory/vector.py +96 -0
- agentkit/providers/llm/__init__.py +5 -0
- agentkit/providers/llm/base.py +36 -0
- agentkit/providers/llm/gemini.py +59 -0
- agentkit/providers/llm/openai.py +59 -0
- agentkit/providers/stt/__init__.py +5 -0
- agentkit/providers/stt/base.py +21 -0
- agentkit/providers/stt/deepgram.py +53 -0
- agentkit/providers/stt/sarvam.py +45 -0
- agentkit/providers/tts/__init__.py +5 -0
- agentkit/providers/tts/base.py +21 -0
- agentkit/providers/tts/elevenlabs.py +40 -0
- agentkit/providers/tts/sarvam.py +43 -0
- agentkit/server/__init__.py +1 -0
- agentkit/server/app.py +189 -0
- agentkit_sdk-0.1.0.dist-info/METADATA +108 -0
- agentkit_sdk-0.1.0.dist-info/RECORD +39 -0
- agentkit_sdk-0.1.0.dist-info/WHEEL +5 -0
- agentkit_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- agentkit_sdk-0.1.0.dist-info/top_level.txt +1 -0
agentkit/__init__.py
ADDED
agentkit/cli/__init__.py
ADDED
|
File without changes
|
agentkit/cli/build.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import subprocess
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@click.command()
|
|
11
|
+
@click.option("--config", default="agent.config.yaml", help="Config file path")
|
|
12
|
+
@click.option("--output", default="build/", help="Output directory")
|
|
13
|
+
def android(config: str, output: str):
|
|
14
|
+
"""Build Android APK from AgentShell template."""
|
|
15
|
+
|
|
16
|
+
if not Path(config).exists():
|
|
17
|
+
click.echo(f"Error: Config file '{config}' not found.")
|
|
18
|
+
sys.exit(1)
|
|
19
|
+
|
|
20
|
+
click.echo("Building Android APK...")
|
|
21
|
+
|
|
22
|
+
import yaml
|
|
23
|
+
with open(config) as f:
|
|
24
|
+
cfg = yaml.safe_load(f)
|
|
25
|
+
|
|
26
|
+
agent_name = cfg.get("agent", {}).get("name", "Agent")
|
|
27
|
+
backend_url = cfg.get("deployment", {}).get("backend_url", "ws://localhost:8000/ws/voice")
|
|
28
|
+
|
|
29
|
+
template_dir = Path(__file__).parent.parent / "mobile" / "AgentShell"
|
|
30
|
+
|
|
31
|
+
if not template_dir.exists():
|
|
32
|
+
click.echo("Error: AgentShell template not found.")
|
|
33
|
+
sys.exit(1)
|
|
34
|
+
|
|
35
|
+
project_dir = Path(output)
|
|
36
|
+
project_dir.mkdir(parents=True, exist_ok=True)
|
|
37
|
+
|
|
38
|
+
shutil.copytree(template_dir, project_dir / "AgentShell", dirs_exist_ok=True)
|
|
39
|
+
|
|
40
|
+
app_js = project_dir / "AgentShell" / "App.jsx"
|
|
41
|
+
if app_js.exists():
|
|
42
|
+
content = app_js.read_text()
|
|
43
|
+
content = content.replace("ws://localhost:8000/ws/voice", backend_url)
|
|
44
|
+
content = content.replace("Assistant", agent_name)
|
|
45
|
+
app_js.write_text(content)
|
|
46
|
+
|
|
47
|
+
click.echo("Installing dependencies...")
|
|
48
|
+
os.chdir(project_dir / "AgentShell")
|
|
49
|
+
subprocess.run(["npm", "install"], check=True)
|
|
50
|
+
|
|
51
|
+
click.echo("Building Android APK...")
|
|
52
|
+
subprocess.run(["npx", "expo", "run:android"], check=True)
|
|
53
|
+
|
|
54
|
+
click.echo(f"APK built at: {project_dir / "AgentShell/android/app/build/outputs/apk/release/"}")
|
agentkit/cli/deploy.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
PLATFORMS = ["railway", "render"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@click.command()
|
|
11
|
+
@click.option("--platform", type=click.Choice(PLATFORMS), required=True, help="Deployment platform")
|
|
12
|
+
@click.option("--config", default="agent.config.yaml", help="Config file path")
|
|
13
|
+
def deploy(platform: str, config: str):
|
|
14
|
+
"""Deploy backend to Railway or Render."""
|
|
15
|
+
|
|
16
|
+
click.echo(f"Deploying to {platform}...")
|
|
17
|
+
|
|
18
|
+
if platform == "railway":
|
|
19
|
+
_deploy_railway()
|
|
20
|
+
elif platform == "render":
|
|
21
|
+
_deploy_render()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _deploy_railway():
|
|
25
|
+
try:
|
|
26
|
+
result = subprocess.run(["railway", "login"], capture_output=True)
|
|
27
|
+
if result.returncode != 0:
|
|
28
|
+
click.echo("Please install Railway CLI: npm install -g @railway/cli")
|
|
29
|
+
sys.exit(1)
|
|
30
|
+
|
|
31
|
+
subprocess.run(["railway", "init"], check=True)
|
|
32
|
+
subprocess.run(["railway", "up"], check=True)
|
|
33
|
+
|
|
34
|
+
click.echo("Deployed to Railway! Run 'railway link' to connect.")
|
|
35
|
+
except FileNotFoundError:
|
|
36
|
+
click.echo("Railway CLI not found. Install with: npm install -g @railway/cli")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _deploy_render():
|
|
40
|
+
try:
|
|
41
|
+
result = subprocess.run(["render", "--version"], capture_output=True)
|
|
42
|
+
if result.returncode != 0:
|
|
43
|
+
click.echo("Please install Render CLI: npm install -g render-cli")
|
|
44
|
+
sys.exit(1)
|
|
45
|
+
|
|
46
|
+
subprocess.run(["render", "blueprint", "init"], check=True)
|
|
47
|
+
|
|
48
|
+
click.echo("Deployed to Render! Check your dashboard for the live URL.")
|
|
49
|
+
except FileNotFoundError:
|
|
50
|
+
click.echo("Render CLI not found. Install with: npm install -g render-cli")
|
agentkit/cli/init.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.command()
|
|
10
|
+
@click.option("--name", default="my-agent", help="Agent name")
|
|
11
|
+
@click.option("--output-dir", default=".", help="Output directory")
|
|
12
|
+
def init(name: str, output_dir: str):
|
|
13
|
+
"""Initialize a new AgentKit project."""
|
|
14
|
+
|
|
15
|
+
project_dir = Path(output_dir) / name
|
|
16
|
+
project_dir.mkdir(parents=True, exist_ok=True)
|
|
17
|
+
|
|
18
|
+
config_content = {
|
|
19
|
+
"agent": {
|
|
20
|
+
"name": name,
|
|
21
|
+
"persona": "You are a helpful personal assistant that remembers conversations and learns from corrections.",
|
|
22
|
+
"language": "hinglish"
|
|
23
|
+
},
|
|
24
|
+
"voice": {
|
|
25
|
+
"enabled": True,
|
|
26
|
+
"stt": {
|
|
27
|
+
"provider": "sarvam",
|
|
28
|
+
"api_key": "${SARVAM_API_KEY}"
|
|
29
|
+
},
|
|
30
|
+
"tts": {
|
|
31
|
+
"provider": "sarvam",
|
|
32
|
+
"voice": "meera",
|
|
33
|
+
"api_key": "${SARVAM_API_KEY}"
|
|
34
|
+
}
|
|
35
|
+
},
|
|
36
|
+
"llm": {
|
|
37
|
+
"provider": "gemini",
|
|
38
|
+
"model": "gemini-2.0-flash",
|
|
39
|
+
"api_key": "${GEMINI_API_KEY}",
|
|
40
|
+
"temperature": 0.7
|
|
41
|
+
},
|
|
42
|
+
"memory": {
|
|
43
|
+
"type": "markdown",
|
|
44
|
+
"backend": "local",
|
|
45
|
+
"episodic_window": 20,
|
|
46
|
+
"semantic_top_k": 5
|
|
47
|
+
},
|
|
48
|
+
"learning": {
|
|
49
|
+
"enabled": True,
|
|
50
|
+
"correction_detection": True,
|
|
51
|
+
"implicit_feedback": True,
|
|
52
|
+
"profile_extraction": True
|
|
53
|
+
},
|
|
54
|
+
"deployment": {
|
|
55
|
+
"type": "self-host",
|
|
56
|
+
"port": 8000
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
config_path = project_dir / "agent.config.yaml"
|
|
61
|
+
with open(config_path, "w") as f:
|
|
62
|
+
yaml.dump(config_content, f, default_flow_style=False)
|
|
63
|
+
|
|
64
|
+
env_example = """# API Keys
|
|
65
|
+
SARVAM_API_KEY=your_sarvam_key
|
|
66
|
+
GEMINI_API_KEY=your_gemini_key
|
|
67
|
+
DEEPGRAM_API_KEY=your_deepgram_key
|
|
68
|
+
ELEVENLABS_API_KEY=your_elevenlabs_key
|
|
69
|
+
OPENAI_API_KEY=your_openai_key
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
env_path = project_dir / ".env.example"
|
|
73
|
+
with open(env_path, "w") as f:
|
|
74
|
+
f.write(env_example)
|
|
75
|
+
|
|
76
|
+
src_dir = project_dir / "src"
|
|
77
|
+
src_dir.mkdir(exist_ok=True)
|
|
78
|
+
|
|
79
|
+
(src_dir / "__init__.py").touch()
|
|
80
|
+
|
|
81
|
+
click.echo(f"AgentKit project '{name}' created at {project_dir}")
|
|
82
|
+
click.echo(f"Next steps:")
|
|
83
|
+
click.echo(f" 1. cd {project_dir}")
|
|
84
|
+
click.echo(f" 2. cp .env.example .env and add your API keys")
|
|
85
|
+
click.echo(f" 3. agentkit serve")
|
agentkit/cli/main.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from agentkit.cli.init import init
|
|
4
|
+
from agentkit.cli.serve import serve
|
|
5
|
+
from agentkit.cli.build import android
|
|
6
|
+
from agentkit.cli.deploy import deploy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group()
|
|
10
|
+
def cli():
|
|
11
|
+
"""AgentKit - Build personalized voice AI assistants"""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
cli.add_command(init)
|
|
16
|
+
cli.add_command(serve)
|
|
17
|
+
cli.add_command(android)
|
|
18
|
+
cli.add_command(deploy)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
if __name__ == "__main__":
|
|
22
|
+
cli()
|
agentkit/cli/serve.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.command()
|
|
10
|
+
@click.option("--config", default="agent.config.yaml", help="Config file path")
|
|
11
|
+
@click.option("--port", default=8000, help="Server port")
|
|
12
|
+
@click.option("--reload", is_flag=True, help="Enable auto-reload")
|
|
13
|
+
def serve(config: str, port: int, reload: bool):
|
|
14
|
+
"""Start the AgentKit server."""
|
|
15
|
+
|
|
16
|
+
if not Path(config).exists():
|
|
17
|
+
click.echo(f"Error: Config file '{config}' not found.")
|
|
18
|
+
click.echo("Run 'agentkit init' first to create a project.")
|
|
19
|
+
sys.exit(1)
|
|
20
|
+
|
|
21
|
+
load_dotenv()
|
|
22
|
+
|
|
23
|
+
import uvicorn
|
|
24
|
+
from agentkit.server.app import app
|
|
25
|
+
|
|
26
|
+
click.echo(f"Starting AgentKit server on port {port}...")
|
|
27
|
+
click.echo(f"Playground available at http://localhost:{port}/playground")
|
|
28
|
+
|
|
29
|
+
uvicorn.run(
|
|
30
|
+
"agentkit.server.app:app",
|
|
31
|
+
host="0.0.0.0",
|
|
32
|
+
port=port,
|
|
33
|
+
reload=reload,
|
|
34
|
+
env_file=".env",
|
|
35
|
+
)
|
agentkit/core/context.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class MemoryContext:
|
|
6
|
+
user_profile_summary: str = ""
|
|
7
|
+
semantic_episodes: str = ""
|
|
8
|
+
correction_rules: str = ""
|
|
9
|
+
recent_turns: str = ""
|
|
10
|
+
|
|
11
|
+
def assemble(self) -> str:
|
|
12
|
+
parts = []
|
|
13
|
+
if self.user_profile_summary:
|
|
14
|
+
parts.append(f"User Profile:\n{self.user_profile_summary}")
|
|
15
|
+
if self.semantic_episodes:
|
|
16
|
+
parts.append(f"Relevant History:\n{self.semantic_episodes}")
|
|
17
|
+
if self.correction_rules:
|
|
18
|
+
parts.append(f"Correction Rules:\n{self.correction_rules}")
|
|
19
|
+
if self.recent_turns:
|
|
20
|
+
parts.append(f"Recent Conversation:\n{self.recent_turns}")
|
|
21
|
+
return "\n\n".join(parts)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ContextAssembler:
|
|
25
|
+
def __init__(self, memory):
|
|
26
|
+
self.memory = memory
|
|
27
|
+
|
|
28
|
+
async def assemble(self, query: str, user_id: str) -> MemoryContext:
|
|
29
|
+
context = MemoryContext()
|
|
30
|
+
|
|
31
|
+
user_model = await self.memory.get_user_model(user_id)
|
|
32
|
+
if user_model:
|
|
33
|
+
context.user_profile_summary = self._summarize_user_model(user_model)
|
|
34
|
+
|
|
35
|
+
semantic_results = await self.memory.retrieve(query, user_id)
|
|
36
|
+
if semantic_results:
|
|
37
|
+
context.semantic_episodes = self._format_semantic_results(semantic_results)
|
|
38
|
+
|
|
39
|
+
if user_model and user_model.correction_rules:
|
|
40
|
+
context.correction_rules = self._format_corrections(user_model.correction_rules)
|
|
41
|
+
|
|
42
|
+
recent_turns = await self.memory.get_recent_turns(user_id, limit=20)
|
|
43
|
+
if recent_turns:
|
|
44
|
+
context.recent_turns = self._format_recent_turns(recent_turns)
|
|
45
|
+
|
|
46
|
+
return context
|
|
47
|
+
|
|
48
|
+
def _summarize_user_model(self, user_model) -> str:
|
|
49
|
+
parts = [f"Name: {user_model.name}"]
|
|
50
|
+
if user_model.communication_style:
|
|
51
|
+
parts.append(f"Style: {user_model.communication_style}")
|
|
52
|
+
if user_model.inferred_interests:
|
|
53
|
+
parts.append(f"Interests: {', '.join(user_model.inferred_interests[:5])}")
|
|
54
|
+
if user_model.stated_goals:
|
|
55
|
+
parts.append(f"Goals: {', '.join(user_model.stated_goals[:3])}")
|
|
56
|
+
return "\n".join(parts)
|
|
57
|
+
|
|
58
|
+
def _format_semantic_results(self, results: list) -> str:
|
|
59
|
+
formatted = []
|
|
60
|
+
for r in results[:5]:
|
|
61
|
+
formatted.append(f"- {r.get('content', '')[:200]}")
|
|
62
|
+
return "\n".join(formatted)
|
|
63
|
+
|
|
64
|
+
def _format_corrections(self, corrections: list) -> str:
|
|
65
|
+
formatted = []
|
|
66
|
+
for c in corrections[:10]:
|
|
67
|
+
formatted.append(f"- {c.get('rule', '')}")
|
|
68
|
+
return "\n".join(formatted)
|
|
69
|
+
|
|
70
|
+
def _format_recent_turns(self, turns: list) -> str:
|
|
71
|
+
formatted = []
|
|
72
|
+
for turn in turns:
|
|
73
|
+
formatted.append(f"User: {turn.get('user_message', '')}")
|
|
74
|
+
formatted.append(f"Assistant: {turn.get('assistant_message', '')}")
|
|
75
|
+
return "\n".join(formatted[-40:])
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import AsyncIterator
|
|
4
|
+
|
|
5
|
+
from agentkit.providers.llm import BaseLLM, Message
|
|
6
|
+
from agentkit.providers.stt import BaseSTT
|
|
7
|
+
from agentkit.providers.tts import BaseTTS
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class Turn:
|
|
12
|
+
user_message: str
|
|
13
|
+
assistant_message: str
|
|
14
|
+
timestamp: float
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class PipelineConfig:
|
|
19
|
+
persona: str
|
|
20
|
+
language: str = "hinglish"
|
|
21
|
+
sentence_endings: list[str] = field(default_factory=lambda: ['.', '!', '?', '।'])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VoicePipeline:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
stt: BaseSTT,
|
|
28
|
+
llm: BaseLLM,
|
|
29
|
+
tts: BaseTTS,
|
|
30
|
+
config: PipelineConfig,
|
|
31
|
+
):
|
|
32
|
+
self.stt = stt
|
|
33
|
+
self.llm = llm
|
|
34
|
+
self.tts = tts
|
|
35
|
+
self.config = config
|
|
36
|
+
self.conversation_history: list[Message] = []
|
|
37
|
+
self.turns: list[Turn] = []
|
|
38
|
+
|
|
39
|
+
async def process_voice_stream(
|
|
40
|
+
self,
|
|
41
|
+
audio_stream: AsyncIterator[bytes],
|
|
42
|
+
memory_context: str = ""
|
|
43
|
+
) -> AsyncIterator[bytes]:
|
|
44
|
+
"""Main streaming pipeline: STT → LLM → TTS with sentence-level streaming."""
|
|
45
|
+
|
|
46
|
+
transcription = ""
|
|
47
|
+
async for text in self.stt.transcribe_stream(audio_stream):
|
|
48
|
+
transcription += text
|
|
49
|
+
|
|
50
|
+
if not transcription.strip():
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
self.conversation_history.append(Message(role="user", content=transcription))
|
|
54
|
+
|
|
55
|
+
llm_stream = self.llm.chat_stream(
|
|
56
|
+
messages=self.conversation_history,
|
|
57
|
+
system=self.config.persona,
|
|
58
|
+
memory_context=memory_context,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
sentence_buffer = ""
|
|
62
|
+
async for token in llm_stream:
|
|
63
|
+
sentence_buffer += token
|
|
64
|
+
|
|
65
|
+
if self._is_sentence_end(sentence_buffer) and len(sentence_buffer) >= 10:
|
|
66
|
+
tts_stream = self.tts.synthesize_stream(sentence_buffer)
|
|
67
|
+
async for audio_chunk in tts_stream:
|
|
68
|
+
yield audio_chunk
|
|
69
|
+
sentence_buffer = ""
|
|
70
|
+
|
|
71
|
+
if sentence_buffer.strip():
|
|
72
|
+
tts_stream = self.tts.synthesize_stream(sentence_buffer)
|
|
73
|
+
async for audio_chunk in tts_stream:
|
|
74
|
+
yield audio_chunk
|
|
75
|
+
|
|
76
|
+
assistant_response = self._get_full_response()
|
|
77
|
+
self.conversation_history.append(Message(role="assistant", content=assistant_response))
|
|
78
|
+
self.turns.append(Turn(user_message=transcription, assistant_message=assistant_response))
|
|
79
|
+
|
|
80
|
+
async def process_text(
|
|
81
|
+
self,
|
|
82
|
+
user_message: str,
|
|
83
|
+
memory_context: str = ""
|
|
84
|
+
) -> AsyncIterator[bytes]:
|
|
85
|
+
"""Process text input and stream audio response."""
|
|
86
|
+
|
|
87
|
+
self.conversation_history.append(Message(role="user", content=user_message))
|
|
88
|
+
|
|
89
|
+
llm_stream = self.llm.chat_stream(
|
|
90
|
+
messages=self.conversation_history,
|
|
91
|
+
system=self.config.persona,
|
|
92
|
+
memory_context=memory_context,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
sentence_buffer = ""
|
|
96
|
+
async for token in llm_stream:
|
|
97
|
+
sentence_buffer += token
|
|
98
|
+
|
|
99
|
+
if self._is_sentence_end(sentence_buffer) and len(sentence_buffer) >= 10:
|
|
100
|
+
tts_stream = self.tts.synthesize_stream(sentence_buffer)
|
|
101
|
+
async for audio_chunk in tts_stream:
|
|
102
|
+
yield audio_chunk
|
|
103
|
+
sentence_buffer = ""
|
|
104
|
+
|
|
105
|
+
if sentence_buffer.strip():
|
|
106
|
+
tts_stream = self.tts.synthesize_stream(sentence_buffer)
|
|
107
|
+
async for audio_chunk in tts_stream:
|
|
108
|
+
yield audio_chunk
|
|
109
|
+
|
|
110
|
+
assistant_response = self._get_full_response()
|
|
111
|
+
self.conversation_history.append(Message(role="assistant", content=assistant_response))
|
|
112
|
+
self.turns.append(Turn(user_message=user_message, assistant_message=assistant_response))
|
|
113
|
+
|
|
114
|
+
def _is_sentence_end(self, text: str) -> bool:
|
|
115
|
+
for ending in self.config.sentence_endings:
|
|
116
|
+
if text.rstrip().endswith(ending):
|
|
117
|
+
return True
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def _get_full_response(self) -> str:
|
|
121
|
+
return " ".join([
|
|
122
|
+
m.content for m in self.conversation_history
|
|
123
|
+
if m.role == "assistant"
|
|
124
|
+
])
|
|
125
|
+
|
|
126
|
+
def get_recent_turns(self, count: int = 20) -> list[Turn]:
|
|
127
|
+
return self.turns[-count:]
|
|
128
|
+
|
|
129
|
+
async def close(self):
|
|
130
|
+
await self.stt.close()
|
|
131
|
+
await self.llm.close()
|
|
132
|
+
await self.tts.close()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import AsyncIterator
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TurnDetector:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
silence_threshold: float = 1.5,
|
|
9
|
+
energy_threshold: float = 0.01,
|
|
10
|
+
):
|
|
11
|
+
self.silence_threshold = silence_threshold
|
|
12
|
+
self.energy_threshold = energy_threshold
|
|
13
|
+
|
|
14
|
+
async def detect_turn_end(
|
|
15
|
+
self,
|
|
16
|
+
audio_stream: AsyncIterator[bytes],
|
|
17
|
+
sample_rate: int = 16000
|
|
18
|
+
) -> AsyncIterator[bytes]:
|
|
19
|
+
"""Detect when user has finished speaking based on silence."""
|
|
20
|
+
|
|
21
|
+
buffer = bytearray()
|
|
22
|
+
silence_duration = 0.0
|
|
23
|
+
chunk_duration = 0.02
|
|
24
|
+
bytes_per_chunk = int(sample_rate * chunk_duration * 2)
|
|
25
|
+
|
|
26
|
+
async for chunk in audio_stream:
|
|
27
|
+
buffer.extend(chunk)
|
|
28
|
+
|
|
29
|
+
energy = self._calculate_energy(bytes(chunk[-bytes_per_chunk:]))
|
|
30
|
+
|
|
31
|
+
if energy < self.energy_threshold:
|
|
32
|
+
silence_duration += chunk_duration
|
|
33
|
+
if silence_duration >= self.silence_threshold:
|
|
34
|
+
if buffer[:-bytes_per_chunk]:
|
|
35
|
+
yield bytes(buffer[:-bytes_per_chunk])
|
|
36
|
+
buffer.clear()
|
|
37
|
+
break
|
|
38
|
+
else:
|
|
39
|
+
silence_duration = 0.0
|
|
40
|
+
|
|
41
|
+
if buffer:
|
|
42
|
+
yield bytes(buffer)
|
|
43
|
+
|
|
44
|
+
def _calculate_energy(self, audio_bytes: bytes) -> float:
|
|
45
|
+
if not audio_bytes:
|
|
46
|
+
return 0.0
|
|
47
|
+
import struct
|
|
48
|
+
samples = struct.unpack(f"{len(audio_bytes)//2}h", audio_bytes)
|
|
49
|
+
return sum(abs(s) for s in samples) / len(samples) / 32768.0
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
EXPLICIT_PATTERNS = [
|
|
8
|
+
"that's wrong", "not correct", "actually", "I said", "no wait",
|
|
9
|
+
"wrong", "incorrect", "that's not right", "I meant", "forget that",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DetectedCorrection:
|
|
15
|
+
type: str
|
|
16
|
+
rule: str
|
|
17
|
+
original_intent: str
|
|
18
|
+
corrected_intent: str
|
|
19
|
+
confidence: float
|
|
20
|
+
turn_id: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CorrectionDetector:
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self.last_user_query = ""
|
|
26
|
+
self.last_agent_response_time = 0.0
|
|
27
|
+
|
|
28
|
+
def detect_explicit(self, user_message: str) -> DetectedCorrection | None:
|
|
29
|
+
"""Detect explicit corrections like 'that's wrong', 'not correct'."""
|
|
30
|
+
user_lower = user_message.lower()
|
|
31
|
+
|
|
32
|
+
for pattern in EXPLICIT_PATTERNS:
|
|
33
|
+
if pattern in user_lower:
|
|
34
|
+
return DetectedCorrection(
|
|
35
|
+
type="explicit",
|
|
36
|
+
rule=f"Never: {pattern}",
|
|
37
|
+
original_intent="",
|
|
38
|
+
corrected_intent=user_message,
|
|
39
|
+
confidence=0.9,
|
|
40
|
+
turn_id=str(int(time.time())),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
def detect_implicit(
|
|
46
|
+
self,
|
|
47
|
+
user_message: str,
|
|
48
|
+
response_time: float,
|
|
49
|
+
last_user_query: str | None = None
|
|
50
|
+
) -> DetectedCorrection | None:
|
|
51
|
+
"""Detect implicit corrections - user rephrases after agent response."""
|
|
52
|
+
|
|
53
|
+
if last_user_query is None:
|
|
54
|
+
last_user_query = self.last_user_query
|
|
55
|
+
|
|
56
|
+
if response_time < 4.0 and last_user_query:
|
|
57
|
+
similarity = self._cosine_similarity(user_message, last_user_query)
|
|
58
|
+
|
|
59
|
+
if similarity > 0.8:
|
|
60
|
+
return DetectedCorrection(
|
|
61
|
+
type="implicit",
|
|
62
|
+
rule="User rephrased - original intent was X not Y",
|
|
63
|
+
original_intent=last_user_query,
|
|
64
|
+
corrected_intent=user_message,
|
|
65
|
+
confidence=0.7,
|
|
66
|
+
turn_id=str(int(time.time())),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
def _cosine_similarity(self, text1: str, text2: str) -> float:
|
|
72
|
+
words1 = set(text1.lower().split())
|
|
73
|
+
words2 = set(text2.lower().split())
|
|
74
|
+
|
|
75
|
+
if not words1 or not words2:
|
|
76
|
+
return 0.0
|
|
77
|
+
|
|
78
|
+
intersection = words1 & words2
|
|
79
|
+
union = words1 | words2
|
|
80
|
+
|
|
81
|
+
return len(intersection) / len(union) if union else 0.0
|
|
82
|
+
|
|
83
|
+
def update_context(self, user_message: str):
|
|
84
|
+
self.last_user_query = user_message
|
|
85
|
+
self.last_agent_response_time = time.time()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class CorrectionStore:
|
|
89
|
+
def __init__(self, memory):
|
|
90
|
+
self.memory = memory
|
|
91
|
+
|
|
92
|
+
async def store_correction(self, correction: DetectedCorrection, user_id: str):
|
|
93
|
+
from agentkit.memory.base import Correction
|
|
94
|
+
|
|
95
|
+
rule = Correction(
|
|
96
|
+
rule=correction.rule,
|
|
97
|
+
original_intent=correction.original_intent,
|
|
98
|
+
corrected_intent=correction.corrected_intent,
|
|
99
|
+
source_turn_id=correction.turn_id,
|
|
100
|
+
created_at=datetime.now(),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
user_model = await self.memory.get_user_model(user_id)
|
|
104
|
+
if user_model:
|
|
105
|
+
user_model.correction_rules.append(rule)
|
|
106
|
+
await self.memory.update_user_model(
|
|
107
|
+
user_id,
|
|
108
|
+
correction_rules=user_model.correction_rules
|
|
109
|
+
)
|