ruby_llm 1.0.1 → 1.1.0rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +28 -12
  3. data/lib/ruby_llm/active_record/acts_as.rb +46 -7
  4. data/lib/ruby_llm/aliases.json +65 -0
  5. data/lib/ruby_llm/aliases.rb +56 -0
  6. data/lib/ruby_llm/chat.rb +10 -9
  7. data/lib/ruby_llm/configuration.rb +4 -0
  8. data/lib/ruby_llm/error.rb +15 -4
  9. data/lib/ruby_llm/models.json +1163 -303
  10. data/lib/ruby_llm/models.rb +40 -11
  11. data/lib/ruby_llm/provider.rb +32 -39
  12. data/lib/ruby_llm/providers/anthropic/capabilities.rb +8 -9
  13. data/lib/ruby_llm/providers/anthropic/chat.rb +31 -4
  14. data/lib/ruby_llm/providers/anthropic/streaming.rb +12 -6
  15. data/lib/ruby_llm/providers/anthropic.rb +4 -0
  16. data/lib/ruby_llm/providers/bedrock/capabilities.rb +168 -0
  17. data/lib/ruby_llm/providers/bedrock/chat.rb +108 -0
  18. data/lib/ruby_llm/providers/bedrock/models.rb +84 -0
  19. data/lib/ruby_llm/providers/bedrock/signing.rb +831 -0
  20. data/lib/ruby_llm/providers/bedrock/streaming/base.rb +46 -0
  21. data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +63 -0
  22. data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +79 -0
  23. data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +90 -0
  24. data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +91 -0
  25. data/lib/ruby_llm/providers/bedrock/streaming.rb +36 -0
  26. data/lib/ruby_llm/providers/bedrock.rb +83 -0
  27. data/lib/ruby_llm/providers/deepseek/chat.rb +17 -0
  28. data/lib/ruby_llm/providers/deepseek.rb +5 -0
  29. data/lib/ruby_llm/providers/gemini/capabilities.rb +50 -34
  30. data/lib/ruby_llm/providers/gemini/chat.rb +8 -15
  31. data/lib/ruby_llm/providers/gemini/images.rb +5 -10
  32. data/lib/ruby_llm/providers/gemini/streaming.rb +35 -76
  33. data/lib/ruby_llm/providers/gemini/tools.rb +12 -12
  34. data/lib/ruby_llm/providers/gemini.rb +4 -0
  35. data/lib/ruby_llm/providers/openai/capabilities.rb +146 -206
  36. data/lib/ruby_llm/providers/openai/streaming.rb +9 -13
  37. data/lib/ruby_llm/providers/openai.rb +4 -0
  38. data/lib/ruby_llm/streaming.rb +96 -0
  39. data/lib/ruby_llm/version.rb +1 -1
  40. data/lib/ruby_llm.rb +6 -3
  41. data/lib/tasks/browser_helper.rb +97 -0
  42. data/lib/tasks/capability_generator.rb +123 -0
  43. data/lib/tasks/capability_scraper.rb +224 -0
  44. data/lib/tasks/cli_helper.rb +22 -0
  45. data/lib/tasks/code_validator.rb +29 -0
  46. data/lib/tasks/model_updater.rb +66 -0
  47. data/lib/tasks/models.rake +28 -193
  48. data/lib/tasks/vcr.rake +13 -30
  49. metadata +27 -19
  50. data/.github/workflows/cicd.yml +0 -158
  51. data/.github/workflows/docs.yml +0 -53
  52. data/.gitignore +0 -59
  53. data/.overcommit.yml +0 -26
  54. data/.rspec +0 -3
  55. data/.rubocop.yml +0 -10
  56. data/.yardopts +0 -12
  57. data/CONTRIBUTING.md +0 -207
  58. data/Gemfile +0 -33
  59. data/Rakefile +0 -9
  60. data/bin/console +0 -17
  61. data/bin/setup +0 -6
  62. data/ruby_llm.gemspec +0 -44
@@ -7,229 +7,169 @@ module RubyLLM
7
7
  module Capabilities # rubocop:disable Metrics/ModuleLength
8
8
  module_function
9
9
 
10
- # Returns the context window size for the given model ID
11
- # @param model_id [String] the model identifier
12
- # @return [Integer] the context window size in tokens
13
- def context_window_for(model_id)
14
- case model_id
15
- when /o1-2024/, /o3-mini/, /o3-mini-2025/ then 200_000
16
- when /gpt-4o/, /gpt-4o-mini/, /gpt-4-turbo/, /o1-mini/ then 128_000
17
- when /gpt-4-0[0-9]{3}/ then 8_192
18
- when /gpt-3.5-turbo$/, /babbage-002/, /davinci-002/, /16k/ then 16_385
10
+ MODEL_PATTERNS = {
11
+ dall_e: /^dall-e/,
12
+ chatgpt4o: /^chatgpt-4o/,
13
+ gpt4: /^gpt-4(?:-\d{6})?$/,
14
+ gpt4_turbo: /^gpt-4(?:\.5)?-(?:\d{6}-)?(preview|turbo)/,
15
+ gpt35_turbo: /^gpt-3\.5-turbo/,
16
+ gpt4o: /^gpt-4o(?!-(?:mini|audio|realtime|transcribe|tts|search))/,
17
+ gpt4o_audio: /^gpt-4o-(?:audio)/,
18
+ gpt4o_mini: /^gpt-4o-mini(?!-(?:audio|realtime|transcribe|tts|search))/,
19
+ gpt4o_mini_audio: /^gpt-4o-mini-audio/,
20
+ gpt4o_mini_realtime: /^gpt-4o-mini-realtime/,
21
+ gpt4o_mini_transcribe: /^gpt-4o-mini-transcribe/,
22
+ gpt4o_mini_tts: /^gpt-4o-mini-tts/,
23
+ gpt4o_realtime: /^gpt-4o-realtime/,
24
+ gpt4o_search: /^gpt-4o-search/,
25
+ gpt4o_transcribe: /^gpt-4o-transcribe/,
26
+ o1: /^o1(?!-(?:mini|pro))/,
27
+ o1_mini: /^o1-mini/,
28
+ o1_pro: /^o1-pro/,
29
+ o3_mini: /^o3-mini/,
30
+ babbage: /^babbage/,
31
+ davinci: /^davinci/,
32
+ embedding3_large: /^text-embedding-3-large/,
33
+ embedding3_small: /^text-embedding-3-small/,
34
+ embedding_ada: /^text-embedding-ada/,
35
+ tts1: /^tts-1(?!-hd)/,
36
+ tts1_hd: /^tts-1-hd/,
37
+ whisper: /^whisper/,
38
+ moderation: /^(?:omni|text)-moderation/
39
+ }.freeze
40
+
41
+ def context_window_for(model_id) # rubocop:disable Metrics/MethodLength
42
+ case model_family(model_id)
43
+ when 'chatgpt4o', 'gpt4_turbo', 'gpt4o', 'gpt4o_audio', 'gpt4o_mini',
44
+ 'gpt4o_mini_audio', 'gpt4o_mini_realtime', 'gpt4o_realtime',
45
+ 'gpt4o_search', 'gpt4o_transcribe', 'gpt4o_mini_search', 'o1_mini' then 128_000
46
+ when 'gpt4' then 8_192
47
+ when 'gpt4o_mini_transcribe' then 16_000
48
+ when 'o1', 'o1_pro', 'o3_mini' then 200_000
49
+ when 'gpt35_turbo' then 16_385
50
+ when 'gpt4o_mini_tts', 'tts1', 'tts1_hd', 'whisper', 'moderation',
51
+ 'embedding3_large', 'embedding3_small', 'embedding_ada' then nil
19
52
  else 4_096
20
53
  end
21
54
  end
22
55
 
23
- # Returns the maximum output tokens for the given model ID
24
- # @param model_id [String] the model identifier
25
- # @return [Integer] the maximum output tokens
26
- def max_tokens_for(model_id)
27
- case model_id
28
- when /o1-2024/, /o3-mini/, /o3-mini-2025/ then 100_000
29
- when /o1-mini-2024/ then 65_536
30
- when /gpt-4o/, /gpt-4o-mini/, /gpt-4o-audio/, /gpt-4o-mini-audio/, /babbage-002/, /davinci-002/ then 16_384
31
- when /gpt-4-0[0-9]{3}/ then 8_192
32
- else 4_096
56
+ def max_tokens_for(model_id) # rubocop:disable Metrics/CyclomaticComplexity,Metrics/MethodLength
57
+ case model_family(model_id)
58
+ when 'chatgpt4o', 'gpt4o', 'gpt4o_mini', 'gpt4o_mini_search' then 16_384
59
+ when 'babbage', 'davinci' then 16_384 # rubocop:disable Lint/DuplicateBranch
60
+ when 'gpt4' then 8_192
61
+ when 'gpt35_turbo' then 4_096
62
+ when 'gpt4_turbo', 'gpt4o_realtime', 'gpt4o_mini_realtime' then 4_096 # rubocop:disable Lint/DuplicateBranch
63
+ when 'gpt4o_mini_transcribe' then 2_000
64
+ when 'o1', 'o1_pro', 'o3_mini' then 100_000
65
+ when 'o1_mini' then 65_536
66
+ when 'gpt4o_mini_tts', 'tts1', 'tts1_hd', 'whisper', 'moderation',
67
+ 'embedding3_large', 'embedding3_small', 'embedding_ada' then nil
68
+ else 16_384 # rubocop:disable Lint/DuplicateBranch
33
69
  end
34
70
  end
35
71
 
36
- # Returns the input price per million tokens for the given model ID
37
- # @param model_id [String] the model identifier
38
- # @return [Float] the price per million tokens for input
39
- def input_price_for(model_id)
40
- PRICES.dig(model_family(model_id), :input) || default_input_price
72
+ def supports_vision?(model_id)
73
+ case model_family(model_id)
74
+ when 'chatgpt4o', 'gpt4', 'gpt4_turbo', 'gpt4o', 'gpt4o_mini', 'o1', 'o1_pro',
75
+ 'moderation', 'gpt4o_search', 'gpt4o_mini_search' then true
76
+ else false
77
+ end
41
78
  end
42
79
 
43
- # Returns the output price per million tokens for the given model ID
44
- # @param model_id [String] the model identifier
45
- # @return [Float] the price per million tokens for output
46
- def output_price_for(model_id)
47
- PRICES.dig(model_family(model_id), :output) || default_output_price
80
+ def supports_functions?(model_id)
81
+ case model_family(model_id)
82
+ when 'gpt4', 'gpt4_turbo', 'gpt4o', 'gpt4o_mini', 'o1', 'o1_pro', 'o3_mini' then true
83
+ when 'chatgpt4o', 'gpt35_turbo', 'o1_mini', 'gpt4o_mini_tts',
84
+ 'gpt4o_transcribe', 'gpt4o_search', 'gpt4o_mini_search' then false
85
+ else false # rubocop:disable Lint/DuplicateBranch
86
+ end
48
87
  end
49
88
 
50
- # Determines if the model supports vision capabilities
51
- # @param model_id [String] the model identifier
52
- # @return [Boolean] true if the model supports vision
53
- def supports_vision?(model_id) # rubocop:disable Metrics/MethodLength
54
- supporting_patterns = [
55
- /^o1$/,
56
- /^o1-(?!.*mini|.*preview).*$/,
57
- /gpt-4\.5/,
58
- /^gpt-4o$/,
59
- /gpt-4o-2024/,
60
- /gpt-4o-search/,
61
- /^gpt-4o-mini$/,
62
- /gpt-4o-mini-2024/,
63
- /gpt-4o-mini-search/,
64
- /chatgpt-4o/,
65
- /gpt-4-turbo-2024/,
66
- /computer-use-preview/,
67
- /omni-moderation/
68
- ]
69
- supporting_patterns.any? { |regex| model_id.match?(regex) }
89
+ def supports_structured_output?(model_id)
90
+ case model_family(model_id)
91
+ when 'chatgpt4o', 'gpt4o', 'gpt4o_mini', 'o1', 'o1_pro', 'o3_mini' then true
92
+ else false
93
+ end
70
94
  end
71
95
 
72
- # Determines if the model supports function calling
73
- # @param model_id [String] the model identifier
74
- # @return [Boolean] true if the model supports functions
75
- def supports_functions?(model_id) # rubocop:disable Metrics/MethodLength
76
- supporting_patterns = [
77
- /^o1$/,
78
- /gpt-4o/,
79
- /gpt-4\.5/,
80
- /chatgpt-4o/,
81
- /gpt-4-turbo/,
82
- /computer-use-preview/,
83
- /o1-preview/,
84
- /o1-\d{4}-\d{2}-\d{2}/,
85
- /o1-pro/,
86
- /o3-mini/
87
- ]
88
- supporting_patterns.any? { |regex| model_id.match?(regex) }
96
+ def supports_json_mode?(model_id)
97
+ supports_structured_output?(model_id)
89
98
  end
90
99
 
91
- # Determines if the model supports audio input/output
92
- # @param model_id [String] the model identifier
93
- # @return [Boolean] true if the model supports audio
94
- def supports_audio?(model_id)
95
- model_id.match?(/audio-preview|realtime-preview|whisper|tts/)
100
+ PRICES = {
101
+ chatgpt4o: { input: 5.0, output: 15.0 },
102
+ gpt4: { input: 10.0, output: 30.0 },
103
+ gpt4_turbo: { input: 10.0, output: 30.0 },
104
+ gpt45: { input: 75.0, output: 150.0 },
105
+ gpt35_turbo: { input: 0.5, output: 1.5 },
106
+ gpt4o: { input: 2.5, output: 10.0 },
107
+ gpt4o_audio: { input: 2.5, output: 10.0, audio_input: 40.0, audio_output: 80.0 },
108
+ gpt4o_mini: { input: 0.15, output: 0.6 },
109
+ gpt4o_mini_audio: { input: 0.15, output: 0.6, audio_input: 10.0, audio_output: 20.0 },
110
+ gpt4o_mini_realtime: { input: 0.6, output: 2.4 },
111
+ gpt4o_mini_transcribe: { input: 1.25, output: 5.0, audio_input: 3.0 },
112
+ gpt4o_mini_tts: { input: 0.6, output: 12.0 },
113
+ gpt4o_realtime: { input: 5.0, output: 20.0 },
114
+ gpt4o_search: { input: 2.5, output: 10.0 },
115
+ gpt4o_transcribe: { input: 2.5, output: 10.0, audio_input: 6.0 },
116
+ o1: { input: 15.0, output: 60.0 },
117
+ o1_mini: { input: 1.1, output: 4.4 },
118
+ o1_pro: { input: 150.0, output: 600.0 },
119
+ o3_mini: { input: 1.1, output: 4.4 },
120
+ babbage: { input: 0.4, output: 0.4 },
121
+ davinci: { input: 2.0, output: 2.0 },
122
+ embedding3_large: { price: 0.13 },
123
+ embedding3_small: { price: 0.02 },
124
+ embedding_ada: { price: 0.10 },
125
+ tts1: { price: 15.0 },
126
+ tts1_hd: { price: 30.0 },
127
+ whisper: { price: 0.006 },
128
+ moderation: { price: 0.0 }
129
+ }.freeze
130
+
131
+ def model_family(model_id)
132
+ MODEL_PATTERNS.each do |family, pattern|
133
+ return family.to_s if model_id.match?(pattern)
134
+ end
135
+ 'other'
96
136
  end
97
137
 
98
- # Determines if the model supports JSON mode
99
- # @param model_id [String] the model identifier
100
- # @return [Boolean] true if the model supports JSON mode
101
- def supports_json_mode?(model_id)
102
- model_id.match?(/gpt-4-\d{4}-preview/) ||
103
- model_id.include?('turbo') ||
104
- model_id.match?(/gpt-3.5-turbo-(?!0301|0613)/)
138
+ def input_price_for(model_id)
139
+ family = model_family(model_id).to_sym
140
+ prices = PRICES.fetch(family, { input: default_input_price })
141
+ prices[:input] || prices[:price] || default_input_price
105
142
  end
106
143
 
107
- # Formats the model ID into a human-readable display name
108
- # @param model_id [String] the model identifier
109
- # @return [String] the formatted display name
110
- def format_display_name(model_id)
111
- model_id.then { |id| humanize(id) }
112
- .then { |name| apply_special_formatting(name) }
144
+ def output_price_for(model_id)
145
+ family = model_family(model_id).to_sym
146
+ prices = PRICES.fetch(family, { output: default_output_price })
147
+ prices[:output] || prices[:price] || default_output_price
113
148
  end
114
149
 
115
- # Determines the type of model
116
- # @param model_id [String] the model identifier
117
- # @return [String] the model type (chat, embedding, image, audio, moderation)
118
150
  def model_type(model_id)
119
- case model_id
120
- when /text-embedding|embedding/ then 'embedding'
121
- when /dall-e/ then 'image'
122
- when /tts|whisper/ then 'audio'
123
- when /omni-moderation|text-moderation/ then 'moderation'
151
+ case model_family(model_id)
152
+ when /embedding/ then 'embedding'
153
+ when /^tts|whisper|gpt4o_(?:mini_)?(?:transcribe|tts)$/ then 'audio'
154
+ when 'moderation' then 'moderation'
155
+ when /dall/ then 'image'
124
156
  else 'chat'
125
157
  end
126
158
  end
127
159
 
128
- # Determines if the model supports structured output
129
- # @param model_id [String] the model identifier
130
- # @return [Boolean] true if the model supports structured output
131
- def supports_structured_output?(model_id)
132
- model_id.match?(/gpt-4o|o[13]-mini|o1|o3-mini/)
133
- end
134
-
135
- # Determines the model family for pricing and capability lookup
136
- # @param model_id [String] the model identifier
137
- # @return [Symbol] the model family identifier
138
- def model_family(model_id) # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metrics/MethodLength
139
- case model_id
140
- when /o3-mini/ then 'o3_mini'
141
- when /o1-mini/ then 'o1_mini'
142
- when /o1/ then 'o1'
143
- when /gpt-4o-audio/ then 'gpt4o_audio'
144
- when /gpt-4o-realtime/ then 'gpt4o_realtime'
145
- when /gpt-4o-mini-audio/ then 'gpt4o_mini_audio'
146
- when /gpt-4o-mini-realtime/ then 'gpt4o_mini_realtime'
147
- when /gpt-4o-mini/ then 'gpt4o_mini'
148
- when /gpt-4o/ then 'gpt4o'
149
- when /gpt-4-turbo/ then 'gpt4_turbo'
150
- when /gpt-4/ then 'gpt4'
151
- when /gpt-3.5-turbo-instruct/ then 'gpt35_instruct'
152
- when /gpt-3.5/ then 'gpt35'
153
- when /dall-e-3/ then 'dalle3'
154
- when /dall-e-2/ then 'dalle2'
155
- when /text-embedding-3-large/ then 'embedding3_large'
156
- when /text-embedding-3-small/ then 'embedding3_small'
157
- when /text-embedding-ada/ then 'embedding2'
158
- when /tts-1-hd/ then 'tts1_hd'
159
- when /tts-1/ then 'tts1'
160
- when /whisper/ then 'whisper1'
161
- when /omni-moderation|text-moderation/ then 'moderation'
162
- when /babbage/ then 'babbage'
163
- when /davinci/ then 'davinci'
164
- else 'other'
165
- end
166
- end
167
-
168
- # Pricing information for OpenAI models (per million tokens unless otherwise specified)
169
- PRICES = {
170
- o1: { input: 15.0, cached_input: 7.5, output: 60.0 },
171
- o1_mini: { input: 1.10, cached_input: 0.55, output: 4.40 },
172
- o3_mini: { input: 1.10, cached_input: 0.55, output: 4.40 },
173
- gpt4o: { input: 2.50, cached_input: 1.25, output: 10.0 },
174
- gpt4o_audio: {
175
- text_input: 2.50,
176
- audio_input: 40.0,
177
- text_output: 10.0,
178
- audio_output: 80.0
179
- },
180
- gpt4o_realtime: {
181
- text_input: 5.0,
182
- cached_text_input: 2.50,
183
- audio_input: 40.0,
184
- cached_audio_input: 2.50,
185
- text_output: 20.0,
186
- audio_output: 80.0
187
- },
188
- gpt4o_mini: { input: 0.15, cached_input: 0.075, output: 0.60 },
189
- gpt4o_mini_audio: {
190
- text_input: 0.15,
191
- audio_input: 10.0,
192
- text_output: 0.60,
193
- audio_output: 20.0
194
- },
195
- gpt4o_mini_realtime: {
196
- text_input: 0.60,
197
- cached_text_input: 0.30,
198
- audio_input: 10.0,
199
- cached_audio_input: 0.30,
200
- text_output: 2.40,
201
- audio_output: 20.0
202
- },
203
- gpt4_turbo: { input: 10.0, output: 30.0 },
204
- gpt4: { input: 30.0, output: 60.0 },
205
- gpt35: { input: 0.50, output: 1.50 },
206
- gpt35_instruct: { input: 1.50, output: 2.0 },
207
- embedding3_large: { price: 0.13 },
208
- embedding3_small: { price: 0.02 },
209
- embedding2: { price: 0.10 },
210
- davinci: { input: 2.0, output: 2.0 },
211
- babbage: { input: 0.40, output: 0.40 },
212
- tts1: { price: 15.0 }, # per million characters
213
- tts1_hd: { price: 30.0 }, # per million characters
214
- whisper1: { price: 0.006 }, # per minute
215
- moderation: { price: 0.0 } # free
216
- }.freeze
217
-
218
- # Default input price when model-specific pricing is not available
219
- # @return [Float] the default price per million tokens
220
160
  def default_input_price
221
161
  0.50
222
162
  end
223
163
 
224
- # Default output price when model-specific pricing is not available
225
- # @return [Float] the default price per million tokens
226
164
  def default_output_price
227
165
  1.50
228
166
  end
229
167
 
230
- # Converts a model ID to a human-readable format
231
- # @param id [String] the model identifier
232
- # @return [String] the humanized model name
168
+ def format_display_name(model_id)
169
+ model_id.then { |id| humanize(id) }
170
+ .then { |name| apply_special_formatting(name) }
171
+ end
172
+
233
173
  def humanize(id)
234
174
  id.tr('-', ' ')
235
175
  .split
@@ -237,30 +177,30 @@ module RubyLLM
237
177
  .join(' ')
238
178
  end
239
179
 
240
- # Applies special formatting rules to model names
241
- # @param name [String] the humanized model name
242
- # @return [String] the specially formatted model name
243
- def apply_special_formatting(name) # rubocop:disable Metrics/MethodLength
180
+ def apply_special_formatting(name)
244
181
  name
245
182
  .gsub(/(\d{4}) (\d{2}) (\d{2})/, '\1\2\3')
246
- .gsub(/^Gpt /, 'GPT-')
183
+ .gsub(/^(?:Gpt|Chatgpt|Tts|Dall E) /) { |m| special_prefix_format(m.strip) }
247
184
  .gsub(/^O([13]) /, 'O\1-')
248
- .gsub(/^O3 Mini/, 'O3-Mini')
249
- .gsub(/^O1 Mini/, 'O1-Mini')
250
- .gsub(/^Chatgpt /, 'ChatGPT-')
251
- .gsub(/^Tts /, 'TTS-')
252
- .gsub(/^Dall E /, 'DALL-E-')
253
- .gsub('3.5 ', '3.5-')
254
- .gsub('4 ', '4-')
255
- .gsub(/4o (?=Mini|Preview|Turbo|Audio|Realtime)/, '4o-')
185
+ .gsub(/^O[13] Mini/, '\0'.gsub(' ', '-'))
186
+ .gsub(/\d\.\d /, '\0'.sub(' ', '-'))
187
+ .gsub(/4o (?=Mini|Preview|Turbo|Audio|Realtime|Transcribe|Tts)/, '4o-')
256
188
  .gsub(/\bHd\b/, 'HD')
257
- .gsub('Omni Moderation', 'Omni-Moderation')
258
- .gsub('Text Moderation', 'Text-Moderation')
189
+ .gsub(/(?:Omni|Text) Moderation/, '\0'.gsub(' ', '-'))
190
+ .gsub('Text Embedding', 'text-embedding-')
191
+ end
192
+
193
+ def special_prefix_format(prefix)
194
+ case prefix # rubocop:disable Style/HashLikeCase
195
+ when 'Gpt' then 'GPT-'
196
+ when 'Chatgpt' then 'ChatGPT-'
197
+ when 'Tts' then 'TTS-'
198
+ when 'Dall E' then 'DALL-E-'
199
+ end
259
200
  end
260
201
 
261
202
  def normalize_temperature(temperature, model_id)
262
- if model_id.match?(/o[13]/)
263
- # O1/O3 models always use temperature 1.0
203
+ if model_id.match?(/^o[13]/)
264
204
  RubyLLM.logger.debug "Model #{model_id} requires temperature=1.0, ignoring provided value"
265
205
  1.0
266
206
  else
@@ -11,19 +11,15 @@ module RubyLLM
11
11
  completion_url
12
12
  end
13
13
 
14
- def handle_stream(&block) # rubocop:disable Metrics/MethodLength
15
- to_json_stream do |data|
16
- block.call(
17
- Chunk.new(
18
- role: :assistant,
19
- model_id: data['model'],
20
- content: data.dig('choices', 0, 'delta', 'content'),
21
- tool_calls: parse_tool_calls(data.dig('choices', 0, 'delta', 'tool_calls'), parse_arguments: false),
22
- input_tokens: data.dig('usage', 'prompt_tokens'),
23
- output_tokens: data.dig('usage', 'completion_tokens')
24
- )
25
- )
26
- end
14
+ def build_chunk(data)
15
+ Chunk.new(
16
+ role: :assistant,
17
+ model_id: data['model'],
18
+ content: data.dig('choices', 0, 'delta', 'content'),
19
+ tool_calls: parse_tool_calls(data.dig('choices', 0, 'delta', 'tool_calls'), parse_arguments: false),
20
+ input_tokens: data.dig('usage', 'prompt_tokens'),
21
+ output_tokens: data.dig('usage', 'completion_tokens')
22
+ )
27
23
  end
28
24
  end
29
25
  end
@@ -45,6 +45,10 @@ module RubyLLM
45
45
  def slug
46
46
  'openai'
47
47
  end
48
+
49
+ def configuration_requirements
50
+ %i[openai_api_key]
51
+ end
48
52
  end
49
53
  end
50
54
  end
@@ -0,0 +1,96 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ # Handles streaming responses from AI providers. Provides a unified way to process
5
+ # chunked responses, accumulate content, and handle provider-specific streaming formats.
6
+ # Each provider implements provider-specific parsing while sharing common stream handling
7
+ # patterns.
8
+ module Streaming
9
+ module_function
10
+
11
+ def stream_response(payload, &block)
12
+ accumulator = StreamAccumulator.new
13
+
14
+ post stream_url, payload do |req|
15
+ req.options.on_data = handle_stream do |chunk|
16
+ accumulator.add chunk
17
+ block.call chunk
18
+ end
19
+ end
20
+
21
+ accumulator.to_message
22
+ end
23
+
24
+ def handle_stream(&block)
25
+ to_json_stream do |data|
26
+ block.call(build_chunk(data)) if data
27
+ end
28
+ end
29
+
30
+ private
31
+
32
+ def to_json_stream(&block) # rubocop:disable Metrics/MethodLength
33
+ buffer = String.new
34
+ parser = EventStreamParser::Parser.new
35
+
36
+ proc do |chunk, _bytes, env|
37
+ RubyLLM.logger.debug "Received chunk: #{chunk}"
38
+
39
+ if error_chunk?(chunk)
40
+ handle_error_chunk(chunk, env)
41
+ elsif env&.status != 200
42
+ handle_failed_response(chunk, buffer, env)
43
+ else
44
+ yield handle_sse(chunk, parser, env, &block)
45
+ end
46
+ end
47
+ end
48
+
49
+ def error_chunk?(chunk)
50
+ chunk.start_with?('event: error')
51
+ end
52
+
53
+ def handle_error_chunk(chunk, env)
54
+ error_data = chunk.split("\n")[1].delete_prefix('data: ')
55
+ status, _message = parse_streaming_error(error_data)
56
+ error_response = env.merge(body: JSON.parse(error_data), status: status)
57
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
58
+ rescue JSON::ParserError => e
59
+ RubyLLM.logger.debug "Failed to parse error chunk: #{e.message}"
60
+ end
61
+
62
+ def handle_failed_response(chunk, buffer, env)
63
+ buffer << chunk
64
+ error_data = JSON.parse(buffer)
65
+ error_response = env.merge(body: error_data)
66
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
67
+ rescue JSON::ParserError
68
+ RubyLLM.logger.debug "Accumulating error chunk: #{chunk}"
69
+ end
70
+
71
+ def handle_sse(chunk, parser, env, &block)
72
+ parser.feed(chunk) do |type, data|
73
+ case type.to_sym
74
+ when :error
75
+ handle_error_event(data, env)
76
+ else
77
+ yield handle_data(data, &block) unless data == '[DONE]'
78
+ end
79
+ end
80
+ end
81
+
82
+ def handle_data(data)
83
+ JSON.parse(data)
84
+ rescue JSON::ParserError => e
85
+ RubyLLM.logger.debug "Failed to parse data chunk: #{e.message}"
86
+ end
87
+
88
+ def handle_error_event(data, env)
89
+ status, _message = parse_streaming_error(data)
90
+ error_response = env.merge(body: JSON.parse(data), status: status)
91
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
92
+ rescue JSON::ParserError => e
93
+ RubyLLM.logger.debug "Failed to parse error event: #{e.message}"
94
+ end
95
+ end
96
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module RubyLLM
4
- VERSION = '1.0.1'
4
+ VERSION = '1.1.0rc1'
5
5
  end
data/lib/ruby_llm.rb CHANGED
@@ -15,8 +15,10 @@ loader.inflector.inflect(
15
15
  'llm' => 'LLM',
16
16
  'openai' => 'OpenAI',
17
17
  'api' => 'API',
18
- 'deepseek' => 'DeepSeek'
18
+ 'deepseek' => 'DeepSeek',
19
+ 'bedrock' => 'Bedrock'
19
20
  )
21
+ loader.ignore("#{__dir__}/tasks")
20
22
  loader.ignore("#{__dir__}/ruby_llm/railtie")
21
23
  loader.ignore("#{__dir__}/ruby_llm/active_record")
22
24
  loader.setup
@@ -28,8 +30,8 @@ module RubyLLM
28
30
  class Error < StandardError; end
29
31
 
30
32
  class << self
31
- def chat(model: nil)
32
- Chat.new(model: model)
33
+ def chat(model: nil, provider: nil)
34
+ Chat.new(model: model, provider: provider)
33
35
  end
34
36
 
35
37
  def embed(...)
@@ -70,6 +72,7 @@ RubyLLM::Provider.register :openai, RubyLLM::Providers::OpenAI
70
72
  RubyLLM::Provider.register :anthropic, RubyLLM::Providers::Anthropic
71
73
  RubyLLM::Provider.register :gemini, RubyLLM::Providers::Gemini
72
74
  RubyLLM::Provider.register :deepseek, RubyLLM::Providers::DeepSeek
75
+ RubyLLM::Provider.register :bedrock, RubyLLM::Providers::Bedrock
73
76
 
74
77
  if defined?(Rails::Railtie)
75
78
  require 'ruby_llm/railtie'