astraagent 2.25.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.env.template +22 -0
- package/LICENSE +21 -0
- package/README.md +333 -0
- package/astra/__init__.py +15 -0
- package/astra/__pycache__/__init__.cpython-314.pyc +0 -0
- package/astra/__pycache__/chat.cpython-314.pyc +0 -0
- package/astra/__pycache__/cli.cpython-314.pyc +0 -0
- package/astra/__pycache__/prompts.cpython-314.pyc +0 -0
- package/astra/__pycache__/updater.cpython-314.pyc +0 -0
- package/astra/chat.py +763 -0
- package/astra/cli.py +913 -0
- package/astra/core/__init__.py +8 -0
- package/astra/core/__pycache__/__init__.cpython-314.pyc +0 -0
- package/astra/core/__pycache__/agent.cpython-314.pyc +0 -0
- package/astra/core/__pycache__/config.cpython-314.pyc +0 -0
- package/astra/core/__pycache__/memory.cpython-314.pyc +0 -0
- package/astra/core/__pycache__/reasoning.cpython-314.pyc +0 -0
- package/astra/core/__pycache__/state.cpython-314.pyc +0 -0
- package/astra/core/agent.py +515 -0
- package/astra/core/config.py +247 -0
- package/astra/core/memory.py +782 -0
- package/astra/core/reasoning.py +423 -0
- package/astra/core/state.py +366 -0
- package/astra/core/voice.py +144 -0
- package/astra/llm/__init__.py +32 -0
- package/astra/llm/__pycache__/__init__.cpython-314.pyc +0 -0
- package/astra/llm/__pycache__/providers.cpython-314.pyc +0 -0
- package/astra/llm/providers.py +530 -0
- package/astra/planning/__init__.py +117 -0
- package/astra/prompts.py +289 -0
- package/astra/reflection/__init__.py +181 -0
- package/astra/search.py +469 -0
- package/astra/tasks.py +466 -0
- package/astra/tools/__init__.py +17 -0
- package/astra/tools/__pycache__/__init__.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/advanced.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/base.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/browser.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/file.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/git.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/memory_tool.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/python.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/shell.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/web.cpython-314.pyc +0 -0
- package/astra/tools/__pycache__/windows.cpython-314.pyc +0 -0
- package/astra/tools/advanced.py +251 -0
- package/astra/tools/base.py +344 -0
- package/astra/tools/browser.py +93 -0
- package/astra/tools/file.py +476 -0
- package/astra/tools/git.py +74 -0
- package/astra/tools/memory_tool.py +89 -0
- package/astra/tools/python.py +238 -0
- package/astra/tools/shell.py +183 -0
- package/astra/tools/web.py +804 -0
- package/astra/tools/windows.py +542 -0
- package/astra/updater.py +450 -0
- package/astra/utils/__init__.py +230 -0
- package/bin/astraagent.js +73 -0
- package/bin/postinstall.js +25 -0
- package/config.json.template +52 -0
- package/main.py +16 -0
- package/package.json +51 -0
- package/pyproject.toml +72 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Agent State Management for AstraAgent.
|
|
3
|
+
Tracks the current state, history, and context of the agent.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import uuid
|
|
7
|
+
import json
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from dataclasses import dataclass, field, asdict
|
|
10
|
+
from typing import Optional, Dict, Any, List
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TaskStatus(Enum):
|
|
16
|
+
"""Status of a task."""
|
|
17
|
+
PENDING = "pending"
|
|
18
|
+
IN_PROGRESS = "in_progress"
|
|
19
|
+
COMPLETED = "completed"
|
|
20
|
+
FAILED = "failed"
|
|
21
|
+
BLOCKED = "blocked"
|
|
22
|
+
CANCELLED = "cancelled"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ActionStatus(Enum):
|
|
26
|
+
"""Status of an action."""
|
|
27
|
+
PENDING = "pending"
|
|
28
|
+
EXECUTING = "executing"
|
|
29
|
+
SUCCESS = "success"
|
|
30
|
+
FAILED = "failed"
|
|
31
|
+
SKIPPED = "skipped"
|
|
32
|
+
REQUIRES_APPROVAL = "requires_approval"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class Action:
|
|
37
|
+
"""Represents a single action taken by the agent."""
|
|
38
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
39
|
+
tool: str = ""
|
|
40
|
+
args: Dict[str, Any] = field(default_factory=dict)
|
|
41
|
+
thought: str = ""
|
|
42
|
+
result: Optional[Any] = None
|
|
43
|
+
error: Optional[str] = None
|
|
44
|
+
status: ActionStatus = ActionStatus.PENDING
|
|
45
|
+
started_at: Optional[datetime] = None
|
|
46
|
+
completed_at: Optional[datetime] = None
|
|
47
|
+
execution_time_ms: Optional[float] = None
|
|
48
|
+
|
|
49
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
50
|
+
data = asdict(self)
|
|
51
|
+
data['status'] = self.status.value
|
|
52
|
+
data['started_at'] = self.started_at.isoformat() if self.started_at else None
|
|
53
|
+
data['completed_at'] = self.completed_at.isoformat() if self.completed_at else None
|
|
54
|
+
return data
|
|
55
|
+
|
|
56
|
+
def mark_executing(self):
|
|
57
|
+
self.status = ActionStatus.EXECUTING
|
|
58
|
+
self.started_at = datetime.now()
|
|
59
|
+
|
|
60
|
+
def mark_success(self, result: Any):
|
|
61
|
+
self.status = ActionStatus.SUCCESS
|
|
62
|
+
self.result = result
|
|
63
|
+
self.completed_at = datetime.now()
|
|
64
|
+
if self.started_at:
|
|
65
|
+
self.execution_time_ms = (self.completed_at - self.started_at).total_seconds() * 1000
|
|
66
|
+
|
|
67
|
+
def mark_failed(self, error: str):
|
|
68
|
+
self.status = ActionStatus.FAILED
|
|
69
|
+
self.error = error
|
|
70
|
+
self.completed_at = datetime.now()
|
|
71
|
+
if self.started_at:
|
|
72
|
+
self.execution_time_ms = (self.completed_at - self.started_at).total_seconds() * 1000
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class Task:
|
|
77
|
+
"""Represents a task or sub-task in the plan."""
|
|
78
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
79
|
+
description: str = ""
|
|
80
|
+
parent_id: Optional[str] = None
|
|
81
|
+
subtasks: List['Task'] = field(default_factory=list)
|
|
82
|
+
actions: List[Action] = field(default_factory=list)
|
|
83
|
+
status: TaskStatus = TaskStatus.PENDING
|
|
84
|
+
priority: int = 5 # 1-10, 1 being highest
|
|
85
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
86
|
+
started_at: Optional[datetime] = None
|
|
87
|
+
completed_at: Optional[datetime] = None
|
|
88
|
+
result: Optional[str] = None
|
|
89
|
+
error: Optional[str] = None
|
|
90
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
91
|
+
|
|
92
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
93
|
+
return {
|
|
94
|
+
'id': self.id,
|
|
95
|
+
'description': self.description,
|
|
96
|
+
'parent_id': self.parent_id,
|
|
97
|
+
'subtasks': [st.to_dict() for st in self.subtasks],
|
|
98
|
+
'actions': [a.to_dict() for a in self.actions],
|
|
99
|
+
'status': self.status.value,
|
|
100
|
+
'priority': self.priority,
|
|
101
|
+
'created_at': self.created_at.isoformat(),
|
|
102
|
+
'started_at': self.started_at.isoformat() if self.started_at else None,
|
|
103
|
+
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
|
104
|
+
'result': self.result,
|
|
105
|
+
'error': self.error,
|
|
106
|
+
'metadata': self.metadata
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
def add_subtask(self, description: str, priority: int = 5) -> 'Task':
|
|
110
|
+
subtask = Task(
|
|
111
|
+
description=description,
|
|
112
|
+
parent_id=self.id,
|
|
113
|
+
priority=priority
|
|
114
|
+
)
|
|
115
|
+
self.subtasks.append(subtask)
|
|
116
|
+
return subtask
|
|
117
|
+
|
|
118
|
+
def add_action(self, action: Action):
|
|
119
|
+
self.actions.append(action)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def is_complete(self) -> bool:
|
|
123
|
+
return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def all_subtasks_complete(self) -> bool:
|
|
127
|
+
return all(st.is_complete for st in self.subtasks)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class Plan:
|
|
132
|
+
"""Represents the agent's plan for achieving a goal."""
|
|
133
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
134
|
+
goal: str = ""
|
|
135
|
+
tasks: List[Task] = field(default_factory=list)
|
|
136
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
137
|
+
updated_at: datetime = field(default_factory=datetime.now)
|
|
138
|
+
version: int = 1
|
|
139
|
+
notes: str = ""
|
|
140
|
+
|
|
141
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
142
|
+
return {
|
|
143
|
+
'id': self.id,
|
|
144
|
+
'goal': self.goal,
|
|
145
|
+
'tasks': [t.to_dict() for t in self.tasks],
|
|
146
|
+
'created_at': self.created_at.isoformat(),
|
|
147
|
+
'updated_at': self.updated_at.isoformat(),
|
|
148
|
+
'version': self.version,
|
|
149
|
+
'notes': self.notes
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def add_task(self, description: str, priority: int = 5) -> Task:
|
|
153
|
+
task = Task(description=description, priority=priority)
|
|
154
|
+
self.tasks.append(task)
|
|
155
|
+
self.updated_at = datetime.now()
|
|
156
|
+
return task
|
|
157
|
+
|
|
158
|
+
def get_next_task(self) -> Optional[Task]:
|
|
159
|
+
"""Get the next pending task with highest priority."""
|
|
160
|
+
pending = [t for t in self.tasks if t.status == TaskStatus.PENDING]
|
|
161
|
+
if not pending:
|
|
162
|
+
return None
|
|
163
|
+
return min(pending, key=lambda t: t.priority)
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def progress(self) -> float:
|
|
167
|
+
"""Calculate completion percentage."""
|
|
168
|
+
if not self.tasks:
|
|
169
|
+
return 0.0
|
|
170
|
+
completed = sum(1 for t in self.tasks if t.status == TaskStatus.COMPLETED)
|
|
171
|
+
return (completed / len(self.tasks)) * 100
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def is_complete(self) -> bool:
|
|
175
|
+
return all(t.is_complete for t in self.tasks)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@dataclass
|
|
179
|
+
class Observation:
|
|
180
|
+
"""Represents an observation or insight from the agent."""
|
|
181
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
182
|
+
content: str = ""
|
|
183
|
+
source: str = ""
|
|
184
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
185
|
+
importance: int = 5 # 1-10
|
|
186
|
+
tags: List[str] = field(default_factory=list)
|
|
187
|
+
|
|
188
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
189
|
+
return {
|
|
190
|
+
'id': self.id,
|
|
191
|
+
'content': self.content,
|
|
192
|
+
'source': self.source,
|
|
193
|
+
'timestamp': self.timestamp.isoformat(),
|
|
194
|
+
'importance': self.importance,
|
|
195
|
+
'tags': self.tags
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class Reflection:
|
|
201
|
+
"""Represents the agent's reflection on its actions."""
|
|
202
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
203
|
+
action_id: str = ""
|
|
204
|
+
thought: str = ""
|
|
205
|
+
success: bool = True
|
|
206
|
+
lessons_learned: List[str] = field(default_factory=list)
|
|
207
|
+
improvements: List[str] = field(default_factory=list)
|
|
208
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
209
|
+
|
|
210
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
211
|
+
return {
|
|
212
|
+
'id': self.id,
|
|
213
|
+
'action_id': self.action_id,
|
|
214
|
+
'thought': self.thought,
|
|
215
|
+
'success': self.success,
|
|
216
|
+
'lessons_learned': self.lessons_learned,
|
|
217
|
+
'improvements': self.improvements,
|
|
218
|
+
'timestamp': self.timestamp.isoformat()
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class AgentState:
|
|
224
|
+
"""
|
|
225
|
+
Complete state of the AstraAgent.
|
|
226
|
+
Tracks everything the agent needs to operate.
|
|
227
|
+
"""
|
|
228
|
+
# Session info
|
|
229
|
+
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
230
|
+
started_at: datetime = field(default_factory=datetime.now)
|
|
231
|
+
|
|
232
|
+
# Current goal and plan
|
|
233
|
+
current_goal: Optional[str] = None
|
|
234
|
+
current_plan: Optional[Plan] = None
|
|
235
|
+
current_task: Optional[Task] = None
|
|
236
|
+
current_action: Optional[Action] = None
|
|
237
|
+
|
|
238
|
+
# History
|
|
239
|
+
action_history: List[Action] = field(default_factory=list)
|
|
240
|
+
observations: List[Observation] = field(default_factory=list)
|
|
241
|
+
reflections: List[Reflection] = field(default_factory=list)
|
|
242
|
+
|
|
243
|
+
# Metrics
|
|
244
|
+
total_actions: int = 0
|
|
245
|
+
successful_actions: int = 0
|
|
246
|
+
failed_actions: int = 0
|
|
247
|
+
total_tokens_used: int = 0
|
|
248
|
+
iteration_count: int = 0
|
|
249
|
+
|
|
250
|
+
# Context
|
|
251
|
+
context: Dict[str, Any] = field(default_factory=dict)
|
|
252
|
+
artifacts_created: List[str] = field(default_factory=list)
|
|
253
|
+
errors_encountered: List[str] = field(default_factory=list)
|
|
254
|
+
|
|
255
|
+
# Flags
|
|
256
|
+
is_running: bool = False
|
|
257
|
+
is_paused: bool = False
|
|
258
|
+
requires_input: bool = False
|
|
259
|
+
last_error: Optional[str] = None
|
|
260
|
+
|
|
261
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
262
|
+
return {
|
|
263
|
+
'session_id': self.session_id,
|
|
264
|
+
'started_at': self.started_at.isoformat(),
|
|
265
|
+
'current_goal': self.current_goal,
|
|
266
|
+
'current_plan': self.current_plan.to_dict() if self.current_plan else None,
|
|
267
|
+
'current_task': self.current_task.to_dict() if self.current_task else None,
|
|
268
|
+
'current_action': self.current_action.to_dict() if self.current_action else None,
|
|
269
|
+
'action_history': [a.to_dict() for a in self.action_history],
|
|
270
|
+
'observations': [o.to_dict() for o in self.observations],
|
|
271
|
+
'reflections': [r.to_dict() for r in self.reflections],
|
|
272
|
+
'metrics': {
|
|
273
|
+
'total_actions': self.total_actions,
|
|
274
|
+
'successful_actions': self.successful_actions,
|
|
275
|
+
'failed_actions': self.failed_actions,
|
|
276
|
+
'total_tokens_used': self.total_tokens_used,
|
|
277
|
+
'iteration_count': self.iteration_count,
|
|
278
|
+
'success_rate': self.success_rate
|
|
279
|
+
},
|
|
280
|
+
'artifacts_created': self.artifacts_created,
|
|
281
|
+
'errors_encountered': self.errors_encountered,
|
|
282
|
+
'is_running': self.is_running,
|
|
283
|
+
'is_paused': self.is_paused,
|
|
284
|
+
'requires_input': self.requires_input,
|
|
285
|
+
'last_error': self.last_error
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
@property
|
|
289
|
+
def success_rate(self) -> float:
|
|
290
|
+
if self.total_actions == 0:
|
|
291
|
+
return 0.0
|
|
292
|
+
return (self.successful_actions / self.total_actions) * 100
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def duration(self) -> float:
|
|
296
|
+
"""Session duration in seconds."""
|
|
297
|
+
return (datetime.now() - self.started_at).total_seconds()
|
|
298
|
+
|
|
299
|
+
def set_goal(self, goal: str):
|
|
300
|
+
"""Set the current goal and create a new plan."""
|
|
301
|
+
self.current_goal = goal
|
|
302
|
+
self.current_plan = Plan(goal=goal)
|
|
303
|
+
|
|
304
|
+
def record_action(self, action: Action):
|
|
305
|
+
"""Record an action in history and update metrics."""
|
|
306
|
+
self.action_history.append(action)
|
|
307
|
+
self.total_actions += 1
|
|
308
|
+
if action.status == ActionStatus.SUCCESS:
|
|
309
|
+
self.successful_actions += 1
|
|
310
|
+
elif action.status == ActionStatus.FAILED:
|
|
311
|
+
self.failed_actions += 1
|
|
312
|
+
if action.error:
|
|
313
|
+
self.errors_encountered.append(action.error)
|
|
314
|
+
|
|
315
|
+
def add_observation(self, content: str, source: str = "", importance: int = 5):
|
|
316
|
+
"""Add an observation."""
|
|
317
|
+
obs = Observation(content=content, source=source, importance=importance)
|
|
318
|
+
self.observations.append(obs)
|
|
319
|
+
return obs
|
|
320
|
+
|
|
321
|
+
def add_reflection(self, action_id: str, thought: str, success: bool = True):
|
|
322
|
+
"""Add a reflection on an action."""
|
|
323
|
+
ref = Reflection(action_id=action_id, thought=thought, success=success)
|
|
324
|
+
self.reflections.append(ref)
|
|
325
|
+
return ref
|
|
326
|
+
|
|
327
|
+
def add_artifact(self, path: str):
|
|
328
|
+
"""Record a created artifact."""
|
|
329
|
+
if path not in self.artifacts_created:
|
|
330
|
+
self.artifacts_created.append(path)
|
|
331
|
+
|
|
332
|
+
def increment_iteration(self):
|
|
333
|
+
"""Increment the iteration counter."""
|
|
334
|
+
self.iteration_count += 1
|
|
335
|
+
|
|
336
|
+
def save(self, path: str):
|
|
337
|
+
"""Save state to file."""
|
|
338
|
+
with open(path, 'w') as f:
|
|
339
|
+
json.dump(self.to_dict(), f, indent=2)
|
|
340
|
+
|
|
341
|
+
def get_recent_context(self, n: int = 10) -> str:
|
|
342
|
+
"""Get recent actions as context string."""
|
|
343
|
+
recent = self.action_history[-n:]
|
|
344
|
+
lines = []
|
|
345
|
+
for action in recent:
|
|
346
|
+
status = "[OK]" if action.status == ActionStatus.SUCCESS else "[X]"
|
|
347
|
+
lines.append(f"[{status}] {action.tool}: {action.thought[:100]}...")
|
|
348
|
+
return "\n".join(lines)
|
|
349
|
+
|
|
350
|
+
def get_summary(self) -> str:
|
|
351
|
+
"""Get a summary of the current state."""
|
|
352
|
+
plan_status = ""
|
|
353
|
+
if self.current_plan:
|
|
354
|
+
plan_status = f"Plan Progress: {self.current_plan.progress:.1f}%"
|
|
355
|
+
|
|
356
|
+
return f"""
|
|
357
|
+
Session: {self.session_id[:8]}
|
|
358
|
+
Goal: {self.current_goal or 'None'}
|
|
359
|
+
{plan_status}
|
|
360
|
+
Iteration: {self.iteration_count}
|
|
361
|
+
Actions: {self.total_actions} (OK:{self.successful_actions} FAIL:{self.failed_actions})
|
|
362
|
+
Success Rate: {self.success_rate:.1f}%
|
|
363
|
+
Duration: {self.duration:.1f}s
|
|
364
|
+
Artifacts: {len(self.artifacts_created)}
|
|
365
|
+
Status: {'Running' if self.is_running else 'Stopped'}
|
|
366
|
+
""".strip()
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AstraAgent Voice Engine
|
|
3
|
+
Integrates Qwen3-TTS for high-quality speech synthesis.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import threading
|
|
8
|
+
import logging
|
|
9
|
+
import tempfile
|
|
10
|
+
import time
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
# Configure logging
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Global flag for dependencies
|
|
17
|
+
HAS_DEPENDENCIES = True
|
|
18
|
+
MISSING_DEPS = []
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import torch
|
|
22
|
+
except ImportError:
|
|
23
|
+
HAS_DEPENDENCIES = False
|
|
24
|
+
MISSING_DEPS.append("torch")
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import soundfile as sf
|
|
28
|
+
except ImportError:
|
|
29
|
+
HAS_DEPENDENCIES = False
|
|
30
|
+
MISSING_DEPS.append("soundfile")
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
import sounddevice as sd
|
|
34
|
+
except ImportError:
|
|
35
|
+
HAS_DEPENDENCIES = False
|
|
36
|
+
MISSING_DEPS.append("sounddevice")
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
from qwen_tts import Qwen3TTSModel
|
|
40
|
+
except ImportError:
|
|
41
|
+
HAS_DEPENDENCIES = False
|
|
42
|
+
MISSING_DEPS.append("qwen-tts")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class VoiceEngine:
|
|
46
|
+
"""
|
|
47
|
+
Manages Text-to-Speech generation using Qwen3-TTS.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, model_id: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"):
|
|
51
|
+
self.model_id = model_id
|
|
52
|
+
self.model = None
|
|
53
|
+
self.is_loading = False
|
|
54
|
+
self.lock = threading.Lock()
|
|
55
|
+
self.device = "cuda" if HAS_DEPENDENCIES and torch.cuda.is_available() else "cpu"
|
|
56
|
+
|
|
57
|
+
def is_available(self) -> bool:
|
|
58
|
+
"""Check if dependencies are met."""
|
|
59
|
+
return HAS_DEPENDENCIES
|
|
60
|
+
|
|
61
|
+
def get_missing_deps(self) -> list:
|
|
62
|
+
"""Get list of missing dependencies."""
|
|
63
|
+
return MISSING_DEPS
|
|
64
|
+
|
|
65
|
+
def ensure_loaded(self):
|
|
66
|
+
"""Load the model if not already loaded."""
|
|
67
|
+
if self.model is not None:
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
with self.lock:
|
|
71
|
+
if self.model is not None:
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
if not self.is_available():
|
|
75
|
+
raise ImportError(f"Missing dependencies: {', '.join(self.get_missing_deps())}")
|
|
76
|
+
|
|
77
|
+
logger.info(f"Loading Qwen3-TTS model: {self.model_id} on {self.device}...")
|
|
78
|
+
print(f"Loading AI Voice Model ({self.model_id})... This may take a moment.")
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Determine dtype based on device
|
|
82
|
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
|
83
|
+
|
|
84
|
+
# Check for Flash Attention 2 support (cuda only usually)
|
|
85
|
+
attn_impl = "eager" # Default to eager for broader compatibility
|
|
86
|
+
# if self.device == "cuda":
|
|
87
|
+
# attn_impl = "flash_attention_2"
|
|
88
|
+
|
|
89
|
+
self.model = Qwen3TTSModel.from_pretrained(
|
|
90
|
+
self.model_id,
|
|
91
|
+
device_map=self.device,
|
|
92
|
+
dtype=dtype,
|
|
93
|
+
attn_implementation=attn_impl
|
|
94
|
+
)
|
|
95
|
+
logger.info("Voice model loaded successfully.")
|
|
96
|
+
print("✓ AI Voice Model loaded")
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.error(f"Failed to load voice model: {e}")
|
|
99
|
+
raise e
|
|
100
|
+
|
|
101
|
+
def speak(self, text: str, speaker: str = "Vivian", language: str = "English", block: bool = False):
|
|
102
|
+
"""
|
|
103
|
+
Generate and play speech for the given text.
|
|
104
|
+
"""
|
|
105
|
+
if not text:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
# Run in separate thread if not blocking
|
|
109
|
+
if not block:
|
|
110
|
+
threading.Thread(target=self._speak_internal, args=(text, speaker, language)).start()
|
|
111
|
+
else:
|
|
112
|
+
self._speak_internal(text, speaker, language)
|
|
113
|
+
|
|
114
|
+
def _speak_internal(self, text: str, speaker: str, language: str):
|
|
115
|
+
try:
|
|
116
|
+
self.ensure_loaded()
|
|
117
|
+
|
|
118
|
+
# Generate audio
|
|
119
|
+
# Returns: wavs (list of numpy arrays), sample_rate
|
|
120
|
+
wavs, sample_rate = self.model.generate_custom_voice(
|
|
121
|
+
text=text,
|
|
122
|
+
language=language,
|
|
123
|
+
speaker=speaker
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if wavs and len(wavs) > 0:
|
|
127
|
+
audio_data = wavs[0]
|
|
128
|
+
|
|
129
|
+
# Play audio
|
|
130
|
+
sd.play(audio_data, sample_rate)
|
|
131
|
+
sd.wait()
|
|
132
|
+
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.error(f"Speech generation error: {e}")
|
|
135
|
+
print(f"Speech Error: {e}")
|
|
136
|
+
|
|
137
|
+
# Global instance
|
|
138
|
+
_engine = None
|
|
139
|
+
|
|
140
|
+
def get_engine():
|
|
141
|
+
global _engine
|
|
142
|
+
if _engine is None:
|
|
143
|
+
_engine = VoiceEngine()
|
|
144
|
+
return _engine
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""LLM Provider for AstraAgent - Multi-provider support."""
|
|
2
|
+
|
|
3
|
+
# Re-export everything from providers module for backward compatibility
|
|
4
|
+
from astra.llm.providers import (
|
|
5
|
+
Message,
|
|
6
|
+
LLMResponse,
|
|
7
|
+
LLMProvider,
|
|
8
|
+
LocalServerProvider,
|
|
9
|
+
OpenAIProvider,
|
|
10
|
+
GeminiProvider,
|
|
11
|
+
AnthropicProvider,
|
|
12
|
+
OpenRouterProvider,
|
|
13
|
+
GroqProvider,
|
|
14
|
+
PROVIDERS,
|
|
15
|
+
create_provider,
|
|
16
|
+
get_provider_info,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"Message",
|
|
21
|
+
"LLMResponse",
|
|
22
|
+
"LLMProvider",
|
|
23
|
+
"LocalServerProvider",
|
|
24
|
+
"OpenAIProvider",
|
|
25
|
+
"GeminiProvider",
|
|
26
|
+
"AnthropicProvider",
|
|
27
|
+
"OpenRouterProvider",
|
|
28
|
+
"GroqProvider",
|
|
29
|
+
"PROVIDERS",
|
|
30
|
+
"create_provider",
|
|
31
|
+
"get_provider_info",
|
|
32
|
+
]
|
|
Binary file
|
|
Binary file
|