ruby_llm 1.0.0 → 1.1.0rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +58 -19
  3. data/lib/ruby_llm/active_record/acts_as.rb +46 -7
  4. data/lib/ruby_llm/aliases.json +65 -0
  5. data/lib/ruby_llm/aliases.rb +56 -0
  6. data/lib/ruby_llm/chat.rb +11 -10
  7. data/lib/ruby_llm/configuration.rb +4 -0
  8. data/lib/ruby_llm/error.rb +15 -4
  9. data/lib/ruby_llm/models.json +1489 -283
  10. data/lib/ruby_llm/models.rb +57 -22
  11. data/lib/ruby_llm/provider.rb +44 -41
  12. data/lib/ruby_llm/providers/anthropic/capabilities.rb +8 -9
  13. data/lib/ruby_llm/providers/anthropic/chat.rb +31 -4
  14. data/lib/ruby_llm/providers/anthropic/streaming.rb +12 -6
  15. data/lib/ruby_llm/providers/anthropic.rb +4 -0
  16. data/lib/ruby_llm/providers/bedrock/capabilities.rb +168 -0
  17. data/lib/ruby_llm/providers/bedrock/chat.rb +108 -0
  18. data/lib/ruby_llm/providers/bedrock/models.rb +84 -0
  19. data/lib/ruby_llm/providers/bedrock/signing.rb +831 -0
  20. data/lib/ruby_llm/providers/bedrock/streaming/base.rb +46 -0
  21. data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +63 -0
  22. data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +79 -0
  23. data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +90 -0
  24. data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +91 -0
  25. data/lib/ruby_llm/providers/bedrock/streaming.rb +36 -0
  26. data/lib/ruby_llm/providers/bedrock.rb +83 -0
  27. data/lib/ruby_llm/providers/deepseek/chat.rb +17 -0
  28. data/lib/ruby_llm/providers/deepseek.rb +5 -0
  29. data/lib/ruby_llm/providers/gemini/capabilities.rb +50 -34
  30. data/lib/ruby_llm/providers/gemini/chat.rb +8 -15
  31. data/lib/ruby_llm/providers/gemini/images.rb +5 -10
  32. data/lib/ruby_llm/providers/gemini/models.rb +0 -8
  33. data/lib/ruby_llm/providers/gemini/streaming.rb +35 -76
  34. data/lib/ruby_llm/providers/gemini/tools.rb +12 -12
  35. data/lib/ruby_llm/providers/gemini.rb +4 -0
  36. data/lib/ruby_llm/providers/openai/capabilities.rb +154 -177
  37. data/lib/ruby_llm/providers/openai/streaming.rb +9 -13
  38. data/lib/ruby_llm/providers/openai.rb +4 -0
  39. data/lib/ruby_llm/streaming.rb +96 -0
  40. data/lib/ruby_llm/tool.rb +15 -7
  41. data/lib/ruby_llm/version.rb +1 -1
  42. data/lib/ruby_llm.rb +8 -3
  43. data/lib/tasks/browser_helper.rb +97 -0
  44. data/lib/tasks/capability_generator.rb +123 -0
  45. data/lib/tasks/capability_scraper.rb +224 -0
  46. data/lib/tasks/cli_helper.rb +22 -0
  47. data/lib/tasks/code_validator.rb +29 -0
  48. data/lib/tasks/model_updater.rb +66 -0
  49. data/lib/tasks/models.rake +28 -197
  50. data/lib/tasks/vcr.rake +97 -0
  51. metadata +42 -19
  52. data/.github/workflows/cicd.yml +0 -109
  53. data/.github/workflows/docs.yml +0 -53
  54. data/.gitignore +0 -58
  55. data/.overcommit.yml +0 -26
  56. data/.rspec +0 -3
  57. data/.rspec_status +0 -50
  58. data/.rubocop.yml +0 -10
  59. data/.yardopts +0 -12
  60. data/Gemfile +0 -32
  61. data/Rakefile +0 -9
  62. data/bin/console +0 -17
  63. data/bin/setup +0 -6
  64. data/ruby_llm.gemspec +0 -43
@@ -7,202 +7,169 @@ module RubyLLM
7
7
  module Capabilities # rubocop:disable Metrics/ModuleLength
8
8
  module_function
9
9
 
10
- # Returns the context window size for the given model ID
11
- # @param model_id [String] the model identifier
12
- # @return [Integer] the context window size in tokens
13
- def context_window_for(model_id)
14
- case model_id
15
- when /o1-2024/, /o3-mini/, /o3-mini-2025/ then 200_000
16
- when /gpt-4o/, /gpt-4o-mini/, /gpt-4-turbo/, /o1-mini/ then 128_000
17
- when /gpt-4-0[0-9]{3}/ then 8_192
18
- when /gpt-3.5-turbo$/, /babbage-002/, /davinci-002/, /16k/ then 16_385
10
+ MODEL_PATTERNS = {
11
+ dall_e: /^dall-e/,
12
+ chatgpt4o: /^chatgpt-4o/,
13
+ gpt4: /^gpt-4(?:-\d{6})?$/,
14
+ gpt4_turbo: /^gpt-4(?:\.5)?-(?:\d{6}-)?(preview|turbo)/,
15
+ gpt35_turbo: /^gpt-3\.5-turbo/,
16
+ gpt4o: /^gpt-4o(?!-(?:mini|audio|realtime|transcribe|tts|search))/,
17
+ gpt4o_audio: /^gpt-4o-(?:audio)/,
18
+ gpt4o_mini: /^gpt-4o-mini(?!-(?:audio|realtime|transcribe|tts|search))/,
19
+ gpt4o_mini_audio: /^gpt-4o-mini-audio/,
20
+ gpt4o_mini_realtime: /^gpt-4o-mini-realtime/,
21
+ gpt4o_mini_transcribe: /^gpt-4o-mini-transcribe/,
22
+ gpt4o_mini_tts: /^gpt-4o-mini-tts/,
23
+ gpt4o_realtime: /^gpt-4o-realtime/,
24
+ gpt4o_search: /^gpt-4o-search/,
25
+ gpt4o_transcribe: /^gpt-4o-transcribe/,
26
+ o1: /^o1(?!-(?:mini|pro))/,
27
+ o1_mini: /^o1-mini/,
28
+ o1_pro: /^o1-pro/,
29
+ o3_mini: /^o3-mini/,
30
+ babbage: /^babbage/,
31
+ davinci: /^davinci/,
32
+ embedding3_large: /^text-embedding-3-large/,
33
+ embedding3_small: /^text-embedding-3-small/,
34
+ embedding_ada: /^text-embedding-ada/,
35
+ tts1: /^tts-1(?!-hd)/,
36
+ tts1_hd: /^tts-1-hd/,
37
+ whisper: /^whisper/,
38
+ moderation: /^(?:omni|text)-moderation/
39
+ }.freeze
40
+
41
+ def context_window_for(model_id) # rubocop:disable Metrics/MethodLength
42
+ case model_family(model_id)
43
+ when 'chatgpt4o', 'gpt4_turbo', 'gpt4o', 'gpt4o_audio', 'gpt4o_mini',
44
+ 'gpt4o_mini_audio', 'gpt4o_mini_realtime', 'gpt4o_realtime',
45
+ 'gpt4o_search', 'gpt4o_transcribe', 'gpt4o_mini_search', 'o1_mini' then 128_000
46
+ when 'gpt4' then 8_192
47
+ when 'gpt4o_mini_transcribe' then 16_000
48
+ when 'o1', 'o1_pro', 'o3_mini' then 200_000
49
+ when 'gpt35_turbo' then 16_385
50
+ when 'gpt4o_mini_tts', 'tts1', 'tts1_hd', 'whisper', 'moderation',
51
+ 'embedding3_large', 'embedding3_small', 'embedding_ada' then nil
19
52
  else 4_096
20
53
  end
21
54
  end
22
55
 
23
- # Returns the maximum output tokens for the given model ID
24
- # @param model_id [String] the model identifier
25
- # @return [Integer] the maximum output tokens
26
- def max_tokens_for(model_id)
27
- case model_id
28
- when /o1-2024/, /o3-mini/, /o3-mini-2025/ then 100_000
29
- when /o1-mini-2024/ then 65_536
30
- when /gpt-4o/, /gpt-4o-mini/, /gpt-4o-audio/, /gpt-4o-mini-audio/, /babbage-002/, /davinci-002/ then 16_384
31
- when /gpt-4-0[0-9]{3}/ then 8_192
32
- else 4_096
56
+ def max_tokens_for(model_id) # rubocop:disable Metrics/CyclomaticComplexity,Metrics/MethodLength
57
+ case model_family(model_id)
58
+ when 'chatgpt4o', 'gpt4o', 'gpt4o_mini', 'gpt4o_mini_search' then 16_384
59
+ when 'babbage', 'davinci' then 16_384 # rubocop:disable Lint/DuplicateBranch
60
+ when 'gpt4' then 8_192
61
+ when 'gpt35_turbo' then 4_096
62
+ when 'gpt4_turbo', 'gpt4o_realtime', 'gpt4o_mini_realtime' then 4_096 # rubocop:disable Lint/DuplicateBranch
63
+ when 'gpt4o_mini_transcribe' then 2_000
64
+ when 'o1', 'o1_pro', 'o3_mini' then 100_000
65
+ when 'o1_mini' then 65_536
66
+ when 'gpt4o_mini_tts', 'tts1', 'tts1_hd', 'whisper', 'moderation',
67
+ 'embedding3_large', 'embedding3_small', 'embedding_ada' then nil
68
+ else 16_384 # rubocop:disable Lint/DuplicateBranch
33
69
  end
34
70
  end
35
71
 
36
- # Returns the input price per million tokens for the given model ID
37
- # @param model_id [String] the model identifier
38
- # @return [Float] the price per million tokens for input
39
- def input_price_for(model_id)
40
- PRICES.dig(model_family(model_id), :input) || default_input_price
72
+ def supports_vision?(model_id)
73
+ case model_family(model_id)
74
+ when 'chatgpt4o', 'gpt4', 'gpt4_turbo', 'gpt4o', 'gpt4o_mini', 'o1', 'o1_pro',
75
+ 'moderation', 'gpt4o_search', 'gpt4o_mini_search' then true
76
+ else false
77
+ end
41
78
  end
42
79
 
43
- # Returns the output price per million tokens for the given model ID
44
- # @param model_id [String] the model identifier
45
- # @return [Float] the price per million tokens for output
46
- def output_price_for(model_id)
47
- PRICES.dig(model_family(model_id), :output) || default_output_price
80
+ def supports_functions?(model_id)
81
+ case model_family(model_id)
82
+ when 'gpt4', 'gpt4_turbo', 'gpt4o', 'gpt4o_mini', 'o1', 'o1_pro', 'o3_mini' then true
83
+ when 'chatgpt4o', 'gpt35_turbo', 'o1_mini', 'gpt4o_mini_tts',
84
+ 'gpt4o_transcribe', 'gpt4o_search', 'gpt4o_mini_search' then false
85
+ else false # rubocop:disable Lint/DuplicateBranch
86
+ end
48
87
  end
49
88
 
50
- # Determines if the model supports vision capabilities
51
- # @param model_id [String] the model identifier
52
- # @return [Boolean] true if the model supports vision
53
- def supports_vision?(model_id)
54
- model_id.match?(/gpt-4o|o1/) || model_id.match?(/gpt-4-(?!0314|0613)/)
89
+ def supports_structured_output?(model_id)
90
+ case model_family(model_id)
91
+ when 'chatgpt4o', 'gpt4o', 'gpt4o_mini', 'o1', 'o1_pro', 'o3_mini' then true
92
+ else false
93
+ end
55
94
  end
56
95
 
57
- # Determines if the model supports function calling
58
- # @param model_id [String] the model identifier
59
- # @return [Boolean] true if the model supports functions
60
- def supports_functions?(model_id)
61
- !model_id.include?('instruct')
96
+ def supports_json_mode?(model_id)
97
+ supports_structured_output?(model_id)
62
98
  end
63
99
 
64
- # Determines if the model supports audio input/output
65
- # @param model_id [String] the model identifier
66
- # @return [Boolean] true if the model supports audio
67
- def supports_audio?(model_id)
68
- model_id.match?(/audio-preview|realtime-preview|whisper|tts/)
100
+ PRICES = {
101
+ chatgpt4o: { input: 5.0, output: 15.0 },
102
+ gpt4: { input: 10.0, output: 30.0 },
103
+ gpt4_turbo: { input: 10.0, output: 30.0 },
104
+ gpt45: { input: 75.0, output: 150.0 },
105
+ gpt35_turbo: { input: 0.5, output: 1.5 },
106
+ gpt4o: { input: 2.5, output: 10.0 },
107
+ gpt4o_audio: { input: 2.5, output: 10.0, audio_input: 40.0, audio_output: 80.0 },
108
+ gpt4o_mini: { input: 0.15, output: 0.6 },
109
+ gpt4o_mini_audio: { input: 0.15, output: 0.6, audio_input: 10.0, audio_output: 20.0 },
110
+ gpt4o_mini_realtime: { input: 0.6, output: 2.4 },
111
+ gpt4o_mini_transcribe: { input: 1.25, output: 5.0, audio_input: 3.0 },
112
+ gpt4o_mini_tts: { input: 0.6, output: 12.0 },
113
+ gpt4o_realtime: { input: 5.0, output: 20.0 },
114
+ gpt4o_search: { input: 2.5, output: 10.0 },
115
+ gpt4o_transcribe: { input: 2.5, output: 10.0, audio_input: 6.0 },
116
+ o1: { input: 15.0, output: 60.0 },
117
+ o1_mini: { input: 1.1, output: 4.4 },
118
+ o1_pro: { input: 150.0, output: 600.0 },
119
+ o3_mini: { input: 1.1, output: 4.4 },
120
+ babbage: { input: 0.4, output: 0.4 },
121
+ davinci: { input: 2.0, output: 2.0 },
122
+ embedding3_large: { price: 0.13 },
123
+ embedding3_small: { price: 0.02 },
124
+ embedding_ada: { price: 0.10 },
125
+ tts1: { price: 15.0 },
126
+ tts1_hd: { price: 30.0 },
127
+ whisper: { price: 0.006 },
128
+ moderation: { price: 0.0 }
129
+ }.freeze
130
+
131
+ def model_family(model_id)
132
+ MODEL_PATTERNS.each do |family, pattern|
133
+ return family.to_s if model_id.match?(pattern)
134
+ end
135
+ 'other'
69
136
  end
70
137
 
71
- # Determines if the model supports JSON mode
72
- # @param model_id [String] the model identifier
73
- # @return [Boolean] true if the model supports JSON mode
74
- def supports_json_mode?(model_id)
75
- model_id.match?(/gpt-4-\d{4}-preview/) ||
76
- model_id.include?('turbo') ||
77
- model_id.match?(/gpt-3.5-turbo-(?!0301|0613)/)
138
+ def input_price_for(model_id)
139
+ family = model_family(model_id).to_sym
140
+ prices = PRICES.fetch(family, { input: default_input_price })
141
+ prices[:input] || prices[:price] || default_input_price
78
142
  end
79
143
 
80
- # Formats the model ID into a human-readable display name
81
- # @param model_id [String] the model identifier
82
- # @return [String] the formatted display name
83
- def format_display_name(model_id)
84
- model_id.then { |id| humanize(id) }
85
- .then { |name| apply_special_formatting(name) }
144
+ def output_price_for(model_id)
145
+ family = model_family(model_id).to_sym
146
+ prices = PRICES.fetch(family, { output: default_output_price })
147
+ prices[:output] || prices[:price] || default_output_price
86
148
  end
87
149
 
88
- # Determines the type of model
89
- # @param model_id [String] the model identifier
90
- # @return [String] the model type (chat, embedding, image, audio, moderation)
91
150
  def model_type(model_id)
92
- case model_id
93
- when /text-embedding|embedding/ then 'embedding'
94
- when /dall-e/ then 'image'
95
- when /tts|whisper/ then 'audio'
96
- when /omni-moderation|text-moderation/ then 'moderation'
151
+ case model_family(model_id)
152
+ when /embedding/ then 'embedding'
153
+ when /^tts|whisper|gpt4o_(?:mini_)?(?:transcribe|tts)$/ then 'audio'
154
+ when 'moderation' then 'moderation'
155
+ when /dall/ then 'image'
97
156
  else 'chat'
98
157
  end
99
158
  end
100
159
 
101
- # Determines if the model supports structured output
102
- # @param model_id [String] the model identifier
103
- # @return [Boolean] true if the model supports structured output
104
- def supports_structured_output?(model_id)
105
- model_id.match?(/gpt-4o|o[13]-mini|o1|o3-mini/)
106
- end
107
-
108
- # Determines the model family for pricing and capability lookup
109
- # @param model_id [String] the model identifier
110
- # @return [Symbol] the model family identifier
111
- def model_family(model_id) # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metrics/MethodLength
112
- case model_id
113
- when /o3-mini/ then 'o3_mini'
114
- when /o1-mini/ then 'o1_mini'
115
- when /o1/ then 'o1'
116
- when /gpt-4o-audio/ then 'gpt4o_audio'
117
- when /gpt-4o-realtime/ then 'gpt4o_realtime'
118
- when /gpt-4o-mini-audio/ then 'gpt4o_mini_audio'
119
- when /gpt-4o-mini-realtime/ then 'gpt4o_mini_realtime'
120
- when /gpt-4o-mini/ then 'gpt4o_mini'
121
- when /gpt-4o/ then 'gpt4o'
122
- when /gpt-4-turbo/ then 'gpt4_turbo'
123
- when /gpt-4/ then 'gpt4'
124
- when /gpt-3.5-turbo-instruct/ then 'gpt35_instruct'
125
- when /gpt-3.5/ then 'gpt35'
126
- when /dall-e-3/ then 'dalle3'
127
- when /dall-e-2/ then 'dalle2'
128
- when /text-embedding-3-large/ then 'embedding3_large'
129
- when /text-embedding-3-small/ then 'embedding3_small'
130
- when /text-embedding-ada/ then 'embedding2'
131
- when /tts-1-hd/ then 'tts1_hd'
132
- when /tts-1/ then 'tts1'
133
- when /whisper/ then 'whisper1'
134
- when /omni-moderation|text-moderation/ then 'moderation'
135
- when /babbage/ then 'babbage'
136
- when /davinci/ then 'davinci'
137
- else 'other'
138
- end
139
- end
140
-
141
- # Pricing information for OpenAI models (per million tokens unless otherwise specified)
142
- PRICES = {
143
- o1: { input: 15.0, cached_input: 7.5, output: 60.0 },
144
- o1_mini: { input: 1.10, cached_input: 0.55, output: 4.40 },
145
- o3_mini: { input: 1.10, cached_input: 0.55, output: 4.40 },
146
- gpt4o: { input: 2.50, cached_input: 1.25, output: 10.0 },
147
- gpt4o_audio: {
148
- text_input: 2.50,
149
- audio_input: 40.0,
150
- text_output: 10.0,
151
- audio_output: 80.0
152
- },
153
- gpt4o_realtime: {
154
- text_input: 5.0,
155
- cached_text_input: 2.50,
156
- audio_input: 40.0,
157
- cached_audio_input: 2.50,
158
- text_output: 20.0,
159
- audio_output: 80.0
160
- },
161
- gpt4o_mini: { input: 0.15, cached_input: 0.075, output: 0.60 },
162
- gpt4o_mini_audio: {
163
- text_input: 0.15,
164
- audio_input: 10.0,
165
- text_output: 0.60,
166
- audio_output: 20.0
167
- },
168
- gpt4o_mini_realtime: {
169
- text_input: 0.60,
170
- cached_text_input: 0.30,
171
- audio_input: 10.0,
172
- cached_audio_input: 0.30,
173
- text_output: 2.40,
174
- audio_output: 20.0
175
- },
176
- gpt4_turbo: { input: 10.0, output: 30.0 },
177
- gpt4: { input: 30.0, output: 60.0 },
178
- gpt35: { input: 0.50, output: 1.50 },
179
- gpt35_instruct: { input: 1.50, output: 2.0 },
180
- embedding3_large: { price: 0.13 },
181
- embedding3_small: { price: 0.02 },
182
- embedding2: { price: 0.10 },
183
- davinci: { input: 2.0, output: 2.0 },
184
- babbage: { input: 0.40, output: 0.40 },
185
- tts1: { price: 15.0 }, # per million characters
186
- tts1_hd: { price: 30.0 }, # per million characters
187
- whisper1: { price: 0.006 }, # per minute
188
- moderation: { price: 0.0 } # free
189
- }.freeze
190
-
191
- # Default input price when model-specific pricing is not available
192
- # @return [Float] the default price per million tokens
193
160
  def default_input_price
194
161
  0.50
195
162
  end
196
163
 
197
- # Default output price when model-specific pricing is not available
198
- # @return [Float] the default price per million tokens
199
164
  def default_output_price
200
165
  1.50
201
166
  end
202
167
 
203
- # Converts a model ID to a human-readable format
204
- # @param id [String] the model identifier
205
- # @return [String] the humanized model name
168
+ def format_display_name(model_id)
169
+ model_id.then { |id| humanize(id) }
170
+ .then { |name| apply_special_formatting(name) }
171
+ end
172
+
206
173
  def humanize(id)
207
174
  id.tr('-', ' ')
208
175
  .split
@@ -210,25 +177,35 @@ module RubyLLM
210
177
  .join(' ')
211
178
  end
212
179
 
213
- # Applies special formatting rules to model names
214
- # @param name [String] the humanized model name
215
- # @return [String] the specially formatted model name
216
- def apply_special_formatting(name) # rubocop:disable Metrics/MethodLength
180
+ def apply_special_formatting(name)
217
181
  name
218
182
  .gsub(/(\d{4}) (\d{2}) (\d{2})/, '\1\2\3')
219
- .gsub(/^Gpt /, 'GPT-')
183
+ .gsub(/^(?:Gpt|Chatgpt|Tts|Dall E) /) { |m| special_prefix_format(m.strip) }
220
184
  .gsub(/^O([13]) /, 'O\1-')
221
- .gsub(/^O3 Mini/, 'O3-Mini')
222
- .gsub(/^O1 Mini/, 'O1-Mini')
223
- .gsub(/^Chatgpt /, 'ChatGPT-')
224
- .gsub(/^Tts /, 'TTS-')
225
- .gsub(/^Dall E /, 'DALL-E-')
226
- .gsub('3.5 ', '3.5-')
227
- .gsub('4 ', '4-')
228
- .gsub(/4o (?=Mini|Preview|Turbo|Audio|Realtime)/, '4o-')
185
+ .gsub(/^O[13] Mini/, '\0'.gsub(' ', '-'))
186
+ .gsub(/\d\.\d /, '\0'.sub(' ', '-'))
187
+ .gsub(/4o (?=Mini|Preview|Turbo|Audio|Realtime|Transcribe|Tts)/, '4o-')
229
188
  .gsub(/\bHd\b/, 'HD')
230
- .gsub('Omni Moderation', 'Omni-Moderation')
231
- .gsub('Text Moderation', 'Text-Moderation')
189
+ .gsub(/(?:Omni|Text) Moderation/, '\0'.gsub(' ', '-'))
190
+ .gsub('Text Embedding', 'text-embedding-')
191
+ end
192
+
193
+ def special_prefix_format(prefix)
194
+ case prefix # rubocop:disable Style/HashLikeCase
195
+ when 'Gpt' then 'GPT-'
196
+ when 'Chatgpt' then 'ChatGPT-'
197
+ when 'Tts' then 'TTS-'
198
+ when 'Dall E' then 'DALL-E-'
199
+ end
200
+ end
201
+
202
+ def normalize_temperature(temperature, model_id)
203
+ if model_id.match?(/^o[13]/)
204
+ RubyLLM.logger.debug "Model #{model_id} requires temperature=1.0, ignoring provided value"
205
+ 1.0
206
+ else
207
+ temperature
208
+ end
232
209
  end
233
210
  end
234
211
  end
@@ -11,19 +11,15 @@ module RubyLLM
11
11
  completion_url
12
12
  end
13
13
 
14
- def handle_stream(&block) # rubocop:disable Metrics/MethodLength
15
- to_json_stream do |data|
16
- block.call(
17
- Chunk.new(
18
- role: :assistant,
19
- model_id: data['model'],
20
- content: data.dig('choices', 0, 'delta', 'content'),
21
- tool_calls: parse_tool_calls(data.dig('choices', 0, 'delta', 'tool_calls'), parse_arguments: false),
22
- input_tokens: data.dig('usage', 'prompt_tokens'),
23
- output_tokens: data.dig('usage', 'completion_tokens')
24
- )
25
- )
26
- end
14
+ def build_chunk(data)
15
+ Chunk.new(
16
+ role: :assistant,
17
+ model_id: data['model'],
18
+ content: data.dig('choices', 0, 'delta', 'content'),
19
+ tool_calls: parse_tool_calls(data.dig('choices', 0, 'delta', 'tool_calls'), parse_arguments: false),
20
+ input_tokens: data.dig('usage', 'prompt_tokens'),
21
+ output_tokens: data.dig('usage', 'completion_tokens')
22
+ )
27
23
  end
28
24
  end
29
25
  end
@@ -45,6 +45,10 @@ module RubyLLM
45
45
  def slug
46
46
  'openai'
47
47
  end
48
+
49
+ def configuration_requirements
50
+ %i[openai_api_key]
51
+ end
48
52
  end
49
53
  end
50
54
  end
@@ -0,0 +1,96 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ # Handles streaming responses from AI providers. Provides a unified way to process
5
+ # chunked responses, accumulate content, and handle provider-specific streaming formats.
6
+ # Each provider implements provider-specific parsing while sharing common stream handling
7
+ # patterns.
8
+ module Streaming
9
+ module_function
10
+
11
+ def stream_response(payload, &block)
12
+ accumulator = StreamAccumulator.new
13
+
14
+ post stream_url, payload do |req|
15
+ req.options.on_data = handle_stream do |chunk|
16
+ accumulator.add chunk
17
+ block.call chunk
18
+ end
19
+ end
20
+
21
+ accumulator.to_message
22
+ end
23
+
24
+ def handle_stream(&block)
25
+ to_json_stream do |data|
26
+ block.call(build_chunk(data)) if data
27
+ end
28
+ end
29
+
30
+ private
31
+
32
+ def to_json_stream(&block) # rubocop:disable Metrics/MethodLength
33
+ buffer = String.new
34
+ parser = EventStreamParser::Parser.new
35
+
36
+ proc do |chunk, _bytes, env|
37
+ RubyLLM.logger.debug "Received chunk: #{chunk}"
38
+
39
+ if error_chunk?(chunk)
40
+ handle_error_chunk(chunk, env)
41
+ elsif env&.status != 200
42
+ handle_failed_response(chunk, buffer, env)
43
+ else
44
+ yield handle_sse(chunk, parser, env, &block)
45
+ end
46
+ end
47
+ end
48
+
49
+ def error_chunk?(chunk)
50
+ chunk.start_with?('event: error')
51
+ end
52
+
53
+ def handle_error_chunk(chunk, env)
54
+ error_data = chunk.split("\n")[1].delete_prefix('data: ')
55
+ status, _message = parse_streaming_error(error_data)
56
+ error_response = env.merge(body: JSON.parse(error_data), status: status)
57
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
58
+ rescue JSON::ParserError => e
59
+ RubyLLM.logger.debug "Failed to parse error chunk: #{e.message}"
60
+ end
61
+
62
+ def handle_failed_response(chunk, buffer, env)
63
+ buffer << chunk
64
+ error_data = JSON.parse(buffer)
65
+ error_response = env.merge(body: error_data)
66
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
67
+ rescue JSON::ParserError
68
+ RubyLLM.logger.debug "Accumulating error chunk: #{chunk}"
69
+ end
70
+
71
+ def handle_sse(chunk, parser, env, &block)
72
+ parser.feed(chunk) do |type, data|
73
+ case type.to_sym
74
+ when :error
75
+ handle_error_event(data, env)
76
+ else
77
+ yield handle_data(data, &block) unless data == '[DONE]'
78
+ end
79
+ end
80
+ end
81
+
82
+ def handle_data(data)
83
+ JSON.parse(data)
84
+ rescue JSON::ParserError => e
85
+ RubyLLM.logger.debug "Failed to parse data chunk: #{e.message}"
86
+ end
87
+
88
+ def handle_error_event(data, env)
89
+ status, _message = parse_streaming_error(data)
90
+ error_response = env.merge(body: JSON.parse(data), status: status)
91
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
92
+ rescue JSON::ParserError => e
93
+ RubyLLM.logger.debug "Failed to parse error event: #{e.message}"
94
+ end
95
+ end
96
+ end
data/lib/ruby_llm/tool.rb CHANGED
@@ -18,14 +18,19 @@ module RubyLLM
18
18
  # interface for defining parameters and implementing tool behavior.
19
19
  #
20
20
  # Example:
21
- # class Calculator < RubyLLM::Tool
22
- # description "Performs arithmetic calculations"
23
- # param :expression, type: :string, desc: "Math expression to evaluate"
21
+ # require 'tzinfo'
24
22
  #
25
- # def execute(expression:)
26
- # eval(expression).to_s
27
- # end
28
- # end
23
+ # class TimeInfo < RubyLLM::Tool
24
+ # description 'Gets the current time in various timezones'
25
+ # param :timezone, desc: "Timezone name (e.g., 'UTC', 'America/New_York')"
26
+ #
27
+ # def execute(timezone:)
28
+ # time = TZInfo::Timezone.get(timezone).now.strftime('%Y-%m-%d %H:%M:%S')
29
+ # "Current time in #{timezone}: #{time}"
30
+ # rescue StandardError => e
31
+ # { error: e.message }
32
+ # end
33
+ # end
29
34
  class Tool
30
35
  class << self
31
36
  def description(text = nil)
@@ -45,6 +50,9 @@ module RubyLLM
45
50
 
46
51
  def name
47
52
  self.class.name
53
+ .unicode_normalize(:nfkd)
54
+ .encode('ASCII', replace: '')
55
+ .gsub(/[^a-zA-Z0-9_-]/, '-')
48
56
  .gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
49
57
  .gsub(/([a-z\d])([A-Z])/, '\1_\2')
50
58
  .downcase
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module RubyLLM
4
- VERSION = '1.0.0'
4
+ VERSION = '1.1.0rc1'
5
5
  end
data/lib/ruby_llm.rb CHANGED
@@ -15,8 +15,12 @@ loader.inflector.inflect(
15
15
  'llm' => 'LLM',
16
16
  'openai' => 'OpenAI',
17
17
  'api' => 'API',
18
- 'deepseek' => 'DeepSeek'
18
+ 'deepseek' => 'DeepSeek',
19
+ 'bedrock' => 'Bedrock'
19
20
  )
21
+ loader.ignore("#{__dir__}/tasks")
22
+ loader.ignore("#{__dir__}/ruby_llm/railtie")
23
+ loader.ignore("#{__dir__}/ruby_llm/active_record")
20
24
  loader.setup
21
25
 
22
26
  # A delightful Ruby interface to modern AI language models.
@@ -26,8 +30,8 @@ module RubyLLM
26
30
  class Error < StandardError; end
27
31
 
28
32
  class << self
29
- def chat(model: nil)
30
- Chat.new(model: model)
33
+ def chat(model: nil, provider: nil)
34
+ Chat.new(model: model, provider: provider)
31
35
  end
32
36
 
33
37
  def embed(...)
@@ -68,6 +72,7 @@ RubyLLM::Provider.register :openai, RubyLLM::Providers::OpenAI
68
72
  RubyLLM::Provider.register :anthropic, RubyLLM::Providers::Anthropic
69
73
  RubyLLM::Provider.register :gemini, RubyLLM::Providers::Gemini
70
74
  RubyLLM::Provider.register :deepseek, RubyLLM::Providers::DeepSeek
75
+ RubyLLM::Provider.register :bedrock, RubyLLM::Providers::Bedrock
71
76
 
72
77
  if defined?(Rails::Railtie)
73
78
  require 'ruby_llm/railtie'