solana-agent 31.1.6__py3-none-any.whl → 31.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,7 @@ clean separation of concerns.
7
7
  """
8
8
 
9
9
  import logging
10
+ import asyncio
10
11
  import re
11
12
  import time
12
13
  from typing import (
@@ -64,6 +65,11 @@ class QueryService(QueryServiceInterface):
64
65
  # Per-user sticky sessions (in-memory)
65
66
  # { user_id: { 'agent': str, 'started_at': float, 'last_updated': float, 'required_complete': bool } }
66
67
  self._sticky_sessions: Dict[str, Dict[str, Any]] = {}
68
+ # Optional realtime service attached by factory (populated in factory)
69
+ self.realtime = None # type: ignore[attr-defined]
70
+ # Persistent realtime WS per user for push-to-talk reuse
71
+ self._rt_services = {}
72
+ self._rt_lock = asyncio.Lock()
67
73
 
68
74
  def _get_sticky_agent(self, user_id: str) -> Optional[str]:
69
75
  sess = self._sticky_sessions.get(user_id)
@@ -88,9 +94,192 @@ class QueryService(QueryServiceInterface):
88
94
  self._sticky_sessions[user_id]["required_complete"] = required_complete
89
95
  self._sticky_sessions[user_id]["last_updated"] = time.time()
90
96
 
91
- def _clear_sticky_agent(self, user_id: str) -> None:
92
- if user_id in self._sticky_sessions:
93
- del self._sticky_sessions[user_id]
97
+ async def _build_combined_context(
98
+ self,
99
+ user_id: str,
100
+ user_text: str,
101
+ agent_name: str,
102
+ capture_name: Optional[str] = None,
103
+ capture_schema: Optional[Dict[str, Any]] = None,
104
+ prev_assistant: str = "",
105
+ ) -> Tuple[str, bool]:
106
+ """Build combined context string and return required_complete flag."""
107
+ # Memory context
108
+ memory_context = ""
109
+ if self.memory_provider:
110
+ try:
111
+ memory_context = await self.memory_provider.retrieve(user_id)
112
+ except Exception:
113
+ memory_context = ""
114
+
115
+ # KB context
116
+ kb_context = ""
117
+ if self.knowledge_base:
118
+ try:
119
+ kb_results = await self.knowledge_base.query(
120
+ query_text=user_text,
121
+ top_k=self.kb_results_count,
122
+ include_content=True,
123
+ include_metadata=False,
124
+ )
125
+ if kb_results:
126
+ kb_lines = [
127
+ "**KNOWLEDGE BASE (CRITICAL: MAKE THIS INFORMATION THE TOP PRIORITY):**"
128
+ ]
129
+ for i, r in enumerate(kb_results, 1):
130
+ kb_lines.append(f"[{i}] {r.get('content', '').strip()}\n")
131
+ kb_context = "\n".join(kb_lines)
132
+ except Exception:
133
+ kb_context = ""
134
+
135
+ # Capture context
136
+ capture_context = ""
137
+ required_complete = False
138
+
139
+ active_capture_name = capture_name
140
+ active_capture_schema = capture_schema
141
+ if not active_capture_name or not active_capture_schema:
142
+ try:
143
+ cap_cfg = self.agent_service.get_agent_capture(agent_name)
144
+ if cap_cfg:
145
+ active_capture_name = active_capture_name or cap_cfg.get("name")
146
+ active_capture_schema = active_capture_schema or cap_cfg.get(
147
+ "schema"
148
+ )
149
+ except Exception:
150
+ pass
151
+
152
+ if active_capture_name and isinstance(active_capture_schema, dict):
153
+ latest_by_name: Dict[str, Dict[str, Any]] = {}
154
+ if self.memory_provider:
155
+ try:
156
+ docs = self.memory_provider.find(
157
+ collection="captures",
158
+ query={"user_id": user_id},
159
+ sort=[("timestamp", -1)],
160
+ limit=100,
161
+ )
162
+ for d in docs or []:
163
+ name = (d or {}).get("capture_name")
164
+ if not name or name in latest_by_name:
165
+ continue
166
+ latest_by_name[name] = {
167
+ "data": (d or {}).get("data", {}) or {},
168
+ "mode": (d or {}).get("mode", "once"),
169
+ "agent": (d or {}).get("agent_name"),
170
+ }
171
+ except Exception:
172
+ pass
173
+
174
+ active_data = (latest_by_name.get(active_capture_name, {}) or {}).get(
175
+ "data", {}
176
+ ) or {}
177
+
178
+ def _non_empty(v: Any) -> bool:
179
+ if v is None:
180
+ return False
181
+ if isinstance(v, str):
182
+ s = v.strip().lower()
183
+ return s not in {"", "null", "none", "n/a", "na", "undefined", "."}
184
+ if isinstance(v, (list, dict, tuple, set)):
185
+ return len(v) > 0
186
+ return True
187
+
188
+ props = (active_capture_schema or {}).get("properties", {})
189
+ required_fields = list(
190
+ (active_capture_schema or {}).get("required", []) or []
191
+ )
192
+ all_fields = list(props.keys())
193
+ optional_fields = [f for f in all_fields if f not in set(required_fields)]
194
+
195
+ def _missing_from(data: Dict[str, Any], fields: List[str]) -> List[str]:
196
+ return [f for f in fields if not _non_empty(data.get(f))]
197
+
198
+ missing_required = _missing_from(active_data, required_fields)
199
+ missing_optional = _missing_from(active_data, optional_fields)
200
+
201
+ required_complete = len(missing_required) == 0 and len(required_fields) > 0
202
+
203
+ lines: List[str] = []
204
+ lines.append(
205
+ "CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
206
+ )
207
+ lines.append(f"- form_name: {active_capture_name}")
208
+
209
+ if active_data:
210
+ pairs = [f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)]
211
+ lines.append(
212
+ f"- filled_fields: {', '.join(pairs) if pairs else '(none)'}"
213
+ )
214
+ else:
215
+ lines.append("- filled_fields: (none)")
216
+
217
+ lines.append(
218
+ f"- missing_required_fields: {', '.join(missing_required) if missing_required else '(none)'}"
219
+ )
220
+ lines.append(
221
+ f"- missing_optional_fields: {', '.join(missing_optional) if missing_optional else '(none)'}"
222
+ )
223
+ lines.append("")
224
+
225
+ if latest_by_name:
226
+ lines.append("OTHER CAPTURED USER DATA (for reference):")
227
+ for cname, info in latest_by_name.items():
228
+ if cname == active_capture_name:
229
+ continue
230
+ data = info.get("data", {}) or {}
231
+ if data:
232
+ pairs = [f"{k}: {v}" for k, v in data.items() if _non_empty(v)]
233
+ lines.append(
234
+ f"- {cname}: {', '.join(pairs) if pairs else '(none)'}"
235
+ )
236
+ else:
237
+ lines.append(f"- {cname}: (none)")
238
+
239
+ if lines:
240
+ capture_context = "\n".join(lines) + "\n\n"
241
+
242
+ # Merge contexts
243
+ combined_context = ""
244
+ if capture_context:
245
+ combined_context += capture_context
246
+ if memory_context:
247
+ combined_context += f"CONVERSATION HISTORY (Use for continuity; not authoritative for facts):\n{memory_context}\n\n"
248
+ if kb_context:
249
+ combined_context += kb_context + "\n"
250
+
251
+ guide = (
252
+ "PRIORITIZATION GUIDE:\n"
253
+ "- Prefer Captured User Data for user-specific fields.\n"
254
+ "- Prefer KB/tools for facts.\n"
255
+ "- History is for tone and continuity.\n\n"
256
+ "FORM FLOW RULES:\n"
257
+ "- Ask exactly one field per turn.\n"
258
+ "- If any required fields are missing, ask the next missing required field.\n"
259
+ "- If all required fields are filled but optional fields are missing, ask the next missing optional field.\n"
260
+ "- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
261
+ "- Do NOT provide summaries until no required or optional fields are missing.\n\n"
262
+ )
263
+
264
+ if combined_context:
265
+ combined_context += guide
266
+ else:
267
+ # Diagnostics for why the context is empty
268
+ try:
269
+ logger.debug(
270
+ "_build_combined_context: empty sources — memory_provider=%s, knowledge_base=%s, active_capture=%s",
271
+ bool(self.memory_provider),
272
+ bool(self.knowledge_base),
273
+ bool(
274
+ active_capture_name and isinstance(active_capture_schema, dict)
275
+ ),
276
+ )
277
+ except Exception:
278
+ pass
279
+ # Provide minimal guide so realtime instructions are not blank
280
+ combined_context = guide
281
+
282
+ return combined_context, required_complete
94
283
 
95
284
  # LLM-backed switch intent detection (gpt-4.1-mini)
96
285
  class _SwitchIntentModel(BaseModel):
@@ -213,6 +402,23 @@ class QueryService(QueryServiceInterface):
213
402
  query: Union[str, bytes],
214
403
  images: Optional[List[Union[str, bytes]]] = None,
215
404
  output_format: Literal["text", "audio"] = "text",
405
+ realtime: bool = False,
406
+ # Realtime minimal controls (voice/format come from audio_* args)
407
+ vad: Optional[bool] = None,
408
+ rt_encode_input: bool = False,
409
+ rt_encode_output: bool = False,
410
+ rt_voice: Literal[
411
+ "alloy",
412
+ "ash",
413
+ "ballad",
414
+ "cedar",
415
+ "coral",
416
+ "echo",
417
+ "marin",
418
+ "sage",
419
+ "shimmer",
420
+ "verse",
421
+ ] = "marin",
216
422
  audio_voice: Literal[
217
423
  "alloy",
218
424
  "ash",
@@ -225,7 +431,6 @@ class QueryService(QueryServiceInterface):
225
431
  "sage",
226
432
  "shimmer",
227
433
  ] = "nova",
228
- audio_instructions: str = "You speak in a friendly and helpful manner.",
229
434
  audio_output_format: Literal[
230
435
  "mp3", "opus", "aac", "flac", "wav", "pcm"
231
436
  ] = "aac",
@@ -240,6 +445,321 @@ class QueryService(QueryServiceInterface):
240
445
  ) -> AsyncGenerator[Union[str, bytes, BaseModel], None]: # pragma: no cover
241
446
  """Process the user request and generate a response."""
242
447
  try:
448
+ # Realtime request: HTTP STT for user + single WS for assistant audio
449
+ if realtime:
450
+ # 1) Launch HTTP STT in background when input is audio; don't block WS
451
+ is_audio_bytes = isinstance(query, (bytes, bytearray))
452
+ user_text = ""
453
+ stt_task = None
454
+ if is_audio_bytes:
455
+
456
+ async def _stt_consume():
457
+ txt = ""
458
+ try:
459
+ logger.info(
460
+ f"Realtime(HTTP STT): transcribing format: {audio_input_format}"
461
+ )
462
+ async for (
463
+ t
464
+ ) in self.agent_service.llm_provider.transcribe_audio( # type: ignore[attr-defined]
465
+ query, audio_input_format
466
+ ):
467
+ txt += t
468
+ except Exception as e:
469
+ logger.error(f"HTTP STT error: {e}")
470
+ return txt
471
+
472
+ stt_task = asyncio.create_task(_stt_consume())
473
+ else:
474
+ user_text = str(query)
475
+
476
+ # 2) Single agent selection (no multi-agent routing in realtime path)
477
+ agent_name = self._get_sticky_agent(user_id)
478
+ if not agent_name:
479
+ try:
480
+ agents = self.agent_service.get_all_ai_agents() or {}
481
+ agent_name = next(iter(agents.keys())) if agents else "default"
482
+ except Exception:
483
+ agent_name = "default"
484
+ prev_assistant = ""
485
+ if self.memory_provider:
486
+ try:
487
+ prev_docs = self.memory_provider.find(
488
+ collection="conversations",
489
+ query={"user_id": user_id},
490
+ sort=[("timestamp", -1)],
491
+ limit=1,
492
+ )
493
+ if prev_docs:
494
+ prev_assistant = (prev_docs[0] or {}).get(
495
+ "assistant_message", ""
496
+ ) or ""
497
+ except Exception:
498
+ pass
499
+
500
+ # 3) Build context + tools
501
+ combined_ctx = ""
502
+ required_complete = False
503
+ try:
504
+ (
505
+ combined_ctx,
506
+ required_complete,
507
+ ) = await self._build_combined_context(
508
+ user_id=user_id,
509
+ user_text=(user_text if not is_audio_bytes else ""),
510
+ agent_name=agent_name,
511
+ capture_name=capture_name,
512
+ capture_schema=capture_schema,
513
+ prev_assistant=prev_assistant,
514
+ )
515
+ try:
516
+ self._update_sticky_required_complete(
517
+ user_id, required_complete
518
+ )
519
+ except Exception:
520
+ pass
521
+ except Exception:
522
+ combined_ctx = ""
523
+ try:
524
+ # GA Realtime expects flattened tool definitions (no nested "function" object)
525
+ initial_tools = [
526
+ {
527
+ "type": "function",
528
+ "name": t["name"],
529
+ "description": t.get("description", ""),
530
+ "parameters": t.get("parameters", {}),
531
+ "strict": True,
532
+ }
533
+ for t in self.agent_service.get_agent_tools(agent_name)
534
+ ]
535
+ except Exception:
536
+ initial_tools = []
537
+
538
+ # Build realtime instructions: include full agent system prompt, context, and optional prompt (no user_text)
539
+ system_prompt = ""
540
+ try:
541
+ system_prompt = self.agent_service.get_agent_system_prompt(
542
+ agent_name
543
+ )
544
+ except Exception:
545
+ system_prompt = ""
546
+
547
+ parts: List[str] = []
548
+ if system_prompt:
549
+ parts.append(system_prompt)
550
+ if combined_ctx:
551
+ parts.append(combined_ctx)
552
+ if prompt:
553
+ parts.append(str(prompt))
554
+ final_instructions = "\n\n".join([p for p in parts if p])
555
+
556
+ # 4) Open a single WS session for assistant audio
557
+ from solana_agent.adapters.openai_realtime_ws import (
558
+ OpenAIRealtimeWebSocketSession,
559
+ )
560
+ from solana_agent.interfaces.providers.realtime import (
561
+ RealtimeSessionOptions,
562
+ )
563
+ from solana_agent.services.realtime import RealtimeService
564
+ from solana_agent.adapters.ffmpeg_transcoder import FFmpegTranscoder
565
+
566
+ api_key = None
567
+ try:
568
+ api_key = self.agent_service.llm_provider.get_api_key() # type: ignore[attr-defined]
569
+ except Exception:
570
+ pass
571
+ if not api_key:
572
+ raise ValueError("OpenAI API key is required for realtime")
573
+
574
+ # Per-user persistent WS (single session)
575
+ def _mime_from(fmt: str) -> str:
576
+ f = (fmt or "").lower()
577
+ return {
578
+ "aac": "audio/aac",
579
+ "mp3": "audio/mpeg",
580
+ "mp4": "audio/mp4",
581
+ "m4a": "audio/mp4",
582
+ "mpeg": "audio/mpeg",
583
+ "mpga": "audio/mpeg",
584
+ "wav": "audio/wav",
585
+ "flac": "audio/flac",
586
+ "opus": "audio/opus",
587
+ "ogg": "audio/ogg",
588
+ "webm": "audio/webm",
589
+ "pcm": "audio/pcm",
590
+ }.get(f, "audio/pcm")
591
+
592
+ # Choose output encoding automatically when non-PCM output is requested
593
+ encode_out = bool(
594
+ rt_encode_output or (audio_output_format.lower() != "pcm")
595
+ )
596
+ # Choose input transcoding when compressed input is provided (or explicitly requested)
597
+ is_audio_bytes = isinstance(query, (bytes, bytearray))
598
+ encode_in = bool(
599
+ rt_encode_input
600
+ or (is_audio_bytes and audio_input_format.lower() != "pcm")
601
+ )
602
+
603
+ async with self._rt_lock:
604
+ rt = self._rt_services.get(user_id)
605
+ if not rt or not isinstance(rt, RealtimeService):
606
+ opts = RealtimeSessionOptions(
607
+ model="gpt-realtime",
608
+ voice=rt_voice,
609
+ vad_enabled=False, # no input audio
610
+ input_rate_hz=24000,
611
+ output_rate_hz=24000,
612
+ input_mime="audio/pcm",
613
+ output_mime="audio/pcm",
614
+ tools=initial_tools or None,
615
+ tool_choice="auto",
616
+ )
617
+ # Ensure initial session.update carries instructions/voice
618
+ try:
619
+ opts.instructions = final_instructions
620
+ opts.voice = rt_voice
621
+ except Exception:
622
+ pass
623
+ conv_session = OpenAIRealtimeWebSocketSession(
624
+ api_key=api_key, options=opts
625
+ )
626
+ transcoder = (
627
+ FFmpegTranscoder() if (encode_in or encode_out) else None
628
+ )
629
+ rt = RealtimeService(
630
+ session=conv_session,
631
+ options=opts,
632
+ transcoder=transcoder,
633
+ accept_compressed_input=encode_in,
634
+ client_input_mime=_mime_from(audio_input_format),
635
+ encode_output=encode_out,
636
+ client_output_mime=_mime_from(audio_output_format),
637
+ )
638
+ self._rt_services[user_id] = rt
639
+
640
+ # Tool executor
641
+ async def _exec(tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
642
+ try:
643
+ return await self.agent_service.execute_tool(
644
+ agent_name, tool_name, args or {}
645
+ )
646
+ except Exception as e:
647
+ return {"status": "error", "message": str(e)}
648
+
649
+ # If possible, set on underlying session
650
+ try:
651
+ if hasattr(rt, "_session"):
652
+ getattr(rt, "_session").set_tool_executor(_exec) # type: ignore[attr-defined]
653
+ except Exception:
654
+ pass
655
+
656
+ # Connect/configure
657
+ if not getattr(rt, "_connected", False):
658
+ await rt.start()
659
+ await rt.configure(
660
+ voice=rt_voice,
661
+ vad_enabled=bool(vad) if vad is not None else False,
662
+ instructions=final_instructions,
663
+ tools=initial_tools or None,
664
+ tool_choice="auto",
665
+ )
666
+
667
+ # Ensure clean input buffers for this turn
668
+ try:
669
+ await rt.clear_input()
670
+ except Exception:
671
+ pass
672
+
673
+ # Persist once per turn
674
+ turn_id = await self.realtime_begin_turn(user_id)
675
+ if turn_id and user_text:
676
+ try:
677
+ await self.realtime_update_user(user_id, turn_id, user_text)
678
+ except Exception:
679
+ pass
680
+
681
+ # Feed audio into WS if audio bytes provided; else use input_text
682
+ if is_audio_bytes:
683
+ bq = bytes(query)
684
+ logger.info(
685
+ "Realtime: appending input audio to WS via FFmpeg, len=%d, fmt=%s",
686
+ len(bq),
687
+ audio_input_format,
688
+ )
689
+ await rt.append_audio(bq)
690
+ vad_enabled_value = bool(vad) if vad is not None else False
691
+ if not vad_enabled_value:
692
+ await rt.commit_input()
693
+ # Manually trigger response when VAD is disabled
694
+ await rt.create_response({})
695
+ else:
696
+ # With server VAD enabled, the model will auto-create a response at end of speech
697
+ logger.debug(
698
+ "Realtime: VAD enabled — skipping manual response.create"
699
+ )
700
+ else:
701
+ # Rely on configured session voice; attach input_text only
702
+ await rt.create_response(
703
+ {
704
+ "modalities": ["audio"],
705
+ "input": [{"type": "input_text", "text": user_text or ""}],
706
+ }
707
+ )
708
+
709
+ # Collect audio and transcripts
710
+ user_tr = ""
711
+ asst_tr = ""
712
+
713
+ async def _drain_in_tr():
714
+ nonlocal user_tr
715
+ async for t in rt.iter_input_transcript():
716
+ if t:
717
+ user_tr += t
718
+
719
+ async def _drain_out_tr():
720
+ nonlocal asst_tr
721
+ async for t in rt.iter_output_transcript():
722
+ if t:
723
+ asst_tr += t
724
+
725
+ in_task = asyncio.create_task(_drain_in_tr())
726
+ out_task = asyncio.create_task(_drain_out_tr())
727
+ try:
728
+ async for audio_chunk in rt.iter_output_audio_encoded():
729
+ yield audio_chunk
730
+ finally:
731
+ in_task.cancel()
732
+ out_task.cancel()
733
+ # If no WS input transcript was captured, fall back to HTTP STT result
734
+ if not user_tr:
735
+ try:
736
+ if "stt_task" in locals() and stt_task is not None:
737
+ user_tr = await stt_task
738
+ except Exception:
739
+ pass
740
+ if turn_id:
741
+ try:
742
+ if user_tr:
743
+ await self.realtime_update_user(
744
+ user_id, turn_id, user_tr
745
+ )
746
+ if asst_tr:
747
+ await self.realtime_update_assistant(
748
+ user_id, turn_id, asst_tr
749
+ )
750
+ except Exception:
751
+ pass
752
+ try:
753
+ await self.realtime_finalize_turn(user_id, turn_id)
754
+ except Exception:
755
+ pass
756
+ # Clear input buffer for next turn reuse
757
+ try:
758
+ await rt.clear_input()
759
+ except Exception:
760
+ pass
761
+ return
762
+
243
763
  # 1) Transcribe audio or accept text
244
764
  user_text = ""
245
765
  if not isinstance(query, str):
@@ -283,7 +803,6 @@ class QueryService(QueryServiceInterface):
283
803
  text=greeting,
284
804
  voice=audio_voice,
285
805
  response_format=audio_output_format,
286
- instructions=audio_instructions,
287
806
  ):
288
807
  yield chunk
289
808
  else:
@@ -771,7 +1290,6 @@ class QueryService(QueryServiceInterface):
771
1290
  output_format="audio",
772
1291
  audio_voice=audio_voice,
773
1292
  audio_output_format=audio_output_format,
774
- audio_instructions=audio_instructions,
775
1293
  prompt=prompt,
776
1294
  ):
777
1295
  yield audio_chunk
@@ -984,6 +1502,43 @@ class QueryService(QueryServiceInterface):
984
1502
  except Exception as e:
985
1503
  logger.error(f"Store conversation error for {user_id}: {e}")
986
1504
 
1505
+ # --- Realtime persistence helpers (used by client/server using realtime service) ---
1506
+ async def realtime_begin_turn(
1507
+ self, user_id: str
1508
+ ) -> Optional[str]: # pragma: no cover
1509
+ if not self.memory_provider:
1510
+ return None
1511
+ if not hasattr(self.memory_provider, "begin_stream_turn"):
1512
+ return None
1513
+ return await self.memory_provider.begin_stream_turn(user_id) # type: ignore[attr-defined]
1514
+
1515
+ async def realtime_update_user(
1516
+ self, user_id: str, turn_id: str, delta: str
1517
+ ) -> None: # pragma: no cover
1518
+ if not self.memory_provider:
1519
+ return
1520
+ if not hasattr(self.memory_provider, "update_stream_user"):
1521
+ return
1522
+ await self.memory_provider.update_stream_user(user_id, turn_id, delta) # type: ignore[attr-defined]
1523
+
1524
+ async def realtime_update_assistant(
1525
+ self, user_id: str, turn_id: str, delta: str
1526
+ ) -> None: # pragma: no cover
1527
+ if not self.memory_provider:
1528
+ return
1529
+ if not hasattr(self.memory_provider, "update_stream_assistant"):
1530
+ return
1531
+ await self.memory_provider.update_stream_assistant(user_id, turn_id, delta) # type: ignore[attr-defined]
1532
+
1533
+ async def realtime_finalize_turn(
1534
+ self, user_id: str, turn_id: str
1535
+ ) -> None: # pragma: no cover
1536
+ if not self.memory_provider:
1537
+ return
1538
+ if not hasattr(self.memory_provider, "finalize_stream_turn"):
1539
+ return
1540
+ await self.memory_provider.finalize_stream_turn(user_id, turn_id) # type: ignore[attr-defined]
1541
+
987
1542
  def _build_model_from_json_schema(
988
1543
  self, name: str, schema: Dict[str, Any]
989
1544
  ) -> Type[BaseModel]: