dv-pipecat-ai 0.0.74.dev770__py3-none-any.whl → 0.0.82.dev776__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dv-pipecat-ai might be problematic. Click here for more details.
- {dv_pipecat_ai-0.0.74.dev770.dist-info → dv_pipecat_ai-0.0.82.dev776.dist-info}/METADATA +137 -93
- dv_pipecat_ai-0.0.82.dev776.dist-info/RECORD +340 -0
- pipecat/__init__.py +17 -0
- pipecat/adapters/base_llm_adapter.py +36 -1
- pipecat/adapters/schemas/direct_function.py +296 -0
- pipecat/adapters/schemas/function_schema.py +15 -6
- pipecat/adapters/schemas/tools_schema.py +55 -7
- pipecat/adapters/services/anthropic_adapter.py +22 -3
- pipecat/adapters/services/aws_nova_sonic_adapter.py +23 -3
- pipecat/adapters/services/bedrock_adapter.py +22 -3
- pipecat/adapters/services/gemini_adapter.py +16 -3
- pipecat/adapters/services/open_ai_adapter.py +17 -2
- pipecat/adapters/services/open_ai_realtime_adapter.py +23 -3
- pipecat/audio/filters/base_audio_filter.py +30 -6
- pipecat/audio/filters/koala_filter.py +37 -2
- pipecat/audio/filters/krisp_filter.py +59 -6
- pipecat/audio/filters/noisereduce_filter.py +37 -0
- pipecat/audio/interruptions/base_interruption_strategy.py +25 -5
- pipecat/audio/interruptions/min_words_interruption_strategy.py +21 -4
- pipecat/audio/mixers/base_audio_mixer.py +30 -7
- pipecat/audio/mixers/soundfile_mixer.py +53 -6
- pipecat/audio/resamplers/base_audio_resampler.py +17 -9
- pipecat/audio/resamplers/resampy_resampler.py +26 -1
- pipecat/audio/resamplers/soxr_resampler.py +32 -1
- pipecat/audio/resamplers/soxr_stream_resampler.py +101 -0
- pipecat/audio/utils.py +194 -1
- pipecat/audio/vad/silero.py +60 -3
- pipecat/audio/vad/vad_analyzer.py +114 -30
- pipecat/clocks/base_clock.py +19 -0
- pipecat/clocks/system_clock.py +25 -0
- pipecat/extensions/voicemail/__init__.py +0 -0
- pipecat/extensions/voicemail/voicemail_detector.py +707 -0
- pipecat/frames/frames.py +590 -156
- pipecat/metrics/metrics.py +64 -1
- pipecat/observers/base_observer.py +58 -19
- pipecat/observers/loggers/debug_log_observer.py +56 -64
- pipecat/observers/loggers/llm_log_observer.py +8 -1
- pipecat/observers/loggers/transcription_log_observer.py +19 -7
- pipecat/observers/loggers/user_bot_latency_log_observer.py +32 -5
- pipecat/observers/turn_tracking_observer.py +26 -1
- pipecat/pipeline/base_pipeline.py +5 -7
- pipecat/pipeline/base_task.py +52 -9
- pipecat/pipeline/parallel_pipeline.py +121 -177
- pipecat/pipeline/pipeline.py +129 -20
- pipecat/pipeline/runner.py +50 -1
- pipecat/pipeline/sync_parallel_pipeline.py +132 -32
- pipecat/pipeline/task.py +263 -280
- pipecat/pipeline/task_observer.py +85 -34
- pipecat/pipeline/to_be_updated/merge_pipeline.py +32 -2
- pipecat/processors/aggregators/dtmf_aggregator.py +29 -22
- pipecat/processors/aggregators/gated.py +25 -24
- pipecat/processors/aggregators/gated_openai_llm_context.py +22 -2
- pipecat/processors/aggregators/llm_response.py +398 -89
- pipecat/processors/aggregators/openai_llm_context.py +161 -13
- pipecat/processors/aggregators/sentence.py +25 -14
- pipecat/processors/aggregators/user_response.py +28 -3
- pipecat/processors/aggregators/vision_image_frame.py +24 -14
- pipecat/processors/async_generator.py +28 -0
- pipecat/processors/audio/audio_buffer_processor.py +78 -37
- pipecat/processors/consumer_processor.py +25 -6
- pipecat/processors/filters/frame_filter.py +23 -0
- pipecat/processors/filters/function_filter.py +30 -0
- pipecat/processors/filters/identity_filter.py +17 -2
- pipecat/processors/filters/null_filter.py +24 -1
- pipecat/processors/filters/stt_mute_filter.py +56 -21
- pipecat/processors/filters/wake_check_filter.py +46 -3
- pipecat/processors/filters/wake_notifier_filter.py +21 -3
- pipecat/processors/frame_processor.py +488 -131
- pipecat/processors/frameworks/langchain.py +38 -3
- pipecat/processors/frameworks/rtvi.py +719 -34
- pipecat/processors/gstreamer/pipeline_source.py +41 -0
- pipecat/processors/idle_frame_processor.py +26 -3
- pipecat/processors/logger.py +23 -0
- pipecat/processors/metrics/frame_processor_metrics.py +77 -4
- pipecat/processors/metrics/sentry.py +42 -4
- pipecat/processors/producer_processor.py +34 -14
- pipecat/processors/text_transformer.py +22 -10
- pipecat/processors/transcript_processor.py +48 -29
- pipecat/processors/user_idle_processor.py +31 -21
- pipecat/runner/__init__.py +1 -0
- pipecat/runner/daily.py +132 -0
- pipecat/runner/livekit.py +148 -0
- pipecat/runner/run.py +543 -0
- pipecat/runner/types.py +67 -0
- pipecat/runner/utils.py +515 -0
- pipecat/serializers/base_serializer.py +42 -0
- pipecat/serializers/exotel.py +17 -6
- pipecat/serializers/genesys.py +95 -0
- pipecat/serializers/livekit.py +33 -0
- pipecat/serializers/plivo.py +16 -15
- pipecat/serializers/protobuf.py +37 -1
- pipecat/serializers/telnyx.py +18 -17
- pipecat/serializers/twilio.py +32 -16
- pipecat/services/ai_service.py +5 -3
- pipecat/services/anthropic/llm.py +113 -43
- pipecat/services/assemblyai/models.py +63 -5
- pipecat/services/assemblyai/stt.py +64 -11
- pipecat/services/asyncai/__init__.py +0 -0
- pipecat/services/asyncai/tts.py +501 -0
- pipecat/services/aws/llm.py +185 -111
- pipecat/services/aws/stt.py +217 -23
- pipecat/services/aws/tts.py +118 -52
- pipecat/services/aws/utils.py +101 -5
- pipecat/services/aws_nova_sonic/aws.py +82 -64
- pipecat/services/aws_nova_sonic/context.py +15 -6
- pipecat/services/azure/common.py +10 -2
- pipecat/services/azure/image.py +32 -0
- pipecat/services/azure/llm.py +9 -7
- pipecat/services/azure/stt.py +65 -2
- pipecat/services/azure/tts.py +154 -23
- pipecat/services/cartesia/stt.py +125 -8
- pipecat/services/cartesia/tts.py +102 -38
- pipecat/services/cerebras/llm.py +15 -23
- pipecat/services/deepgram/stt.py +19 -11
- pipecat/services/deepgram/tts.py +36 -0
- pipecat/services/deepseek/llm.py +14 -23
- pipecat/services/elevenlabs/tts.py +330 -64
- pipecat/services/fal/image.py +43 -0
- pipecat/services/fal/stt.py +48 -10
- pipecat/services/fireworks/llm.py +14 -21
- pipecat/services/fish/tts.py +109 -9
- pipecat/services/gemini_multimodal_live/__init__.py +1 -0
- pipecat/services/gemini_multimodal_live/events.py +83 -2
- pipecat/services/gemini_multimodal_live/file_api.py +189 -0
- pipecat/services/gemini_multimodal_live/gemini.py +218 -21
- pipecat/services/gladia/config.py +17 -10
- pipecat/services/gladia/stt.py +82 -36
- pipecat/services/google/frames.py +40 -0
- pipecat/services/google/google.py +2 -0
- pipecat/services/google/image.py +39 -2
- pipecat/services/google/llm.py +176 -58
- pipecat/services/google/llm_openai.py +26 -4
- pipecat/services/google/llm_vertex.py +37 -15
- pipecat/services/google/rtvi.py +41 -0
- pipecat/services/google/stt.py +65 -17
- pipecat/services/google/test-google-chirp.py +45 -0
- pipecat/services/google/tts.py +390 -19
- pipecat/services/grok/llm.py +8 -6
- pipecat/services/groq/llm.py +8 -6
- pipecat/services/groq/stt.py +13 -9
- pipecat/services/groq/tts.py +40 -0
- pipecat/services/hamsa/__init__.py +9 -0
- pipecat/services/hamsa/stt.py +241 -0
- pipecat/services/heygen/__init__.py +5 -0
- pipecat/services/heygen/api.py +281 -0
- pipecat/services/heygen/client.py +620 -0
- pipecat/services/heygen/video.py +338 -0
- pipecat/services/image_service.py +5 -3
- pipecat/services/inworld/__init__.py +1 -0
- pipecat/services/inworld/tts.py +592 -0
- pipecat/services/llm_service.py +127 -45
- pipecat/services/lmnt/tts.py +80 -7
- pipecat/services/mcp_service.py +85 -44
- pipecat/services/mem0/memory.py +42 -13
- pipecat/services/minimax/tts.py +74 -15
- pipecat/services/mistral/__init__.py +0 -0
- pipecat/services/mistral/llm.py +185 -0
- pipecat/services/moondream/vision.py +55 -10
- pipecat/services/neuphonic/tts.py +275 -48
- pipecat/services/nim/llm.py +8 -6
- pipecat/services/ollama/llm.py +27 -7
- pipecat/services/openai/base_llm.py +54 -16
- pipecat/services/openai/image.py +30 -0
- pipecat/services/openai/llm.py +7 -5
- pipecat/services/openai/stt.py +13 -9
- pipecat/services/openai/tts.py +42 -10
- pipecat/services/openai_realtime_beta/azure.py +11 -9
- pipecat/services/openai_realtime_beta/context.py +7 -5
- pipecat/services/openai_realtime_beta/events.py +10 -7
- pipecat/services/openai_realtime_beta/openai.py +37 -18
- pipecat/services/openpipe/llm.py +30 -24
- pipecat/services/openrouter/llm.py +9 -7
- pipecat/services/perplexity/llm.py +15 -19
- pipecat/services/piper/tts.py +26 -12
- pipecat/services/playht/tts.py +227 -65
- pipecat/services/qwen/llm.py +8 -6
- pipecat/services/rime/tts.py +128 -17
- pipecat/services/riva/stt.py +160 -22
- pipecat/services/riva/tts.py +67 -2
- pipecat/services/sambanova/llm.py +19 -17
- pipecat/services/sambanova/stt.py +14 -8
- pipecat/services/sarvam/tts.py +60 -13
- pipecat/services/simli/video.py +82 -21
- pipecat/services/soniox/__init__.py +0 -0
- pipecat/services/soniox/stt.py +398 -0
- pipecat/services/speechmatics/stt.py +29 -17
- pipecat/services/stt_service.py +47 -11
- pipecat/services/tavus/video.py +94 -25
- pipecat/services/together/llm.py +8 -6
- pipecat/services/tts_service.py +77 -53
- pipecat/services/ultravox/stt.py +46 -43
- pipecat/services/vision_service.py +5 -3
- pipecat/services/websocket_service.py +12 -11
- pipecat/services/whisper/base_stt.py +58 -12
- pipecat/services/whisper/stt.py +69 -58
- pipecat/services/xtts/tts.py +59 -2
- pipecat/sync/base_notifier.py +19 -0
- pipecat/sync/event_notifier.py +24 -0
- pipecat/tests/utils.py +73 -5
- pipecat/transcriptions/language.py +24 -0
- pipecat/transports/base_input.py +112 -8
- pipecat/transports/base_output.py +235 -13
- pipecat/transports/base_transport.py +119 -0
- pipecat/transports/local/audio.py +76 -0
- pipecat/transports/local/tk.py +84 -0
- pipecat/transports/network/fastapi_websocket.py +174 -15
- pipecat/transports/network/small_webrtc.py +383 -39
- pipecat/transports/network/webrtc_connection.py +214 -8
- pipecat/transports/network/websocket_client.py +171 -1
- pipecat/transports/network/websocket_server.py +147 -9
- pipecat/transports/services/daily.py +792 -70
- pipecat/transports/services/helpers/daily_rest.py +122 -129
- pipecat/transports/services/livekit.py +339 -4
- pipecat/transports/services/tavus.py +273 -38
- pipecat/utils/asyncio/task_manager.py +92 -186
- pipecat/utils/base_object.py +83 -1
- pipecat/utils/network.py +2 -0
- pipecat/utils/string.py +114 -58
- pipecat/utils/text/base_text_aggregator.py +44 -13
- pipecat/utils/text/base_text_filter.py +46 -0
- pipecat/utils/text/markdown_text_filter.py +70 -14
- pipecat/utils/text/pattern_pair_aggregator.py +18 -14
- pipecat/utils/text/simple_text_aggregator.py +43 -2
- pipecat/utils/text/skip_tags_aggregator.py +21 -13
- pipecat/utils/time.py +36 -0
- pipecat/utils/tracing/class_decorators.py +32 -7
- pipecat/utils/tracing/conversation_context_provider.py +12 -2
- pipecat/utils/tracing/service_attributes.py +80 -64
- pipecat/utils/tracing/service_decorators.py +48 -21
- pipecat/utils/tracing/setup.py +13 -7
- pipecat/utils/tracing/turn_context_provider.py +12 -2
- pipecat/utils/tracing/turn_trace_observer.py +27 -0
- pipecat/utils/utils.py +14 -14
- dv_pipecat_ai-0.0.74.dev770.dist-info/RECORD +0 -319
- pipecat/examples/daily_runner.py +0 -64
- pipecat/examples/run.py +0 -265
- pipecat/utils/asyncio/watchdog_async_iterator.py +0 -72
- pipecat/utils/asyncio/watchdog_event.py +0 -42
- pipecat/utils/asyncio/watchdog_priority_queue.py +0 -48
- pipecat/utils/asyncio/watchdog_queue.py +0 -48
- {dv_pipecat_ai-0.0.74.dev770.dist-info → dv_pipecat_ai-0.0.82.dev776.dist-info}/WHEEL +0 -0
- {dv_pipecat_ai-0.0.74.dev770.dist-info → dv_pipecat_ai-0.0.82.dev776.dist-info}/licenses/LICENSE +0 -0
- {dv_pipecat_ai-0.0.74.dev770.dist-info → dv_pipecat_ai-0.0.82.dev776.dist-info}/top_level.txt +0 -0
- /pipecat/{examples → extensions}/__init__.py +0 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2024–2025, Daily
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: BSD 2-Clause License
|
|
5
|
+
#
|
|
6
|
+
|
|
7
|
+
"""Mistral LLM service implementation using OpenAI-compatible interface."""
|
|
8
|
+
|
|
9
|
+
from typing import List, Sequence
|
|
10
|
+
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from openai import AsyncStream
|
|
13
|
+
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
|
14
|
+
|
|
15
|
+
from pipecat.frames.frames import FunctionCallFromLLM
|
|
16
|
+
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
|
17
|
+
from pipecat.services.openai.llm import OpenAILLMService
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MistralLLMService(OpenAILLMService):
|
|
21
|
+
"""A service for interacting with Mistral's API using the OpenAI-compatible interface.
|
|
22
|
+
|
|
23
|
+
This service extends OpenAILLMService to connect to Mistral's API endpoint while
|
|
24
|
+
maintaining full compatibility with OpenAI's interface and functionality.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
api_key: str,
|
|
31
|
+
base_url: str = "https://api.mistral.ai/v1",
|
|
32
|
+
model: str = "mistral-small-latest",
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
"""Initialize the Mistral LLM service.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
api_key: The API key for accessing Mistral's API.
|
|
39
|
+
base_url: The base URL for Mistral API. Defaults to "https://api.mistral.ai/v1".
|
|
40
|
+
model: The model identifier to use. Defaults to "mistral-small-latest".
|
|
41
|
+
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
|
44
|
+
|
|
45
|
+
def create_client(self, api_key=None, base_url=None, **kwargs):
|
|
46
|
+
"""Create OpenAI-compatible client for Mistral API endpoint.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
api_key: The API key for authentication. If None, uses instance key.
|
|
50
|
+
base_url: The base URL for the API. If None, uses instance URL.
|
|
51
|
+
**kwargs: Additional arguments passed to the client constructor.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
An OpenAI-compatible client configured for Mistral API.
|
|
55
|
+
"""
|
|
56
|
+
logger.debug(f"Creating Mistral client with api {base_url}")
|
|
57
|
+
return super().create_client(api_key, base_url, **kwargs)
|
|
58
|
+
|
|
59
|
+
def _apply_mistral_assistant_prefix(
|
|
60
|
+
self, messages: List[ChatCompletionMessageParam]
|
|
61
|
+
) -> List[ChatCompletionMessageParam]:
|
|
62
|
+
"""Apply Mistral's assistant message prefix requirement.
|
|
63
|
+
|
|
64
|
+
Mistral requires assistant messages to have prefix=True when they
|
|
65
|
+
are the final message in a conversation. According to Mistral's API:
|
|
66
|
+
- Assistant messages with prefix=True MUST be the last message
|
|
67
|
+
- Only add prefix=True to the final assistant message when needed
|
|
68
|
+
- This allows assistant messages to be accepted as the last message
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
messages: The original list of messages.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Messages with Mistral prefix requirement applied to final assistant message.
|
|
75
|
+
"""
|
|
76
|
+
if not messages:
|
|
77
|
+
return messages
|
|
78
|
+
|
|
79
|
+
# Create a copy to avoid modifying the original
|
|
80
|
+
fixed_messages = [dict(msg) for msg in messages]
|
|
81
|
+
|
|
82
|
+
# Get the last message
|
|
83
|
+
last_message = fixed_messages[-1]
|
|
84
|
+
|
|
85
|
+
# Only add prefix=True to the last message if it's an assistant message
|
|
86
|
+
# and Mistral would otherwise reject it
|
|
87
|
+
if last_message.get("role") == "assistant" and "prefix" not in last_message:
|
|
88
|
+
last_message["prefix"] = True
|
|
89
|
+
|
|
90
|
+
return fixed_messages
|
|
91
|
+
|
|
92
|
+
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
|
|
93
|
+
"""Execute function calls, filtering out already-completed ones.
|
|
94
|
+
|
|
95
|
+
Mistral and OpenAI have different function call detection patterns:
|
|
96
|
+
|
|
97
|
+
OpenAI (Stream-based detection):
|
|
98
|
+
|
|
99
|
+
- Detects function calls only from streaming chunks as the LLM generates them
|
|
100
|
+
- Second LLM completion doesn't re-detect existing tool_calls in message history
|
|
101
|
+
- Function calls execute exactly once
|
|
102
|
+
|
|
103
|
+
Mistral (Message-based detection):
|
|
104
|
+
|
|
105
|
+
- Detects function calls from the complete message history on each completion
|
|
106
|
+
- Second LLM completion with the response re-detects the same tool_calls from
|
|
107
|
+
previous messages
|
|
108
|
+
- Without filtering, function calls would execute twice
|
|
109
|
+
|
|
110
|
+
This method prevents duplicate execution by:
|
|
111
|
+
|
|
112
|
+
1. Checking message history for existing tool result messages
|
|
113
|
+
2. Filtering out function calls that already have corresponding results
|
|
114
|
+
3. Only executing function calls that haven't been completed yet
|
|
115
|
+
|
|
116
|
+
Note: This filtering prevents duplicate function execution, but the
|
|
117
|
+
on_function_calls_started event may still fire twice due to the detection
|
|
118
|
+
pattern difference. This is expected behavior.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
function_calls: The function calls to potentially execute.
|
|
122
|
+
"""
|
|
123
|
+
if not function_calls:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
# Filter out function calls that already have results
|
|
127
|
+
calls_to_execute = []
|
|
128
|
+
|
|
129
|
+
# Get messages from the first function call's context (they should all have the same context)
|
|
130
|
+
messages = function_calls[0].context.get_messages() if function_calls else []
|
|
131
|
+
|
|
132
|
+
# Get all tool_call_ids that already have results
|
|
133
|
+
executed_call_ids = set()
|
|
134
|
+
for msg in messages:
|
|
135
|
+
if msg.get("role") == "tool" and msg.get("tool_call_id"):
|
|
136
|
+
executed_call_ids.add(msg.get("tool_call_id"))
|
|
137
|
+
|
|
138
|
+
# Only include function calls that haven't been executed yet
|
|
139
|
+
for call in function_calls:
|
|
140
|
+
if call.tool_call_id not in executed_call_ids:
|
|
141
|
+
calls_to_execute.append(call)
|
|
142
|
+
else:
|
|
143
|
+
logger.trace(
|
|
144
|
+
f"Skipping already-executed function call: {call.function_name}:{call.tool_call_id}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Call parent method with filtered list
|
|
148
|
+
if calls_to_execute:
|
|
149
|
+
await super().run_function_calls(calls_to_execute)
|
|
150
|
+
|
|
151
|
+
def build_chat_completion_params(
|
|
152
|
+
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
|
153
|
+
) -> dict:
|
|
154
|
+
"""Build parameters for Mistral chat completion request.
|
|
155
|
+
|
|
156
|
+
Handles Mistral-specific requirements including:
|
|
157
|
+
- Assistant message prefix requirement for API compatibility
|
|
158
|
+
- Parameter mapping (random_seed instead of seed)
|
|
159
|
+
- Core completion settings
|
|
160
|
+
"""
|
|
161
|
+
# Apply Mistral's assistant prefix requirement for API compatibility
|
|
162
|
+
fixed_messages = self._apply_mistral_assistant_prefix(messages)
|
|
163
|
+
|
|
164
|
+
params = {
|
|
165
|
+
"model": self.model_name,
|
|
166
|
+
"stream": True,
|
|
167
|
+
"messages": fixed_messages,
|
|
168
|
+
"tools": context.tools,
|
|
169
|
+
"tool_choice": context.tool_choice,
|
|
170
|
+
"frequency_penalty": self._settings["frequency_penalty"],
|
|
171
|
+
"presence_penalty": self._settings["presence_penalty"],
|
|
172
|
+
"temperature": self._settings["temperature"],
|
|
173
|
+
"top_p": self._settings["top_p"],
|
|
174
|
+
"max_tokens": self._settings["max_tokens"],
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
# Handle Mistral-specific parameter mapping
|
|
178
|
+
# Mistral uses "random_seed" instead of "seed"
|
|
179
|
+
if self._settings["seed"]:
|
|
180
|
+
params["random_seed"] = self._settings["seed"]
|
|
181
|
+
|
|
182
|
+
# Add any extra parameters
|
|
183
|
+
params.update(self._settings["extra"])
|
|
184
|
+
|
|
185
|
+
return params
|
|
@@ -4,6 +4,12 @@
|
|
|
4
4
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
5
5
|
#
|
|
6
6
|
|
|
7
|
+
"""Moondream vision service implementation.
|
|
8
|
+
|
|
9
|
+
This module provides integration with the Moondream vision-language model
|
|
10
|
+
for image analysis and description generation.
|
|
11
|
+
"""
|
|
12
|
+
|
|
7
13
|
import asyncio
|
|
8
14
|
from typing import AsyncGenerator
|
|
9
15
|
|
|
@@ -23,7 +29,15 @@ except ModuleNotFoundError as e:
|
|
|
23
29
|
|
|
24
30
|
|
|
25
31
|
def detect_device():
|
|
26
|
-
"""
|
|
32
|
+
"""Detect the appropriate device to run on.
|
|
33
|
+
|
|
34
|
+
Detects available hardware acceleration and selects the best device
|
|
35
|
+
and data type for optimal performance.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
tuple: A tuple containing (device, dtype) where device is a torch.device
|
|
39
|
+
and dtype is the recommended torch data type for that device.
|
|
40
|
+
"""
|
|
27
41
|
try:
|
|
28
42
|
import intel_extension_for_pytorch
|
|
29
43
|
|
|
@@ -40,9 +54,24 @@ def detect_device():
|
|
|
40
54
|
|
|
41
55
|
|
|
42
56
|
class MoondreamService(VisionService):
|
|
57
|
+
"""Moondream vision-language model service.
|
|
58
|
+
|
|
59
|
+
Provides image analysis and description generation using the Moondream
|
|
60
|
+
vision-language model. Supports various hardware acceleration options
|
|
61
|
+
including CUDA, MPS, and Intel XPU.
|
|
62
|
+
"""
|
|
63
|
+
|
|
43
64
|
def __init__(
|
|
44
|
-
self, *, model="vikhyatk/moondream2", revision="
|
|
65
|
+
self, *, model="vikhyatk/moondream2", revision="2025-01-09", use_cpu=False, **kwargs
|
|
45
66
|
):
|
|
67
|
+
"""Initialize the Moondream service.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model: Hugging Face model identifier for the Moondream model.
|
|
71
|
+
revision: Specific model revision to use.
|
|
72
|
+
use_cpu: Whether to force CPU usage instead of hardware acceleration.
|
|
73
|
+
**kwargs: Additional arguments passed to the parent VisionService.
|
|
74
|
+
"""
|
|
46
75
|
super().__init__(**kwargs)
|
|
47
76
|
|
|
48
77
|
self.set_model_name(model)
|
|
@@ -53,18 +82,28 @@ class MoondreamService(VisionService):
|
|
|
53
82
|
device = torch.device("cpu")
|
|
54
83
|
dtype = torch.float32
|
|
55
84
|
|
|
56
|
-
self._tokenizer = AutoTokenizer.from_pretrained(model, revision=revision)
|
|
57
|
-
|
|
58
85
|
logger.debug("Loading Moondream model...")
|
|
59
86
|
|
|
60
87
|
self._model = AutoModelForCausalLM.from_pretrained(
|
|
61
|
-
model,
|
|
62
|
-
|
|
63
|
-
|
|
88
|
+
model,
|
|
89
|
+
trust_remote_code=True,
|
|
90
|
+
revision=revision,
|
|
91
|
+
device_map={"": device},
|
|
92
|
+
torch_dtype=dtype,
|
|
93
|
+
).eval()
|
|
64
94
|
|
|
65
95
|
logger.debug("Loaded Moondream model")
|
|
66
96
|
|
|
67
97
|
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
|
98
|
+
"""Analyze an image and generate a description.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
frame: Vision frame containing the image data and optional question text.
|
|
102
|
+
|
|
103
|
+
Yields:
|
|
104
|
+
Frame: TextFrame containing the generated image description, or ErrorFrame
|
|
105
|
+
if analysis fails.
|
|
106
|
+
"""
|
|
68
107
|
if not self._model:
|
|
69
108
|
logger.error(f"{self} error: Moondream model not available ({self.model_name})")
|
|
70
109
|
yield ErrorFrame("Moondream model not available")
|
|
@@ -73,11 +112,17 @@ class MoondreamService(VisionService):
|
|
|
73
112
|
logger.debug(f"Analyzing image: {frame}")
|
|
74
113
|
|
|
75
114
|
def get_image_description(frame: VisionImageRawFrame):
|
|
115
|
+
"""Generate description for the given image frame.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
frame: Vision frame containing image data and question.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
str: Generated description of the image.
|
|
122
|
+
"""
|
|
76
123
|
image = Image.frombytes(frame.format, frame.size, frame.image)
|
|
77
124
|
image_embeds = self._model.encode_image(image)
|
|
78
|
-
description = self._model.
|
|
79
|
-
image_embeds=image_embeds, question=frame.text, tokenizer=self._tokenizer
|
|
80
|
-
)
|
|
125
|
+
description = self._model.query(image_embeds, frame.text)["answer"]
|
|
81
126
|
return description
|
|
82
127
|
|
|
83
128
|
description = await asyncio.to_thread(get_image_description, frame)
|