smart_prompt 0.4.4 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -10
- data/README.cn.md +307 -64
- data/README.md +311 -64
- data/Rakefile +10 -1
- data/config/anthropic_config.yml +151 -0
- data/config/image_generation_config.yml +22 -0
- data/config/multimodal_config.yml +85 -0
- data/config/sensenova_config.yml +63 -0
- data/config/zhipu_config.yml +73 -0
- data/examples/anthropic_basic_chat.rb +143 -0
- data/examples/anthropic_example.rb +232 -0
- data/examples/anthropic_multimodal.rb +212 -0
- data/examples/anthropic_streaming.rb +312 -0
- data/examples/anthropic_tool_calling.rb +393 -0
- data/examples/automatic_cleanup_example.rb +109 -0
- data/examples/history_management_examples.rb +522 -0
- data/examples/image_generation_example.rb +130 -0
- data/examples/monitoring_example.rb +121 -0
- data/examples/multimodal_example.rb +63 -0
- data/examples/relevance_based_strategy_example.rb +87 -0
- data/examples/sensenova_example.rb +129 -0
- data/examples/stt_example.rb +287 -0
- data/examples/tts_example.rb +244 -0
- data/examples/video_generation_example.rb +189 -0
- data/examples/zhipu_example.rb +151 -0
- data/lib/smart_prompt/anthropic_adapter.rb +363 -281
- data/lib/smart_prompt/compression_engine.rb +201 -0
- data/lib/smart_prompt/context_strategy.rb +22 -0
- data/lib/smart_prompt/conversation.rb +81 -191
- data/lib/smart_prompt/engine.rb +36 -19
- data/lib/smart_prompt/history_manager.rb +596 -0
- data/lib/smart_prompt/hybrid_strategy.rb +222 -0
- data/lib/smart_prompt/image_generation_adapter.rb +297 -0
- data/lib/smart_prompt/lru_cache.rb +133 -0
- data/lib/smart_prompt/message.rb +57 -0
- data/lib/smart_prompt/multimodal_adapter.rb +277 -0
- data/lib/smart_prompt/openai_adapter.rb +1 -25
- data/lib/smart_prompt/persistence_layer.rb +197 -0
- data/lib/smart_prompt/relevance_based_strategy.rb +221 -0
- data/lib/smart_prompt/sensenova_adapter.rb +410 -0
- data/lib/smart_prompt/session.rb +140 -0
- data/lib/smart_prompt/sliding_window_strategy.rb +100 -0
- data/lib/smart_prompt/stt_adapter.rb +381 -0
- data/lib/smart_prompt/summary_based_strategy.rb +152 -0
- data/lib/smart_prompt/token_counter.rb +74 -0
- data/lib/smart_prompt/tts_adapter.rb +403 -0
- data/lib/smart_prompt/version.rb +1 -1
- data/lib/smart_prompt/video_generation_adapter.rb +330 -0
- data/lib/smart_prompt/worker.rb +25 -3
- data/lib/smart_prompt/zhipu_adapter.rb +616 -0
- data/lib/smart_prompt.rb +22 -2
- data/workers/history_management_examples.rb +407 -0
- data/workers/image_generation_workers.rb +119 -0
- data/workers/multimodal_workers.rb +110 -0
- data/workers/sensenova_workers.rb +62 -0
- data/workers/stt_workers.rb +195 -0
- data/workers/tts_workers.rb +388 -0
- data/workers/video_generation_workers.rb +264 -0
- data/workers/zhipu_workers.rb +113 -0
- metadata +84 -8
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
require 'time'
|
|
2
|
+
|
|
3
|
+
module SmartPrompt
|
|
4
|
+
# Session represents an isolated conversation session with its own message history
|
|
5
|
+
class Session
|
|
6
|
+
attr_reader :id, :messages, :metadata, :created_at, :updated_at, :config
|
|
7
|
+
|
|
8
|
+
def initialize(id, config = {})
|
|
9
|
+
@id = id
|
|
10
|
+
@messages = []
|
|
11
|
+
@metadata = {}
|
|
12
|
+
@config = config
|
|
13
|
+
@token_cache = {}
|
|
14
|
+
@importance_scores = {}
|
|
15
|
+
@created_at = Time.now
|
|
16
|
+
@updated_at = Time.now
|
|
17
|
+
@token_counter = TokenCounter.new
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Add a message to the session
|
|
21
|
+
def add_message(message_data)
|
|
22
|
+
message = message_data.is_a?(Message) ? message_data : Message.new(message_data)
|
|
23
|
+
|
|
24
|
+
# Calculate token count for the message
|
|
25
|
+
message.calculate_tokens(@token_counter)
|
|
26
|
+
|
|
27
|
+
@messages << message
|
|
28
|
+
@updated_at = Time.now
|
|
29
|
+
enforce_limits
|
|
30
|
+
message
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Get messages from the session
|
|
34
|
+
def get_messages(count = nil)
|
|
35
|
+
count ? @messages.last(count) : @messages
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Calculate total token count for all messages
|
|
39
|
+
def total_tokens
|
|
40
|
+
@messages.sum { |msg| msg.token_count || 0 }
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
# Get the number of messages in the session
|
|
44
|
+
def message_count
|
|
45
|
+
@messages.length
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# Clear all messages except system messages if preserve_system is true
|
|
49
|
+
def clear(preserve_system: true)
|
|
50
|
+
if preserve_system
|
|
51
|
+
@messages = @messages.select(&:system_message?)
|
|
52
|
+
else
|
|
53
|
+
@messages = []
|
|
54
|
+
end
|
|
55
|
+
@updated_at = Time.now
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Get importance score for a message at given index
|
|
59
|
+
def get_importance_score(message_index)
|
|
60
|
+
@importance_scores[message_index] ||= calculate_importance(message_index)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# Convert session to hash format for serialization
|
|
64
|
+
def to_h
|
|
65
|
+
{
|
|
66
|
+
id: @id,
|
|
67
|
+
messages: @messages.map(&:to_h),
|
|
68
|
+
metadata: @metadata,
|
|
69
|
+
created_at: @created_at.iso8601,
|
|
70
|
+
updated_at: @updated_at.iso8601,
|
|
71
|
+
config: @config
|
|
72
|
+
}
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
private
|
|
76
|
+
|
|
77
|
+
# Enforce message count and token limits
|
|
78
|
+
def enforce_limits
|
|
79
|
+
max_messages = @config[:max_messages]
|
|
80
|
+
max_tokens = @config[:max_tokens]
|
|
81
|
+
|
|
82
|
+
# Enforce message count limit
|
|
83
|
+
if max_messages && @messages.length > max_messages
|
|
84
|
+
remove_oldest_messages_to_limit(max_messages)
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# Enforce token limit
|
|
88
|
+
if max_tokens && total_tokens > max_tokens
|
|
89
|
+
remove_oldest_messages_to_token_limit(max_tokens)
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
# Remove oldest non-system messages to meet message count limit
|
|
94
|
+
def remove_oldest_messages_to_limit(max_messages)
|
|
95
|
+
system_messages = @messages.select(&:system_message?)
|
|
96
|
+
non_system_messages = @messages.reject(&:system_message?)
|
|
97
|
+
|
|
98
|
+
# Keep only the most recent non-system messages
|
|
99
|
+
messages_to_keep = max_messages - system_messages.length
|
|
100
|
+
messages_to_keep = [messages_to_keep, 0].max
|
|
101
|
+
|
|
102
|
+
kept_non_system = non_system_messages.last(messages_to_keep)
|
|
103
|
+
@messages = system_messages + kept_non_system
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# Remove oldest non-system messages to meet token limit
|
|
107
|
+
def remove_oldest_messages_to_token_limit(max_tokens)
|
|
108
|
+
system_messages = @messages.select(&:system_message?)
|
|
109
|
+
non_system_messages = @messages.reject(&:system_message?)
|
|
110
|
+
|
|
111
|
+
system_tokens = system_messages.sum { |msg| msg.token_count || 0 }
|
|
112
|
+
available_tokens = max_tokens - system_tokens
|
|
113
|
+
|
|
114
|
+
# Keep adding messages from the end until we hit the token limit
|
|
115
|
+
kept_messages = []
|
|
116
|
+
current_tokens = 0
|
|
117
|
+
|
|
118
|
+
non_system_messages.reverse_each do |msg|
|
|
119
|
+
msg_tokens = msg.token_count || 0
|
|
120
|
+
if current_tokens + msg_tokens <= available_tokens
|
|
121
|
+
kept_messages.unshift(msg)
|
|
122
|
+
current_tokens += msg_tokens
|
|
123
|
+
else
|
|
124
|
+
break
|
|
125
|
+
end
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
@messages = system_messages + kept_messages
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
# Calculate importance score for a message
|
|
132
|
+
# This is a simple implementation based on recency
|
|
133
|
+
def calculate_importance(message_index)
|
|
134
|
+
return 0.0 if @messages.empty?
|
|
135
|
+
|
|
136
|
+
# Simple recency-based scoring: newer messages have higher scores
|
|
137
|
+
message_index.to_f / @messages.length
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
end
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
require_relative 'context_strategy'
|
|
2
|
+
|
|
3
|
+
module SmartPrompt
|
|
4
|
+
# SlidingWindowStrategy implements a simple context selection strategy
|
|
5
|
+
# that keeps the most recent N messages (sliding window approach)
|
|
6
|
+
#
|
|
7
|
+
# This strategy:
|
|
8
|
+
# - Preserves system messages regardless of window size
|
|
9
|
+
# - Keeps the most recent N non-system messages
|
|
10
|
+
# - Trims messages to fit within token limits if specified
|
|
11
|
+
# - Is efficient and predictable for simple conversation flows
|
|
12
|
+
class SlidingWindowStrategy
|
|
13
|
+
include ContextStrategy
|
|
14
|
+
|
|
15
|
+
# Initialize the sliding window strategy
|
|
16
|
+
# @param config [Hash] Configuration options
|
|
17
|
+
# @option config [Integer] :window_size (10) Number of recent messages to keep
|
|
18
|
+
# @option config [Boolean] :preserve_system (true) Whether to always keep system messages
|
|
19
|
+
def initialize(config = {})
|
|
20
|
+
@window_size = config[:window_size] || 10
|
|
21
|
+
@preserve_system = config[:preserve_system] != false
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
# Select messages using sliding window approach
|
|
25
|
+
# @param messages [Array<Message>] All messages in the session
|
|
26
|
+
# @param max_tokens [Integer, nil] Maximum token limit for selected messages
|
|
27
|
+
# @param current_message [Message, nil] Not used in this strategy
|
|
28
|
+
# @return [Array<Message>] Selected messages
|
|
29
|
+
def select_messages(messages, max_tokens, current_message = nil)
|
|
30
|
+
return [] if messages.nil? || messages.empty?
|
|
31
|
+
|
|
32
|
+
# Separate system and non-system messages
|
|
33
|
+
system_messages = @preserve_system ? messages.select(&:system_message?) : []
|
|
34
|
+
non_system_messages = messages.reject(&:system_message?)
|
|
35
|
+
|
|
36
|
+
# Get the most recent messages within window size
|
|
37
|
+
recent_messages = non_system_messages.last(@window_size)
|
|
38
|
+
|
|
39
|
+
# Combine system messages (at the beginning) with recent messages
|
|
40
|
+
selected = system_messages + recent_messages
|
|
41
|
+
|
|
42
|
+
# Log selection decision
|
|
43
|
+
log_debug "SlidingWindowStrategy: selected #{selected.count}/#{messages.count} messages (window_size=#{@window_size}, system=#{system_messages.count}, recent=#{recent_messages.count})"
|
|
44
|
+
|
|
45
|
+
# Trim to token limit if specified
|
|
46
|
+
result = max_tokens ? trim_to_token_limit(selected, max_tokens) : selected
|
|
47
|
+
|
|
48
|
+
if max_tokens && result.count < selected.count
|
|
49
|
+
tokens_before = selected.sum { |m| m.token_count || 0 }
|
|
50
|
+
tokens_after = result.sum { |m| m.token_count || 0 }
|
|
51
|
+
log_debug "SlidingWindowStrategy: trimmed to token limit #{max_tokens}: #{selected.count} -> #{result.count} messages, #{tokens_before} -> #{tokens_after} tokens"
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
result
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
# Determine if compression should be triggered
|
|
58
|
+
# Recommends compression when message count exceeds 2x window size
|
|
59
|
+
# @param session [Session] The session to evaluate
|
|
60
|
+
# @return [Boolean] true if message count > 2 * window_size
|
|
61
|
+
def should_compress?(session)
|
|
62
|
+
session.message_count > @window_size * 2
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
private
|
|
66
|
+
|
|
67
|
+
# Trim messages to fit within token limit
|
|
68
|
+
# Removes messages from the beginning (oldest first) until within limit
|
|
69
|
+
# @param messages [Array<Message>] Messages to trim
|
|
70
|
+
# @param max_tokens [Integer] Maximum token limit
|
|
71
|
+
# @return [Array<Message>] Trimmed messages
|
|
72
|
+
def trim_to_token_limit(messages, max_tokens)
|
|
73
|
+
return messages unless max_tokens
|
|
74
|
+
return [] if messages.empty?
|
|
75
|
+
|
|
76
|
+
# Calculate tokens from newest to oldest, keeping messages that fit
|
|
77
|
+
total = 0
|
|
78
|
+
selected = []
|
|
79
|
+
|
|
80
|
+
messages.reverse_each do |msg|
|
|
81
|
+
msg_tokens = msg.token_count || 0
|
|
82
|
+
if total + msg_tokens <= max_tokens
|
|
83
|
+
selected.unshift(msg)
|
|
84
|
+
total += msg_tokens
|
|
85
|
+
else
|
|
86
|
+
# Stop adding messages once we exceed the limit
|
|
87
|
+
break
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
selected
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Logging helper methods
|
|
95
|
+
def log_debug(message)
|
|
96
|
+
return unless SmartPrompt.logger
|
|
97
|
+
SmartPrompt.logger.debug "[SlidingWindowStrategy] #{message}"
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
end
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
require "openai"
|
|
2
|
+
require "base64"
|
|
3
|
+
require "net/http"
|
|
4
|
+
require "uri"
|
|
5
|
+
|
|
6
|
+
module SmartPrompt
|
|
7
|
+
class STTAdapter < LLMAdapter
|
|
8
|
+
# Supported audio formats
|
|
9
|
+
SUPPORTED_AUDIO_FORMATS = %w[mp3 mp4 mpeg mpga m4a wav webm]
|
|
10
|
+
|
|
11
|
+
# Supported languages for speech recognition
|
|
12
|
+
SUPPORTED_LANGUAGES = %w[zh en ja ko]
|
|
13
|
+
|
|
14
|
+
# Maximum file size (25MB)
|
|
15
|
+
MAX_FILE_SIZE = 25 * 1024 * 1024
|
|
16
|
+
|
|
17
|
+
def initialize(config)
|
|
18
|
+
super
|
|
19
|
+
api_key = @config["api_key"]
|
|
20
|
+
if api_key.is_a?(String) && api_key.start_with?("ENV[") && api_key.end_with?("]")
|
|
21
|
+
api_key = eval(api_key)
|
|
22
|
+
end
|
|
23
|
+
begin
|
|
24
|
+
@client = OpenAI::Client.new(
|
|
25
|
+
access_token: api_key,
|
|
26
|
+
uri_base: @config["url"],
|
|
27
|
+
request_timeout: 120,
|
|
28
|
+
)
|
|
29
|
+
rescue OpenAI::ConfigurationError => e
|
|
30
|
+
SmartPrompt.logger.error "Failed to initialize STT client: #{e.message}"
|
|
31
|
+
raise LLMAPIError, "Invalid STT configuration: #{e.message}"
|
|
32
|
+
rescue OpenAI::Error => e
|
|
33
|
+
SmartPrompt.logger.error "Failed to initialize STT client: #{e.message}"
|
|
34
|
+
raise LLMAPIError, "STT authentication failed: #{e.message}"
|
|
35
|
+
rescue SocketError => e
|
|
36
|
+
SmartPrompt.logger.error "Failed to initialize STT client: #{e.message}"
|
|
37
|
+
raise LLMAPIError, "Network error: Unable to connect to STT API"
|
|
38
|
+
rescue => e
|
|
39
|
+
SmartPrompt.logger.error "Failed to initialize STT client: #{e.message}"
|
|
40
|
+
raise Error, "Unexpected error initializing STT client: #{e.message}"
|
|
41
|
+
ensure
|
|
42
|
+
SmartPrompt.logger.info "Successfully created an STT client."
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Speech-to-text transcription
|
|
47
|
+
def transcribe_audio(audio_file, model: nil, language: nil, prompt: nil, temperature: 0.0, response_format: "json")
|
|
48
|
+
SmartPrompt.logger.info "STTAdapter: Transcribing audio to text"
|
|
49
|
+
|
|
50
|
+
model_name = model || @config["model"]
|
|
51
|
+
|
|
52
|
+
# Validate parameters
|
|
53
|
+
validate_stt_parameters(audio_file, language, response_format)
|
|
54
|
+
|
|
55
|
+
begin
|
|
56
|
+
# Prepare audio file
|
|
57
|
+
audio_data = prepare_audio_file(audio_file)
|
|
58
|
+
|
|
59
|
+
parameters = {
|
|
60
|
+
model: model_name,
|
|
61
|
+
file: audio_data[:file],
|
|
62
|
+
temperature: temperature,
|
|
63
|
+
response_format: response_format
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Add optional parameters
|
|
67
|
+
parameters[:language] = language if language
|
|
68
|
+
parameters[:prompt] = prompt if prompt
|
|
69
|
+
|
|
70
|
+
SmartPrompt.logger.info "STT parameters: #{parameters.except(:file)}"
|
|
71
|
+
|
|
72
|
+
# Custom implementation for STT since OpenAI gem doesn't support audio transcription endpoints
|
|
73
|
+
response = submit_stt_request(parameters)
|
|
74
|
+
|
|
75
|
+
@last_response = response
|
|
76
|
+
|
|
77
|
+
# Process response
|
|
78
|
+
if response["text"]
|
|
79
|
+
transcription_data = {
|
|
80
|
+
text: response["text"],
|
|
81
|
+
language: language,
|
|
82
|
+
duration: audio_data[:duration],
|
|
83
|
+
file_size: audio_data[:file_size],
|
|
84
|
+
format: audio_data[:format]
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
SmartPrompt.logger.info "STT transcription successful, transcribed #{response['text'].length} characters"
|
|
88
|
+
return transcription_data
|
|
89
|
+
else
|
|
90
|
+
SmartPrompt.logger.error "No text in STT response"
|
|
91
|
+
raise LLMAPIError, "No text in STT response"
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
rescue OpenAI::Error => e
|
|
95
|
+
SmartPrompt.logger.error "STT API error: #{e.message}"
|
|
96
|
+
raise LLMAPIError, "STT API error: #{e.message}"
|
|
97
|
+
rescue => e
|
|
98
|
+
SmartPrompt.logger.error "Unexpected error during STT transcription: #{e.message}"
|
|
99
|
+
raise Error, "Unexpected error during STT transcription: #{e.message}"
|
|
100
|
+
end
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
# Transcribe audio from URL
|
|
104
|
+
def transcribe_audio_url(audio_url, model: nil, language: nil, prompt: nil, temperature: 0.0, response_format: "json")
|
|
105
|
+
SmartPrompt.logger.info "STTAdapter: Transcribing audio from URL"
|
|
106
|
+
|
|
107
|
+
model_name = model || @config["model"]
|
|
108
|
+
|
|
109
|
+
begin
|
|
110
|
+
parameters = {
|
|
111
|
+
model: model_name,
|
|
112
|
+
audio_url: audio_url,
|
|
113
|
+
temperature: temperature,
|
|
114
|
+
response_format: response_format
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
# Add optional parameters
|
|
118
|
+
parameters[:language] = language if language
|
|
119
|
+
parameters[:prompt] = prompt if prompt
|
|
120
|
+
|
|
121
|
+
SmartPrompt.logger.info "STT URL parameters: #{parameters}"
|
|
122
|
+
|
|
123
|
+
# Custom implementation for URL-based STT
|
|
124
|
+
response = submit_stt_url_request(parameters)
|
|
125
|
+
|
|
126
|
+
@last_response = response
|
|
127
|
+
|
|
128
|
+
if response["text"]
|
|
129
|
+
transcription_data = {
|
|
130
|
+
text: response["text"],
|
|
131
|
+
language: language,
|
|
132
|
+
audio_url: audio_url
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
SmartPrompt.logger.info "STT URL transcription successful, transcribed #{response['text'].length} characters"
|
|
136
|
+
return transcription_data
|
|
137
|
+
else
|
|
138
|
+
SmartPrompt.logger.error "No text in STT URL response"
|
|
139
|
+
raise LLMAPIError, "No text in STT URL response"
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
rescue => e
|
|
143
|
+
SmartPrompt.logger.error "Error in URL transcription: #{e.message}"
|
|
144
|
+
raise Error, "Error in URL transcription: #{e.message}"
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
# Batch transcription
|
|
149
|
+
def transcribe_batch(audio_files, model: nil, language: nil, prompt: nil, temperature: 0.0)
|
|
150
|
+
SmartPrompt.logger.info "STTAdapter: Batch transcribing #{audio_files.size} audio files"
|
|
151
|
+
|
|
152
|
+
results = []
|
|
153
|
+
|
|
154
|
+
audio_files.each_with_index do |audio_file, index|
|
|
155
|
+
begin
|
|
156
|
+
SmartPrompt.logger.info "Transcribing file #{index + 1}/#{audio_files.size}: #{File.basename(audio_file)}"
|
|
157
|
+
|
|
158
|
+
result = transcribe_audio(
|
|
159
|
+
audio_file,
|
|
160
|
+
model: model,
|
|
161
|
+
language: language,
|
|
162
|
+
prompt: prompt,
|
|
163
|
+
temperature: temperature
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
results << {
|
|
167
|
+
file: audio_file,
|
|
168
|
+
index: index,
|
|
169
|
+
transcription: result,
|
|
170
|
+
success: true
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
rescue => e
|
|
174
|
+
SmartPrompt.logger.error "Failed to transcribe #{audio_file}: #{e.message}"
|
|
175
|
+
results << {
|
|
176
|
+
file: audio_file,
|
|
177
|
+
index: index,
|
|
178
|
+
error: e.message,
|
|
179
|
+
success: false
|
|
180
|
+
}
|
|
181
|
+
end
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
{
|
|
185
|
+
total_files: audio_files.size,
|
|
186
|
+
successful: results.count { |r| r[:success] },
|
|
187
|
+
failed: results.count { |r| !r[:success] },
|
|
188
|
+
results: results
|
|
189
|
+
}
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
# Get audio file information
|
|
193
|
+
def get_audio_info(audio_file)
|
|
194
|
+
SmartPrompt.logger.info "STTAdapter: Getting audio file information"
|
|
195
|
+
|
|
196
|
+
begin
|
|
197
|
+
unless File.exist?(audio_file)
|
|
198
|
+
raise Error, "Audio file not found: #{audio_file}"
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
file_ext = File.extname(audio_file).downcase.delete(".")
|
|
202
|
+
unless SUPPORTED_AUDIO_FORMATS.include?(file_ext)
|
|
203
|
+
raise Error, "Unsupported audio format: #{file_ext}"
|
|
204
|
+
end
|
|
205
|
+
|
|
206
|
+
file_size = File.size(audio_file)
|
|
207
|
+
if file_size > MAX_FILE_SIZE
|
|
208
|
+
raise Error, "Audio file too large (max #{MAX_FILE_SIZE / (1024 * 1024)}MB)"
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
# Estimate duration (rough calculation)
|
|
212
|
+
# Note: This is a simplified estimation, actual duration may vary
|
|
213
|
+
duration = estimate_audio_duration(file_size, file_ext)
|
|
214
|
+
|
|
215
|
+
{
|
|
216
|
+
file_path: audio_file,
|
|
217
|
+
file_name: File.basename(audio_file),
|
|
218
|
+
file_size: file_size,
|
|
219
|
+
format: file_ext,
|
|
220
|
+
estimated_duration: duration,
|
|
221
|
+
supported: true
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
rescue => e
|
|
225
|
+
SmartPrompt.logger.error "Error getting audio info: #{e.message}"
|
|
226
|
+
raise Error, "Error getting audio info: #{e.message}"
|
|
227
|
+
end
|
|
228
|
+
end
|
|
229
|
+
|
|
230
|
+
# Language detection (basic implementation)
|
|
231
|
+
def detect_language(text)
|
|
232
|
+
SmartPrompt.logger.info "STTAdapter: Detecting language from text"
|
|
233
|
+
|
|
234
|
+
# Simple language detection based on character ranges
|
|
235
|
+
if text =~ /[\u4e00-\u9fff]/
|
|
236
|
+
"zh"
|
|
237
|
+
elsif text =~ /[\u3040-\u309f\u30a0-\u30ff]/
|
|
238
|
+
"ja"
|
|
239
|
+
elsif text =~ /[\uac00-\ud7af]/
|
|
240
|
+
"ko"
|
|
241
|
+
else
|
|
242
|
+
"en"
|
|
243
|
+
end
|
|
244
|
+
end
|
|
245
|
+
|
|
246
|
+
private
|
|
247
|
+
|
|
248
|
+
def validate_stt_parameters(audio_file, language, response_format)
|
|
249
|
+
# Validate audio file
|
|
250
|
+
unless File.exist?(audio_file)
|
|
251
|
+
raise Error, "Audio file not found: #{audio_file}"
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
file_ext = File.extname(audio_file).downcase.delete(".")
|
|
255
|
+
unless SUPPORTED_AUDIO_FORMATS.include?(file_ext)
|
|
256
|
+
raise Error, "Unsupported audio format: #{file_ext}"
|
|
257
|
+
end
|
|
258
|
+
|
|
259
|
+
file_size = File.size(audio_file)
|
|
260
|
+
if file_size > MAX_FILE_SIZE
|
|
261
|
+
raise Error, "Audio file too large (max #{MAX_FILE_SIZE / (1024 * 1024)}MB)"
|
|
262
|
+
end
|
|
263
|
+
|
|
264
|
+
# Validate language
|
|
265
|
+
if language && !SUPPORTED_LANGUAGES.include?(language)
|
|
266
|
+
raise Error, "Unsupported language: #{language}"
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
# Validate response format
|
|
270
|
+
unless %w[json text srt vtt].include?(response_format)
|
|
271
|
+
raise Error, "Unsupported response format: #{response_format}"
|
|
272
|
+
end
|
|
273
|
+
end
|
|
274
|
+
|
|
275
|
+
def prepare_audio_file(audio_file)
|
|
276
|
+
file_ext = File.extname(audio_file).downcase.delete(".")
|
|
277
|
+
file_size = File.size(audio_file)
|
|
278
|
+
duration = estimate_audio_duration(file_size, file_ext)
|
|
279
|
+
|
|
280
|
+
{
|
|
281
|
+
file: File.open(audio_file, "rb"),
|
|
282
|
+
format: file_ext,
|
|
283
|
+
file_size: file_size,
|
|
284
|
+
duration: duration
|
|
285
|
+
}
|
|
286
|
+
end
|
|
287
|
+
|
|
288
|
+
def estimate_audio_duration(file_size, format)
|
|
289
|
+
# Rough estimation based on format and file size
|
|
290
|
+
# These are approximate values and may vary
|
|
291
|
+
case format
|
|
292
|
+
when "mp3", "m4a"
|
|
293
|
+
# Average bitrate ~128kbps
|
|
294
|
+
(file_size * 8) / (128 * 1024) # Convert to seconds
|
|
295
|
+
when "wav"
|
|
296
|
+
# WAV files are larger, estimate based on CD quality
|
|
297
|
+
(file_size / (44100 * 2 * 2)).to_i # 44.1kHz, 16-bit, stereo
|
|
298
|
+
when "webm"
|
|
299
|
+
# Variable bitrate, rough estimate
|
|
300
|
+
(file_size * 8) / (96 * 1024) # ~96kbps
|
|
301
|
+
else
|
|
302
|
+
# Default estimation
|
|
303
|
+
(file_size * 8) / (128 * 1024)
|
|
304
|
+
end
|
|
305
|
+
end
|
|
306
|
+
|
|
307
|
+
# Custom implementation for STT API call
|
|
308
|
+
def submit_stt_request(parameters)
|
|
309
|
+
uri = URI.parse("#{@config['url']}/audio/transcriptions")
|
|
310
|
+
|
|
311
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
|
312
|
+
http.use_ssl = (uri.scheme == 'https')
|
|
313
|
+
|
|
314
|
+
# Create multipart form data
|
|
315
|
+
boundary = "----WebKitFormBoundary#{Time.now.to_i}"
|
|
316
|
+
|
|
317
|
+
body = ""
|
|
318
|
+
body << "--#{boundary}\r\n"
|
|
319
|
+
body << "Content-Disposition: form-data; name=\"file\"; filename=\"#{File.basename(parameters[:file].path)}\"\r\n"
|
|
320
|
+
body << "Content-Type: audio/#{File.extname(parameters[:file].path).delete('.')}\r\n\r\n"
|
|
321
|
+
body << parameters[:file].read
|
|
322
|
+
body << "\r\n"
|
|
323
|
+
|
|
324
|
+
# Add other parameters
|
|
325
|
+
parameters.except(:file).each do |key, value|
|
|
326
|
+
body << "--#{boundary}\r\n"
|
|
327
|
+
body << "Content-Disposition: form-data; name=\"#{key}\"\r\n\r\n"
|
|
328
|
+
body << "#{value}\r\n"
|
|
329
|
+
end
|
|
330
|
+
|
|
331
|
+
body << "--#{boundary}--\r\n"
|
|
332
|
+
|
|
333
|
+
request = Net::HTTP::Post.new(uri.request_uri)
|
|
334
|
+
request['Content-Type'] = "multipart/form-data; boundary=#{boundary}"
|
|
335
|
+
request['Authorization'] = "Bearer #{@config['api_key']}"
|
|
336
|
+
request.body = body
|
|
337
|
+
|
|
338
|
+
response = http.request(request)
|
|
339
|
+
|
|
340
|
+
if response.is_a?(Net::HTTPSuccess)
|
|
341
|
+
JSON.parse(response.body)
|
|
342
|
+
else
|
|
343
|
+
raise LLMAPIError, "STT API error: #{response.code} - #{response.body}"
|
|
344
|
+
end
|
|
345
|
+
end
|
|
346
|
+
|
|
347
|
+
# Custom implementation for URL-based STT
|
|
348
|
+
def submit_stt_url_request(parameters)
|
|
349
|
+
uri = URI.parse("#{@config['url']}/audio/transcriptions")
|
|
350
|
+
|
|
351
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
|
352
|
+
http.use_ssl = (uri.scheme == 'https')
|
|
353
|
+
|
|
354
|
+
request = Net::HTTP::Post.new(uri.request_uri)
|
|
355
|
+
request['Content-Type'] = 'application/json'
|
|
356
|
+
request['Authorization'] = "Bearer #{@config['api_key']}"
|
|
357
|
+
|
|
358
|
+
request.body = parameters.to_json
|
|
359
|
+
|
|
360
|
+
response = http.request(request)
|
|
361
|
+
|
|
362
|
+
if response.is_a?(Net::HTTPSuccess)
|
|
363
|
+
JSON.parse(response.body)
|
|
364
|
+
else
|
|
365
|
+
raise LLMAPIError, "STT URL API error: #{response.code} - #{response.body}"
|
|
366
|
+
end
|
|
367
|
+
end
|
|
368
|
+
|
|
369
|
+
# Override send_request to provide a meaningful error for chat operations
|
|
370
|
+
def send_request(messages, model = nil, temperature = 0.7, tools = nil, proc = nil)
|
|
371
|
+
SmartPrompt.logger.error "STTAdapter does not support chat operations. Use transcribe_audio or transcribe_audio_url methods instead."
|
|
372
|
+
raise NotImplementedError, "STTAdapter does not support chat operations"
|
|
373
|
+
end
|
|
374
|
+
|
|
375
|
+
# Override embeddings method
|
|
376
|
+
def embeddings(text, model)
|
|
377
|
+
SmartPrompt.logger.error "STTAdapter does not support embeddings operations."
|
|
378
|
+
raise NotImplementedError, "STTAdapter does not support embeddings operations"
|
|
379
|
+
end
|
|
380
|
+
end
|
|
381
|
+
end
|