dv-pipecat-ai 0.0.85.dev7__py3-none-any.whl → 0.0.85.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dv-pipecat-ai might be problematic. Click here for more details.
- {dv_pipecat_ai-0.0.85.dev7.dist-info → dv_pipecat_ai-0.0.85.dev11.dist-info}/METADATA +4 -1
- {dv_pipecat_ai-0.0.85.dev7.dist-info → dv_pipecat_ai-0.0.85.dev11.dist-info}/RECORD +16 -10
- pipecat/audio/turn/base_turn_analyzer.py +9 -1
- pipecat/audio/turn/smart_turn/base_smart_turn.py +14 -8
- pipecat/audio/turn/smart_turn/data/__init__.py +0 -0
- pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx +0 -0
- pipecat/audio/turn/smart_turn/http_smart_turn.py +6 -2
- pipecat/audio/turn/smart_turn/local_smart_turn.py +1 -1
- pipecat/audio/turn/smart_turn/local_smart_turn_v2.py +1 -1
- pipecat/audio/turn/smart_turn/local_smart_turn_v3.py +124 -0
- pipecat/services/salesforce/__init__.py +9 -0
- pipecat/services/salesforce/llm.py +587 -0
- pipecat/utils/redis.py +58 -0
- {dv_pipecat_ai-0.0.85.dev7.dist-info → dv_pipecat_ai-0.0.85.dev11.dist-info}/WHEEL +0 -0
- {dv_pipecat_ai-0.0.85.dev7.dist-info → dv_pipecat_ai-0.0.85.dev11.dist-info}/licenses/LICENSE +0 -0
- {dv_pipecat_ai-0.0.85.dev7.dist-info → dv_pipecat_ai-0.0.85.dev11.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dv-pipecat-ai
|
|
3
|
-
Version: 0.0.85.
|
|
3
|
+
Version: 0.0.85.dev11
|
|
4
4
|
Summary: An open source framework for voice (and multimodal) assistants
|
|
5
5
|
License-Expression: BSD-2-Clause
|
|
6
6
|
Project-URL: Source, https://github.com/pipecat-ai/pipecat
|
|
@@ -143,6 +143,9 @@ Requires-Dist: coremltools>=8.0; extra == "local-smart-turn"
|
|
|
143
143
|
Requires-Dist: transformers; extra == "local-smart-turn"
|
|
144
144
|
Requires-Dist: torch<3,>=2.5.0; extra == "local-smart-turn"
|
|
145
145
|
Requires-Dist: torchaudio<3,>=2.5.0; extra == "local-smart-turn"
|
|
146
|
+
Provides-Extra: local-smart-turn-v3
|
|
147
|
+
Requires-Dist: transformers; extra == "local-smart-turn-v3"
|
|
148
|
+
Requires-Dist: onnxruntime<2,>=1.20.1; extra == "local-smart-turn-v3"
|
|
146
149
|
Provides-Extra: remote-smart-turn
|
|
147
150
|
Provides-Extra: silero
|
|
148
151
|
Requires-Dist: onnxruntime~=1.20.1; extra == "silero"
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
dv_pipecat_ai-0.0.85.
|
|
1
|
+
dv_pipecat_ai-0.0.85.dev11.dist-info/licenses/LICENSE,sha256=DWY2QGf2eMCFhuu2ChairtT6CB7BEFffNVhXWc4Od08,1301
|
|
2
2
|
pipecat/__init__.py,sha256=j0Xm6adxHhd7D06dIyyPV_GlBYLlBnTAERVvD_jAARQ,861
|
|
3
3
|
pipecat/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
pipecat/adapters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -50,14 +50,17 @@ pipecat/audio/resamplers/resampy_resampler.py,sha256=fEZv6opn_9j50xYEOdwQiZOJQ_J
|
|
|
50
50
|
pipecat/audio/resamplers/soxr_resampler.py,sha256=CXze7zf_ExlCcgcZp0oArRSbZ9zFpBzsCt2EQ_woKfM,1747
|
|
51
51
|
pipecat/audio/resamplers/soxr_stream_resampler.py,sha256=lHk1__M1HDGf25abpffuWEyqbd0ckNfyADDV_WmTPcY,3665
|
|
52
52
|
pipecat/audio/turn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
53
|
-
pipecat/audio/turn/base_turn_analyzer.py,sha256=
|
|
53
|
+
pipecat/audio/turn/base_turn_analyzer.py,sha256=UoZ61yto2wecXU6nXk2yjdcgM7jGyfMR5ZfrunOFpOA,3359
|
|
54
54
|
pipecat/audio/turn/smart_turn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
|
-
pipecat/audio/turn/smart_turn/base_smart_turn.py,sha256=
|
|
55
|
+
pipecat/audio/turn/smart_turn/base_smart_turn.py,sha256=gE5jrqrU0gQcgjTOvpUbb6LWAhfk8VKZQ-5pyEIZH4E,10037
|
|
56
56
|
pipecat/audio/turn/smart_turn/fal_smart_turn.py,sha256=neahuTAY9SUQjacRYd19BERiuSHIMSpqzZ9uae_ZlWA,1606
|
|
57
|
-
pipecat/audio/turn/smart_turn/http_smart_turn.py,sha256=
|
|
57
|
+
pipecat/audio/turn/smart_turn/http_smart_turn.py,sha256=HlHpdVbk-1g_AU3qAAy7Xob8M2V3FUqtr38UAk1F1Dw,4783
|
|
58
58
|
pipecat/audio/turn/smart_turn/local_coreml_smart_turn.py,sha256=50kiBeZhnq7FZWZnzdSX8KUmhhQtkme0KH2rbiAJbCU,3140
|
|
59
|
-
pipecat/audio/turn/smart_turn/local_smart_turn.py,sha256=
|
|
60
|
-
pipecat/audio/turn/smart_turn/local_smart_turn_v2.py,sha256=
|
|
59
|
+
pipecat/audio/turn/smart_turn/local_smart_turn.py,sha256=0z2M_MC9xIcelm4d9XqZwzJMe2FM-zOjgnHDAeoMw0g,3564
|
|
60
|
+
pipecat/audio/turn/smart_turn/local_smart_turn_v2.py,sha256=hd_nhEdaxwJ2_G6F2RJru9mC8vyzkmku2YqmtULl7NM,7154
|
|
61
|
+
pipecat/audio/turn/smart_turn/local_smart_turn_v3.py,sha256=x1q437Mp8cEU1S-7W869i1meDtCdjrjPTUCjbSLDVgQ,4649
|
|
62
|
+
pipecat/audio/turn/smart_turn/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
|
+
pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx,sha256=B6Ezq6MeLQtSPxf4wuTmXv5tj2he_RLKT-Iev055iZE,8757193
|
|
61
64
|
pipecat/audio/vad/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
62
65
|
pipecat/audio/vad/silero.py,sha256=Cz4_hJjaBKbmUwZVbqMzED8orHOCsnF3zpERgBTw1Rw,7906
|
|
63
66
|
pipecat/audio/vad/vad_analyzer.py,sha256=XkZLEe4z7Ja0lGoYZst1HNYqt5qOwG-vjsk_w8chiNA,7430
|
|
@@ -286,6 +289,8 @@ pipecat/services/rime/tts.py,sha256=XHMSnQUi7gMtWF42u4rBVv6oBDor4KkwkL7O-Sj9MPo,
|
|
|
286
289
|
pipecat/services/riva/__init__.py,sha256=rObSsj504O_TMXhPBg_ymqKslZBhovlR-A0aaRZ0O6A,276
|
|
287
290
|
pipecat/services/riva/stt.py,sha256=dtg8toijmexWB3uipw0EQ7ov3DFgHj40kFFv1Zadmmc,25116
|
|
288
291
|
pipecat/services/riva/tts.py,sha256=idbqx3I2NlWCXtrIFsjEaYapxA3BLIA14ai3aMBh-2w,8158
|
|
292
|
+
pipecat/services/salesforce/__init__.py,sha256=OFvYbcvCadYhcKdBAVLj3ZUXVXQ1HyVyhgxIFf6_Thg,173
|
|
293
|
+
pipecat/services/salesforce/llm.py,sha256=mpozkzldgz3plbMOJcKddiyJxn7x4qqPuJVn22a41Ag,23009
|
|
289
294
|
pipecat/services/sambanova/__init__.py,sha256=oTXExLic-qTcsfsiWmssf3Elclf3IIWoN41_2IpoF18,128
|
|
290
295
|
pipecat/services/sambanova/llm.py,sha256=5XVfPLEk__W8ykFqLdV95ZUhlGGkAaJwmbciLdZYtTc,8976
|
|
291
296
|
pipecat/services/sambanova/stt.py,sha256=ZZgEZ7WQjLFHbCko-3LNTtVajjtfUvbtVLtFcaNadVQ,2536
|
|
@@ -358,6 +363,7 @@ pipecat/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
358
363
|
pipecat/utils/base_object.py,sha256=62e5_0R_rcQe-JdzUM0h1wtv1okw-0LPyG78ZKkDyzE,5963
|
|
359
364
|
pipecat/utils/logger_config.py,sha256=5-RmvReZIINeqSXz3ALhEIiMZ_azmpOxnlIkdyCjWWk,5606
|
|
360
365
|
pipecat/utils/network.py,sha256=RRQ7MmTcbeXBJ2aY5UbMCQ6elm5B8Rxkn8XqkJ9S0Nc,825
|
|
366
|
+
pipecat/utils/redis.py,sha256=JmBaC1yY6e8qygUQkAER3DNFCYSCH18hd7NN9qqjDMU,1677
|
|
361
367
|
pipecat/utils/string.py,sha256=TskK9KxQSwbljct0J6y9ffkRcx4xYjTtPhFjEL4M1i8,6720
|
|
362
368
|
pipecat/utils/time.py,sha256=lirjh24suz9EI1pf2kYwvAYb3I-13U_rJ_ZRg3nRiGs,1741
|
|
363
369
|
pipecat/utils/utils.py,sha256=T2y1Mcd9yWiZiIToUiRkhW-n7EFf8juk3kWX3TF8XOQ,2451
|
|
@@ -378,7 +384,7 @@ pipecat/utils/tracing/service_decorators.py,sha256=HwDCqLGijhYD3F8nxDuQmEw-YkRw0
|
|
|
378
384
|
pipecat/utils/tracing/setup.py,sha256=7TEgPNpq6M8lww8OQvf0P9FzYc5A30xICGklVA-fua0,2892
|
|
379
385
|
pipecat/utils/tracing/turn_context_provider.py,sha256=ikon3plFOx0XbMrH6DdeHttNpb-U0gzMZIm3bWLc9eI,2485
|
|
380
386
|
pipecat/utils/tracing/turn_trace_observer.py,sha256=dma16SBJpYSOE58YDWy89QzHyQFc_9gQZszKeWixuwc,9725
|
|
381
|
-
dv_pipecat_ai-0.0.85.
|
|
382
|
-
dv_pipecat_ai-0.0.85.
|
|
383
|
-
dv_pipecat_ai-0.0.85.
|
|
384
|
-
dv_pipecat_ai-0.0.85.
|
|
387
|
+
dv_pipecat_ai-0.0.85.dev11.dist-info/METADATA,sha256=_scIy5gP8k7GUtLAA9NzNVT_T1y__8ROU0gPj1G6FCw,32858
|
|
388
|
+
dv_pipecat_ai-0.0.85.dev11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
389
|
+
dv_pipecat_ai-0.0.85.dev11.dist-info/top_level.txt,sha256=kQzG20CxGf-nSsHmtXHx3hY2-8zHA3jYg8jk0TajqXc,8
|
|
390
|
+
dv_pipecat_ai-0.0.85.dev11.dist-info/RECORD,,
|
|
@@ -14,6 +14,8 @@ from abc import ABC, abstractmethod
|
|
|
14
14
|
from enum import Enum
|
|
15
15
|
from typing import Optional, Tuple
|
|
16
16
|
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
|
|
17
19
|
from pipecat.metrics.metrics import MetricsData
|
|
18
20
|
|
|
19
21
|
|
|
@@ -29,6 +31,12 @@ class EndOfTurnState(Enum):
|
|
|
29
31
|
INCOMPLETE = 2
|
|
30
32
|
|
|
31
33
|
|
|
34
|
+
class BaseTurnParams(BaseModel):
|
|
35
|
+
"""Base class for turn analyzer parameters."""
|
|
36
|
+
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
32
40
|
class BaseTurnAnalyzer(ABC):
|
|
33
41
|
"""Abstract base class for analyzing user end of turn.
|
|
34
42
|
|
|
@@ -78,7 +86,7 @@ class BaseTurnAnalyzer(ABC):
|
|
|
78
86
|
|
|
79
87
|
@property
|
|
80
88
|
@abstractmethod
|
|
81
|
-
def params(self):
|
|
89
|
+
def params(self) -> BaseTurnParams:
|
|
82
90
|
"""Get the current turn analyzer parameters.
|
|
83
91
|
|
|
84
92
|
Returns:
|
|
@@ -11,15 +11,17 @@ machine learning models to determine when a user has finished speaking, going
|
|
|
11
11
|
beyond simple silence-based detection.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
+
import asyncio
|
|
14
15
|
import time
|
|
15
16
|
from abc import abstractmethod
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
16
18
|
from typing import Any, Dict, Optional, Tuple
|
|
17
19
|
|
|
18
20
|
import numpy as np
|
|
19
21
|
from loguru import logger
|
|
20
22
|
from pydantic import BaseModel
|
|
21
23
|
|
|
22
|
-
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState
|
|
24
|
+
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, BaseTurnParams, EndOfTurnState
|
|
23
25
|
from pipecat.metrics.metrics import MetricsData, SmartTurnMetricsData
|
|
24
26
|
|
|
25
27
|
# Default timing parameters
|
|
@@ -29,7 +31,7 @@ MAX_DURATION_SECONDS = 8 # Max allowed segment duration
|
|
|
29
31
|
USE_ONLY_LAST_VAD_SEGMENT = True
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
class SmartTurnParams(
|
|
34
|
+
class SmartTurnParams(BaseTurnParams):
|
|
33
35
|
"""Configuration parameters for smart turn analysis.
|
|
34
36
|
|
|
35
37
|
Parameters:
|
|
@@ -77,6 +79,9 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
|
|
77
79
|
self._speech_triggered = False
|
|
78
80
|
self._silence_ms = 0
|
|
79
81
|
self._speech_start_time = 0
|
|
82
|
+
# Thread executor that will run the model. We only need one thread per
|
|
83
|
+
# analyzer because one analyzer just handles one audio stream.
|
|
84
|
+
self._executor = ThreadPoolExecutor(max_workers=1)
|
|
80
85
|
|
|
81
86
|
@property
|
|
82
87
|
def speech_triggered(self) -> bool:
|
|
@@ -151,7 +156,10 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
|
|
151
156
|
Tuple containing the end-of-turn state and optional metrics data
|
|
152
157
|
from the ML model analysis.
|
|
153
158
|
"""
|
|
154
|
-
|
|
159
|
+
loop = asyncio.get_running_loop()
|
|
160
|
+
state, result = await loop.run_in_executor(
|
|
161
|
+
self._executor, self._process_speech_segment, self._audio_buffer
|
|
162
|
+
)
|
|
155
163
|
if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT:
|
|
156
164
|
self._clear(state)
|
|
157
165
|
logger.debug(f"End of Turn result: {state}")
|
|
@@ -169,9 +177,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
|
|
169
177
|
self._speech_start_time = 0
|
|
170
178
|
self._silence_ms = 0
|
|
171
179
|
|
|
172
|
-
|
|
173
|
-
self, audio_buffer
|
|
174
|
-
) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
|
180
|
+
def _process_speech_segment(self, audio_buffer) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
|
175
181
|
"""Process accumulated audio segment using ML model."""
|
|
176
182
|
state = EndOfTurnState.INCOMPLETE
|
|
177
183
|
|
|
@@ -203,7 +209,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
|
|
203
209
|
if len(segment_audio) > 0:
|
|
204
210
|
start_time = time.perf_counter()
|
|
205
211
|
try:
|
|
206
|
-
result =
|
|
212
|
+
result = self._predict_endpoint(segment_audio)
|
|
207
213
|
state = (
|
|
208
214
|
EndOfTurnState.COMPLETE
|
|
209
215
|
if result["prediction"] == 1
|
|
@@ -249,6 +255,6 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
|
|
249
255
|
return state, result_data
|
|
250
256
|
|
|
251
257
|
@abstractmethod
|
|
252
|
-
|
|
258
|
+
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
|
253
259
|
"""Predict end-of-turn using ML model from audio data."""
|
|
254
260
|
pass
|
|
File without changes
|
|
Binary file
|
|
@@ -104,11 +104,15 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
|
|
|
104
104
|
logger.error(f"Failed to send raw request to Daily Smart Turn: {e}")
|
|
105
105
|
raise Exception("Failed to send raw request to Daily Smart Turn.")
|
|
106
106
|
|
|
107
|
-
|
|
107
|
+
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
|
108
108
|
"""Predict end-of-turn using remote HTTP ML service."""
|
|
109
109
|
try:
|
|
110
110
|
serialized_array = self._serialize_array(audio_array)
|
|
111
|
-
|
|
111
|
+
loop = asyncio.get_running_loop()
|
|
112
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
113
|
+
self._send_raw_request(serialized_array), loop
|
|
114
|
+
)
|
|
115
|
+
return future.result()
|
|
112
116
|
except Exception as e:
|
|
113
117
|
logger.error(f"Smart turn prediction failed: {str(e)}")
|
|
114
118
|
# Return an incomplete prediction when a failure occurs
|
|
@@ -64,7 +64,7 @@ class LocalSmartTurnAnalyzer(BaseSmartTurn):
|
|
|
64
64
|
self._turn_model.eval()
|
|
65
65
|
logger.debug("Loaded Local Smart Turn")
|
|
66
66
|
|
|
67
|
-
|
|
67
|
+
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
|
68
68
|
"""Predict end-of-turn using local PyTorch model."""
|
|
69
69
|
inputs = self._turn_processor(
|
|
70
70
|
audio_array,
|
|
@@ -73,7 +73,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
|
|
73
73
|
self._turn_model.eval()
|
|
74
74
|
logger.debug("Loaded Local Smart Turn v2")
|
|
75
75
|
|
|
76
|
-
|
|
76
|
+
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
|
77
77
|
"""Predict end-of-turn using local PyTorch model."""
|
|
78
78
|
inputs = self._turn_processor(
|
|
79
79
|
audio_array,
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2025, Daily
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: BSD 2-Clause License
|
|
5
|
+
#
|
|
6
|
+
|
|
7
|
+
"""Local turn analyzer for on-device ML inference using the smart-turn-v3 model.
|
|
8
|
+
|
|
9
|
+
This module provides a smart turn analyzer that uses an ONNX model for
|
|
10
|
+
local end-of-turn detection without requiring network connectivity.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from loguru import logger
|
|
17
|
+
|
|
18
|
+
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import onnxruntime as ort
|
|
22
|
+
from transformers import WhisperFeatureExtractor
|
|
23
|
+
except ModuleNotFoundError as e:
|
|
24
|
+
logger.error(f"Exception: {e}")
|
|
25
|
+
logger.error(
|
|
26
|
+
"In order to use LocalSmartTurnAnalyzerV3, you need to `pip install pipecat-ai[local-smart-turn-v3]`."
|
|
27
|
+
)
|
|
28
|
+
raise Exception(f"Missing module: {e}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
|
32
|
+
"""Local turn analyzer using the smart-turn-v3 ONNX model.
|
|
33
|
+
|
|
34
|
+
Provides end-of-turn detection using locally-stored ONNX model,
|
|
35
|
+
enabling offline operation without network dependencies.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, *, smart_turn_model_path: Optional[str] = None, **kwargs):
|
|
39
|
+
"""Initialize the local ONNX smart-turn-v3 analyzer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
smart_turn_model_path: Path to the ONNX model file. If this is not
|
|
43
|
+
set, the bundled smart-turn-v3.0 model will be used.
|
|
44
|
+
**kwargs: Additional arguments passed to BaseSmartTurn.
|
|
45
|
+
"""
|
|
46
|
+
super().__init__(**kwargs)
|
|
47
|
+
|
|
48
|
+
logger.debug("Loading Local Smart Turn v3 model...")
|
|
49
|
+
|
|
50
|
+
if not smart_turn_model_path:
|
|
51
|
+
# Load bundled model
|
|
52
|
+
model_name = "smart-turn-v3.0.onnx"
|
|
53
|
+
package_path = "pipecat.audio.turn.smart_turn.data"
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
import importlib_resources as impresources
|
|
57
|
+
|
|
58
|
+
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
|
|
59
|
+
except BaseException:
|
|
60
|
+
from importlib import resources as impresources
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
with impresources.path(package_path, model_name) as f:
|
|
64
|
+
smart_turn_model_path = f
|
|
65
|
+
except BaseException:
|
|
66
|
+
smart_turn_model_path = str(
|
|
67
|
+
impresources.files(package_path).joinpath(model_name)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
so = ort.SessionOptions()
|
|
71
|
+
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
72
|
+
so.inter_op_num_threads = 1
|
|
73
|
+
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
74
|
+
|
|
75
|
+
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
|
76
|
+
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
|
77
|
+
|
|
78
|
+
logger.debug("Loaded Local Smart Turn v3")
|
|
79
|
+
|
|
80
|
+
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
|
81
|
+
"""Predict end-of-turn using local ONNX model."""
|
|
82
|
+
|
|
83
|
+
def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000):
|
|
84
|
+
"""Truncate audio to last n seconds or pad with zeros to meet n seconds."""
|
|
85
|
+
max_samples = n_seconds * sample_rate
|
|
86
|
+
if len(audio_array) > max_samples:
|
|
87
|
+
return audio_array[-max_samples:]
|
|
88
|
+
elif len(audio_array) < max_samples:
|
|
89
|
+
# Pad with zeros at the beginning
|
|
90
|
+
padding = max_samples - len(audio_array)
|
|
91
|
+
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
|
92
|
+
return audio_array
|
|
93
|
+
|
|
94
|
+
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
|
95
|
+
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
|
96
|
+
|
|
97
|
+
# Process audio using Whisper's feature extractor
|
|
98
|
+
inputs = self._feature_extractor(
|
|
99
|
+
audio_array,
|
|
100
|
+
sampling_rate=16000,
|
|
101
|
+
return_tensors="np",
|
|
102
|
+
padding="max_length",
|
|
103
|
+
max_length=8 * 16000,
|
|
104
|
+
truncation=True,
|
|
105
|
+
do_normalize=True,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Extract features and ensure correct shape for ONNX
|
|
109
|
+
input_features = inputs.input_features.squeeze(0).astype(np.float32)
|
|
110
|
+
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
|
|
111
|
+
|
|
112
|
+
# Run ONNX inference
|
|
113
|
+
outputs = self._session.run(None, {"input_features": input_features})
|
|
114
|
+
|
|
115
|
+
# Extract probability (ONNX model returns sigmoid probabilities)
|
|
116
|
+
probability = outputs[0][0].item()
|
|
117
|
+
|
|
118
|
+
# Make prediction (1 for Complete, 0 for Incomplete)
|
|
119
|
+
prediction = 1 if probability > 0.5 else 0
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"prediction": prediction,
|
|
123
|
+
"probability": probability,
|
|
124
|
+
}
|
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2024–2025, Daily
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: BSD 2-Clause License
|
|
5
|
+
#
|
|
6
|
+
|
|
7
|
+
"""Salesforce Agent API LLM service implementation."""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import AsyncGenerator, Dict, Optional
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from loguru import logger
|
|
17
|
+
|
|
18
|
+
from pipecat.frames.frames import (
|
|
19
|
+
Frame,
|
|
20
|
+
LLMFullResponseEndFrame,
|
|
21
|
+
LLMFullResponseStartFrame,
|
|
22
|
+
LLMMessagesFrame,
|
|
23
|
+
LLMTextFrame,
|
|
24
|
+
LLMUpdateSettingsFrame,
|
|
25
|
+
)
|
|
26
|
+
from pipecat.processors.aggregators.openai_llm_context import (
|
|
27
|
+
OpenAILLMContext,
|
|
28
|
+
OpenAILLMContextFrame,
|
|
29
|
+
)
|
|
30
|
+
from pipecat.processors.frame_processor import FrameDirection
|
|
31
|
+
from pipecat.services.llm_service import LLMService
|
|
32
|
+
from pipecat.services.openai.llm import (
|
|
33
|
+
OpenAIAssistantContextAggregator,
|
|
34
|
+
OpenAIContextAggregatorPair,
|
|
35
|
+
OpenAIUserContextAggregator,
|
|
36
|
+
)
|
|
37
|
+
from pipecat.processors.aggregators.llm_response import (
|
|
38
|
+
LLMAssistantAggregatorParams,
|
|
39
|
+
LLMUserAggregatorParams,
|
|
40
|
+
)
|
|
41
|
+
from env_config import api_config
|
|
42
|
+
from pipecat.utils.redis import create_async_redis_client
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class SalesforceSessionInfo:
|
|
47
|
+
"""Information about an active Salesforce Agent session."""
|
|
48
|
+
|
|
49
|
+
session_id: str
|
|
50
|
+
agent_id: str
|
|
51
|
+
created_at: float
|
|
52
|
+
last_used: float
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class SalesforceAgentLLMService(LLMService):
|
|
56
|
+
"""Salesforce Agent API LLM service implementation.
|
|
57
|
+
|
|
58
|
+
This service integrates with Salesforce Agent API to provide conversational
|
|
59
|
+
AI capabilities using Salesforce's Agentforce platform.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
*,
|
|
65
|
+
model: str = "salesforce-agent",
|
|
66
|
+
session_timeout_secs: float = 3600.0,
|
|
67
|
+
agent_id: str = api_config.SALESFORCE_AGENT_ID,
|
|
68
|
+
org_domain: str = api_config.SALESFORCE_ORG_DOMAIN,
|
|
69
|
+
client_id: str = api_config.SALESFORCE_CLIENT_ID,
|
|
70
|
+
client_secret: str = api_config.SALESFORCE_CLIENT_SECRET,
|
|
71
|
+
api_host: str = api_config.SALESFORCE_API_HOST,
|
|
72
|
+
redis_url: Optional[str] = None,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
75
|
+
"""Initialize Salesforce Agent LLM service.
|
|
76
|
+
|
|
77
|
+
Reads configuration from environment variables:
|
|
78
|
+
- SALESFORCE_AGENT_ID: The Salesforce agent ID to interact with
|
|
79
|
+
- SALESFORCE_ORG_DOMAIN: Salesforce org domain (e.g., https://myorg.my.salesforce.com)
|
|
80
|
+
- SALESFORCE_CLIENT_ID: Connected app client ID for OAuth
|
|
81
|
+
- SALESFORCE_CLIENT_SECRET: Connected app client secret for OAuth
|
|
82
|
+
- SALESFORCE_API_HOST: Salesforce API host base URL (e.g., https://api.salesforce.com)
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
model: The model name (defaults to "salesforce-agent").
|
|
86
|
+
session_timeout_secs: Session timeout in seconds (default: 1 hour).
|
|
87
|
+
agent_id: Salesforce agent ID. Defaults to SALESFORCE_AGENT_ID.
|
|
88
|
+
org_domain: Salesforce org domain. Defaults to SALESFORCE_ORG_DOMAIN.
|
|
89
|
+
client_id: Salesforce connected app client ID. Defaults to SALESFORCE_CLIENT_ID.
|
|
90
|
+
client_secret: Salesforce connected app client secret. Defaults to SALESFORCE_CLIENT_SECRET.
|
|
91
|
+
api_host: Salesforce API host base URL. Defaults to SALESFORCE_API_HOST.
|
|
92
|
+
redis_url: Optional Redis URL override for token caching.
|
|
93
|
+
**kwargs: Additional arguments passed to parent LLMService.
|
|
94
|
+
"""
|
|
95
|
+
# Initialize parent LLM service
|
|
96
|
+
super().__init__(**kwargs)
|
|
97
|
+
|
|
98
|
+
self._agent_id = agent_id
|
|
99
|
+
self._org_domain = org_domain
|
|
100
|
+
self._client_id = client_id
|
|
101
|
+
self._client_secret = client_secret
|
|
102
|
+
self._api_host = api_host
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# Validate required environment variables
|
|
106
|
+
required_vars = {
|
|
107
|
+
"SALESFORCE_AGENT_ID": self._agent_id,
|
|
108
|
+
"SALESFORCE_ORG_DOMAIN": self._org_domain,
|
|
109
|
+
"SALESFORCE_API_HOST": self._api_host,
|
|
110
|
+
"SALESFORCE_CLIENT_ID": self._client_id,
|
|
111
|
+
"SALESFORCE_CLIENT_SECRET": self._client_secret,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
missing_vars = [var for var, value in required_vars.items() if not value]
|
|
115
|
+
if missing_vars:
|
|
116
|
+
raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
|
|
117
|
+
|
|
118
|
+
logger.info(f"Salesforce LLM initialized - Agent ID: {self._agent_id}")
|
|
119
|
+
|
|
120
|
+
self._session_timeout_secs = session_timeout_secs
|
|
121
|
+
|
|
122
|
+
if redis_url is not None:
|
|
123
|
+
self._redis_url = redis_url
|
|
124
|
+
else:
|
|
125
|
+
self._redis_url = getattr(api_config, "REDIS_URL", None)
|
|
126
|
+
self._redis_client = None
|
|
127
|
+
self._redis_client_init_attempted = False
|
|
128
|
+
self._token_cache_key = f"salesforce_agent_access_token:{self._agent_id}"
|
|
129
|
+
self._token_cache_leeway_secs = 300
|
|
130
|
+
self._sequence_counter = 0
|
|
131
|
+
self._warmup_task: Optional[asyncio.Task] = None
|
|
132
|
+
|
|
133
|
+
# Session management
|
|
134
|
+
self._sessions: Dict[str, SalesforceSessionInfo] = {}
|
|
135
|
+
self._current_session_id: Optional[str] = None
|
|
136
|
+
|
|
137
|
+
# HTTP client for API calls
|
|
138
|
+
self._http_client = httpx.AsyncClient(
|
|
139
|
+
timeout=30.0,
|
|
140
|
+
limits=httpx.Limits(
|
|
141
|
+
max_keepalive_connections=10,
|
|
142
|
+
max_connections=100,
|
|
143
|
+
keepalive_expiry=None,
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
self._schedule_session_warmup()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
async def __aenter__(self):
|
|
151
|
+
"""Async context manager entry."""
|
|
152
|
+
await self.ensure_session_ready()
|
|
153
|
+
return self
|
|
154
|
+
|
|
155
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
156
|
+
"""Async context manager exit."""
|
|
157
|
+
if self._warmup_task:
|
|
158
|
+
try:
|
|
159
|
+
await asyncio.shield(self._warmup_task)
|
|
160
|
+
except Exception as exc: # pragma: no cover - warmup best effort
|
|
161
|
+
logger.debug(f"Salesforce warmup task failed during exit: {exc}")
|
|
162
|
+
finally:
|
|
163
|
+
self._warmup_task = None
|
|
164
|
+
|
|
165
|
+
await self._cleanup_sessions()
|
|
166
|
+
await self._http_client.aclose()
|
|
167
|
+
|
|
168
|
+
if self._redis_client:
|
|
169
|
+
close_coro = getattr(self._redis_client, "close", None)
|
|
170
|
+
if callable(close_coro):
|
|
171
|
+
try:
|
|
172
|
+
await close_coro()
|
|
173
|
+
except Exception as exc: # pragma: no cover - best effort cleanup
|
|
174
|
+
logger.debug(f"Failed to close Redis client cleanly: {exc}")
|
|
175
|
+
self._redis_client = None
|
|
176
|
+
self._redis_client_init_attempted = False
|
|
177
|
+
|
|
178
|
+
def can_generate_metrics(self) -> bool:
|
|
179
|
+
"""Check if this service can generate processing metrics."""
|
|
180
|
+
return True
|
|
181
|
+
|
|
182
|
+
def _schedule_session_warmup(self):
|
|
183
|
+
"""Kick off background warm-up if an event loop is running."""
|
|
184
|
+
try:
|
|
185
|
+
loop = asyncio.get_running_loop()
|
|
186
|
+
except RuntimeError:
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
if loop.is_closed():
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
async def _warmup():
|
|
193
|
+
try:
|
|
194
|
+
await self.ensure_session_ready()
|
|
195
|
+
except Exception as exc: # pragma: no cover - warmup best effort
|
|
196
|
+
logger.warning(f"Salesforce warmup failed: {exc}")
|
|
197
|
+
raise
|
|
198
|
+
|
|
199
|
+
task = loop.create_task(_warmup())
|
|
200
|
+
|
|
201
|
+
def _on_done(warmup_task: asyncio.Task):
|
|
202
|
+
if warmup_task.cancelled():
|
|
203
|
+
logger.debug("Salesforce warmup task cancelled")
|
|
204
|
+
elif warmup_task.exception():
|
|
205
|
+
logger.warning(f"Salesforce warmup task error: {warmup_task.exception()}")
|
|
206
|
+
self._warmup_task = None
|
|
207
|
+
|
|
208
|
+
task.add_done_callback(_on_done)
|
|
209
|
+
self._warmup_task = task
|
|
210
|
+
|
|
211
|
+
def _get_redis_client(self):
|
|
212
|
+
"""Return a Redis client for token caching if configured."""
|
|
213
|
+
if self._redis_client is None and not self._redis_client_init_attempted:
|
|
214
|
+
self._redis_client_init_attempted = True
|
|
215
|
+
self._redis_client = create_async_redis_client(
|
|
216
|
+
self._redis_url, decode_responses=True, encoding="utf-8", logger=logger
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return self._redis_client
|
|
220
|
+
|
|
221
|
+
async def _get_cached_access_token(self) -> Optional[str]:
|
|
222
|
+
"""Return cached access token from Redis."""
|
|
223
|
+
redis_client = self._get_redis_client()
|
|
224
|
+
if not redis_client:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
return await redis_client.get(self._token_cache_key)
|
|
229
|
+
except Exception as exc: # pragma: no cover - cache failures shouldn't break flow
|
|
230
|
+
logger.warning(f"Failed to read Salesforce token from Redis: {exc}")
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
async def _set_cached_access_token(self, token: str, expires_in: Optional[int]):
|
|
234
|
+
"""Persist access token in Redis with TTL matching Salesforce expiry."""
|
|
235
|
+
redis_client = self._get_redis_client()
|
|
236
|
+
if not redis_client:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
ttl_seconds = 3600
|
|
240
|
+
if expires_in is not None:
|
|
241
|
+
try:
|
|
242
|
+
ttl_seconds = max(int(expires_in) - self._token_cache_leeway_secs, 30)
|
|
243
|
+
except (TypeError, ValueError):
|
|
244
|
+
logger.debug("Unable to parse Salesforce token expiry; falling back to default TTL")
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
await redis_client.set(self._token_cache_key, token, ex=ttl_seconds)
|
|
248
|
+
except Exception as exc: # pragma: no cover - cache failures shouldn't break flow
|
|
249
|
+
logger.warning(f"Failed to store Salesforce token in Redis: {exc}")
|
|
250
|
+
|
|
251
|
+
async def _get_access_token(self) -> str:
|
|
252
|
+
"""Get OAuth access token using client credentials."""
|
|
253
|
+
cached_token = await self._get_cached_access_token()
|
|
254
|
+
if cached_token:
|
|
255
|
+
return cached_token
|
|
256
|
+
|
|
257
|
+
token_url = f"{self._org_domain}/services/oauth2/token"
|
|
258
|
+
data = {
|
|
259
|
+
"grant_type": "client_credentials",
|
|
260
|
+
"client_id": self._client_id,
|
|
261
|
+
"client_secret": self._client_secret,
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
response = await self._http_client.post(token_url, data=data)
|
|
266
|
+
response.raise_for_status()
|
|
267
|
+
token_data = response.json()
|
|
268
|
+
access_token = token_data["access_token"]
|
|
269
|
+
await self._set_cached_access_token(access_token, token_data.get("expires_in"))
|
|
270
|
+
return access_token
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error(f"Failed to get access token: {e}")
|
|
273
|
+
raise
|
|
274
|
+
|
|
275
|
+
async def _create_session(self) -> str:
|
|
276
|
+
"""Create a new Salesforce Agent session."""
|
|
277
|
+
access_token = await self._get_access_token()
|
|
278
|
+
session_url = f"{self._api_host}/einstein/ai-agent/v1/agents/{self._agent_id}/sessions"
|
|
279
|
+
|
|
280
|
+
headers = {
|
|
281
|
+
"Authorization": f"Bearer {access_token}",
|
|
282
|
+
"Content-Type": "application/json",
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
external_session_key = f"pipecat-{int(time.time())}-{id(self)}"
|
|
286
|
+
|
|
287
|
+
payload = {
|
|
288
|
+
"externalSessionKey": external_session_key,
|
|
289
|
+
"instanceConfig": {"endpoint": self._org_domain},
|
|
290
|
+
"tz": "America/Los_Angeles",
|
|
291
|
+
"variables": [{"name": "$Context.EndUserLanguage", "type": "Text", "value": "en_US"}],
|
|
292
|
+
"featureSupport": "Streaming",
|
|
293
|
+
"streamingCapabilities": {"chunkTypes": ["Text"]},
|
|
294
|
+
"bypassUser": True,
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
response = await self._http_client.post(session_url, headers=headers, json=payload)
|
|
299
|
+
response.raise_for_status()
|
|
300
|
+
session_data = response.json()
|
|
301
|
+
session_id = session_data["sessionId"]
|
|
302
|
+
|
|
303
|
+
# Store session info
|
|
304
|
+
current_time = time.time()
|
|
305
|
+
self._sessions[session_id] = SalesforceSessionInfo(
|
|
306
|
+
session_id=session_id,
|
|
307
|
+
agent_id=self._agent_id,
|
|
308
|
+
created_at=current_time,
|
|
309
|
+
last_used=current_time,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
logger.debug(f"Created Salesforce Agent session: {session_id}")
|
|
313
|
+
return session_id
|
|
314
|
+
|
|
315
|
+
except Exception as e:
|
|
316
|
+
logger.error(f"Failed to create Salesforce Agent session: {e}")
|
|
317
|
+
raise
|
|
318
|
+
|
|
319
|
+
async def _get_or_create_session(self) -> str:
|
|
320
|
+
"""Get existing session or create a new one."""
|
|
321
|
+
current_time = time.time()
|
|
322
|
+
|
|
323
|
+
# Check if current session is still valid
|
|
324
|
+
if self._current_session_id and self._current_session_id in self._sessions:
|
|
325
|
+
session = self._sessions[self._current_session_id]
|
|
326
|
+
if current_time - session.last_used < self._session_timeout_secs:
|
|
327
|
+
session.last_used = current_time
|
|
328
|
+
return self._current_session_id
|
|
329
|
+
else:
|
|
330
|
+
# Session expired, remove it
|
|
331
|
+
self._sessions.pop(self._current_session_id, None)
|
|
332
|
+
self._current_session_id = None
|
|
333
|
+
|
|
334
|
+
# Create new session
|
|
335
|
+
self._current_session_id = await self._create_session()
|
|
336
|
+
return self._current_session_id
|
|
337
|
+
|
|
338
|
+
async def ensure_session_ready(self) -> str:
|
|
339
|
+
"""Ensure a Salesforce session is ready for use."""
|
|
340
|
+
return await self._get_or_create_session()
|
|
341
|
+
|
|
342
|
+
async def _cleanup_sessions(self):
|
|
343
|
+
"""Clean up expired sessions."""
|
|
344
|
+
current_time = time.time()
|
|
345
|
+
expired_sessions = []
|
|
346
|
+
|
|
347
|
+
for session_id, session in self._sessions.items():
|
|
348
|
+
if current_time - session.last_used > self._session_timeout_secs:
|
|
349
|
+
expired_sessions.append(session_id)
|
|
350
|
+
|
|
351
|
+
for session_id in expired_sessions:
|
|
352
|
+
try:
|
|
353
|
+
# End the session via API
|
|
354
|
+
access_token = await self._get_access_token()
|
|
355
|
+
url = f"{self._api_host}/einstein/ai-agent/v1/sessions/{session_id}"
|
|
356
|
+
headers = {
|
|
357
|
+
"Authorization": f"Bearer {access_token}",
|
|
358
|
+
"x-session-end-reason": "UserRequest",
|
|
359
|
+
}
|
|
360
|
+
await self._http_client.delete(url, headers=headers)
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.warning(f"Failed to end session {session_id}: {e}")
|
|
363
|
+
finally:
|
|
364
|
+
self._sessions.pop(session_id, None)
|
|
365
|
+
if self._current_session_id == session_id:
|
|
366
|
+
self._current_session_id = None
|
|
367
|
+
|
|
368
|
+
def _extract_user_message(self, context: OpenAILLMContext) -> str:
|
|
369
|
+
"""Extract the last user message from context.
|
|
370
|
+
|
|
371
|
+
Similar to Vistaar pattern - extract only the most recent user message.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
context: The OpenAI LLM context containing messages.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
The last user message as a string.
|
|
378
|
+
"""
|
|
379
|
+
messages = context.get_messages()
|
|
380
|
+
|
|
381
|
+
# Find the last user message (iterate in reverse for efficiency)
|
|
382
|
+
for message in reversed(messages):
|
|
383
|
+
if message.get("role") == "user":
|
|
384
|
+
content = message.get("content", "")
|
|
385
|
+
|
|
386
|
+
# Handle content that might be a list (for multimodal messages)
|
|
387
|
+
if isinstance(content, list):
|
|
388
|
+
text_parts = [
|
|
389
|
+
item.get("text", "") for item in content if item.get("type") == "text"
|
|
390
|
+
]
|
|
391
|
+
content = " ".join(text_parts)
|
|
392
|
+
|
|
393
|
+
if isinstance(content, str):
|
|
394
|
+
return content.strip()
|
|
395
|
+
|
|
396
|
+
return ""
|
|
397
|
+
|
|
398
|
+
def _generate_sequence_id(self) -> int:
|
|
399
|
+
"""Generate a sequence ID for the message."""
|
|
400
|
+
self._sequence_counter += 1
|
|
401
|
+
return self._sequence_counter
|
|
402
|
+
|
|
403
|
+
async def _stream_salesforce_response(self, session_id: str, user_message: str) -> AsyncGenerator[str, None]:
|
|
404
|
+
"""Stream response from Salesforce Agent API."""
|
|
405
|
+
access_token = await self._get_access_token()
|
|
406
|
+
url = f"{self._api_host}/einstein/ai-agent/v1/sessions/{session_id}/messages/stream"
|
|
407
|
+
|
|
408
|
+
headers = {
|
|
409
|
+
"Authorization": f"Bearer {access_token}",
|
|
410
|
+
"Content-Type": "application/json",
|
|
411
|
+
"Accept": "text/event-stream",
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
message_data = {
|
|
415
|
+
"message": {
|
|
416
|
+
"sequenceId": self._generate_sequence_id(),
|
|
417
|
+
"type": "Text",
|
|
418
|
+
"text": user_message
|
|
419
|
+
},
|
|
420
|
+
"variables": [
|
|
421
|
+
{
|
|
422
|
+
"name": "$Context.EndUserLanguage",
|
|
423
|
+
"type": "Text",
|
|
424
|
+
"value": "en_US"
|
|
425
|
+
}
|
|
426
|
+
]
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
try:
|
|
430
|
+
logger.info(f"🌐 Salesforce API request: {user_message[:50]}...")
|
|
431
|
+
async with self._http_client.stream("POST", url, headers=headers, json=message_data) as response:
|
|
432
|
+
response.raise_for_status()
|
|
433
|
+
|
|
434
|
+
async for line in response.aiter_lines():
|
|
435
|
+
if not line:
|
|
436
|
+
continue
|
|
437
|
+
|
|
438
|
+
# Parse SSE format
|
|
439
|
+
if line.startswith("data: "):
|
|
440
|
+
try:
|
|
441
|
+
data = json.loads(line[6:])
|
|
442
|
+
message = data.get("message", {})
|
|
443
|
+
message_type = message.get("type")
|
|
444
|
+
|
|
445
|
+
if message_type == "TextChunk":
|
|
446
|
+
content = message.get("text", "") or message.get("message", "")
|
|
447
|
+
if content:
|
|
448
|
+
yield content
|
|
449
|
+
elif message_type == "EndOfTurn":
|
|
450
|
+
logger.info("🏁 Salesforce response complete")
|
|
451
|
+
break
|
|
452
|
+
elif message_type == "Inform":
|
|
453
|
+
# Skip INFORM events to avoid duplication
|
|
454
|
+
continue
|
|
455
|
+
|
|
456
|
+
except json.JSONDecodeError as e:
|
|
457
|
+
logger.warning(f"JSON decode error: {e}, line: {line}")
|
|
458
|
+
continue
|
|
459
|
+
|
|
460
|
+
except Exception as e:
|
|
461
|
+
logger.error(f"Failed to stream from Salesforce Agent API: {e}")
|
|
462
|
+
raise
|
|
463
|
+
|
|
464
|
+
async def _process_context(self, context: OpenAILLMContext):
|
|
465
|
+
"""Process the LLM context and generate streaming response.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
context: The OpenAI LLM context containing messages to process.
|
|
469
|
+
"""
|
|
470
|
+
logger.info(f"🔄 Salesforce processing context with {len(context.get_messages())} messages")
|
|
471
|
+
|
|
472
|
+
# Extract user message from context first
|
|
473
|
+
user_message = self._extract_user_message(context)
|
|
474
|
+
|
|
475
|
+
if not user_message:
|
|
476
|
+
logger.warning("Salesforce: No user message found in context")
|
|
477
|
+
return
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
logger.info(f"🎯 Salesforce extracted query: {user_message}")
|
|
481
|
+
|
|
482
|
+
# Start response
|
|
483
|
+
await self.push_frame(LLMFullResponseStartFrame())
|
|
484
|
+
await self.push_frame(LLMFullResponseStartFrame(),FrameDirection.UPSTREAM)
|
|
485
|
+
await self.start_processing_metrics()
|
|
486
|
+
await self.start_ttfb_metrics()
|
|
487
|
+
|
|
488
|
+
# Get or create session
|
|
489
|
+
session_id = await self._get_or_create_session()
|
|
490
|
+
|
|
491
|
+
first_chunk = True
|
|
492
|
+
|
|
493
|
+
# Stream the response
|
|
494
|
+
async for text_chunk in self._stream_salesforce_response(session_id, user_message):
|
|
495
|
+
if first_chunk:
|
|
496
|
+
await self.stop_ttfb_metrics()
|
|
497
|
+
first_chunk = False
|
|
498
|
+
|
|
499
|
+
# Push each text chunk as it arrives
|
|
500
|
+
await self.push_frame(LLMTextFrame(text=text_chunk))
|
|
501
|
+
|
|
502
|
+
except Exception as e:
|
|
503
|
+
logger.error(f"Salesforce context processing error: {type(e).__name__}: {str(e)}")
|
|
504
|
+
import traceback
|
|
505
|
+
logger.error(f"Salesforce traceback: {traceback.format_exc()}")
|
|
506
|
+
raise
|
|
507
|
+
finally:
|
|
508
|
+
await self.stop_processing_metrics()
|
|
509
|
+
await self.push_frame(LLMFullResponseEndFrame())
|
|
510
|
+
await self.push_frame(LLMFullResponseEndFrame(), FrameDirection.UPSTREAM)
|
|
511
|
+
|
|
512
|
+
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
513
|
+
"""Process frames for LLM completion requests.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
frame: The frame to process.
|
|
517
|
+
direction: The direction of frame processing.
|
|
518
|
+
"""
|
|
519
|
+
context = None
|
|
520
|
+
if isinstance(frame, OpenAILLMContextFrame):
|
|
521
|
+
context = frame.context
|
|
522
|
+
logger.info(f"🔍 Received OpenAILLMContextFrame with {len(context.get_messages())} messages")
|
|
523
|
+
elif isinstance(frame, LLMMessagesFrame):
|
|
524
|
+
context = OpenAILLMContext.from_messages(frame.messages)
|
|
525
|
+
logger.info(f"🔍 Received LLMMessagesFrame with {len(frame.messages)} messages")
|
|
526
|
+
elif isinstance(frame, LLMUpdateSettingsFrame):
|
|
527
|
+
# Call super for settings frames and update settings
|
|
528
|
+
await super().process_frame(frame, direction)
|
|
529
|
+
settings = frame.settings
|
|
530
|
+
logger.debug(f"Updated Salesforce settings: {settings}")
|
|
531
|
+
else:
|
|
532
|
+
# For non-context frames, call super and push them downstream
|
|
533
|
+
await super().process_frame(frame, direction)
|
|
534
|
+
await self.push_frame(frame, direction)
|
|
535
|
+
|
|
536
|
+
if context:
|
|
537
|
+
try:
|
|
538
|
+
await self._process_context(context)
|
|
539
|
+
except httpx.TimeoutException:
|
|
540
|
+
logger.error("Timeout while processing Salesforce request")
|
|
541
|
+
await self._call_event_handler("on_completion_timeout")
|
|
542
|
+
except Exception as e:
|
|
543
|
+
logger.error(f"Error processing Salesforce request: {e}")
|
|
544
|
+
raise
|
|
545
|
+
|
|
546
|
+
def create_context_aggregator(
|
|
547
|
+
self,
|
|
548
|
+
context: OpenAILLMContext,
|
|
549
|
+
*,
|
|
550
|
+
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
|
551
|
+
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
|
552
|
+
) -> OpenAIContextAggregatorPair:
|
|
553
|
+
"""Create context aggregators for Salesforce LLM.
|
|
554
|
+
|
|
555
|
+
Since Salesforce uses OpenAI-compatible message format, we reuse OpenAI's
|
|
556
|
+
context aggregators directly
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
context: The LLM context to create aggregators for.
|
|
560
|
+
user_params: Parameters for user message aggregation.
|
|
561
|
+
assistant_params: Parameters for assistant message aggregation.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
OpenAIContextAggregatorPair: A pair of OpenAI context aggregators,
|
|
565
|
+
compatible with Salesforce's OpenAI-like message format.
|
|
566
|
+
"""
|
|
567
|
+
context.set_llm_adapter(self.get_llm_adapter())
|
|
568
|
+
user = OpenAIUserContextAggregator(context, params=user_params)
|
|
569
|
+
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
|
570
|
+
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
|
571
|
+
|
|
572
|
+
def get_llm_adapter(self):
|
|
573
|
+
"""Get the LLM adapter for this service."""
|
|
574
|
+
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
|
575
|
+
return OpenAILLMAdapter()
|
|
576
|
+
|
|
577
|
+
async def close(self):
|
|
578
|
+
"""Close the HTTP client when the service is destroyed."""
|
|
579
|
+
await self._cleanup_sessions()
|
|
580
|
+
await self._http_client.aclose()
|
|
581
|
+
|
|
582
|
+
def __del__(self):
|
|
583
|
+
"""Ensure the client is closed on deletion."""
|
|
584
|
+
try:
|
|
585
|
+
asyncio.create_task(self._http_client.aclose())
|
|
586
|
+
except:
|
|
587
|
+
pass
|
pipecat/utils/redis.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Async Redis helper utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Optional, TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from urllib.parse import urlparse
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import redis.asyncio as redis
|
|
11
|
+
except ImportError: # pragma: no cover - Redis is optional
|
|
12
|
+
redis = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING: # pragma: no cover - typing aid
|
|
16
|
+
from redis.asyncio import Redis
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_async_redis_client(
|
|
20
|
+
url: Optional[str],
|
|
21
|
+
*,
|
|
22
|
+
decode_responses: bool = True,
|
|
23
|
+
encoding: str = "utf-8",
|
|
24
|
+
logger: Optional[Any] = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
) -> Optional["Redis"]:
|
|
27
|
+
"""Return a configured async Redis client or None if unavailable.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
url: Redis connection URL.
|
|
31
|
+
decode_responses: Whether to decode responses to str.
|
|
32
|
+
encoding: Character encoding to use with decoded responses.
|
|
33
|
+
logger: Optional logger supporting .warning() for diagnostics.
|
|
34
|
+
**kwargs: Additional keyword arguments forwarded to Redis.from_url.
|
|
35
|
+
"""
|
|
36
|
+
if redis is None:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
if not url or url in {"redis_url", "REDIS_URL"}:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
parsed = urlparse(url)
|
|
43
|
+
connection_kwargs = {
|
|
44
|
+
"decode_responses": decode_responses,
|
|
45
|
+
"encoding": encoding,
|
|
46
|
+
}
|
|
47
|
+
connection_kwargs.update(kwargs)
|
|
48
|
+
|
|
49
|
+
if parsed.scheme == "rediss":
|
|
50
|
+
connection_kwargs.setdefault("ssl_cert_reqs", "none")
|
|
51
|
+
connection_kwargs.setdefault("ssl_check_hostname", False)
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
return redis.Redis.from_url(url, **connection_kwargs)
|
|
55
|
+
except Exception as exc: # pragma: no cover - best effort logging
|
|
56
|
+
if logger is not None:
|
|
57
|
+
logger.warning(f"Failed to create Redis client: {exc}")
|
|
58
|
+
return None
|
|
File without changes
|
{dv_pipecat_ai-0.0.85.dev7.dist-info → dv_pipecat_ai-0.0.85.dev11.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|