ruby-pi 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,314 @@
1
+ # frozen_string_literal: true
2
+
3
+ # lib/ruby_pi/llm/anthropic.rb
4
+ #
5
+ # LLM provider for Anthropic Claude. Implements the BaseProvider interface using
6
+ # the Anthropic Messages API for both synchronous and streaming completions,
7
+ # including tool_use block support.
8
+
9
+ module RubyPi
10
+ module LLM
11
+ # Anthropic Claude provider implementation. Communicates with the Anthropic
12
+ # Messages API to generate text completions, handle tool_use blocks, and
13
+ # stream responses via Server-Sent Events.
14
+ #
15
+ # @example Basic usage
16
+ # provider = RubyPi::LLM::Anthropic.new(
17
+ # model: "claude-sonnet-4-20250514",
18
+ # api_key: ENV["ANTHROPIC_API_KEY"]
19
+ # )
20
+ # response = provider.complete(messages: [{ role: "user", content: "Hello!" }])
21
+ # puts response.content
22
+ class Anthropic < BaseProvider
23
+ # Base URL for the Anthropic Messages API.
24
+ BASE_URL = "https://api.anthropic.com"
25
+
26
+ # Anthropic API version header value.
27
+ API_VERSION = "2023-06-01"
28
+
29
+ # Default maximum tokens for a response.
30
+ DEFAULT_MAX_TOKENS = 4096
31
+
32
+ # Creates a new Anthropic provider instance.
33
+ #
34
+ # @param model [String] the Claude model identifier (e.g., "claude-sonnet-4-20250514")
35
+ # @param api_key [String, nil] Anthropic API key (falls back to global config)
36
+ # @param max_tokens [Integer] maximum tokens to generate (default: 4096)
37
+ # @param options [Hash] additional options passed to BaseProvider
38
+ def initialize(model: nil, api_key: nil, max_tokens: DEFAULT_MAX_TOKENS, **options)
39
+ super(**options)
40
+ config = RubyPi.configuration
41
+ @model = model || config.default_anthropic_model
42
+ @api_key = api_key || config.anthropic_api_key
43
+ @max_tokens = max_tokens
44
+ end
45
+
46
+ # Returns the Claude model identifier.
47
+ #
48
+ # @return [String]
49
+ def model_name
50
+ @model
51
+ end
52
+
53
+ # Returns :anthropic as the provider identifier.
54
+ #
55
+ # @return [Symbol]
56
+ def provider_name
57
+ :anthropic
58
+ end
59
+
60
+ private
61
+
62
+ # Performs the completion request against the Anthropic API.
63
+ #
64
+ # @param messages [Array<Hash>] conversation messages
65
+ # @param tools [Array<Hash>] tool definitions
66
+ # @param stream [Boolean] whether to use streaming
67
+ # @yield [event] streaming events if stream is true
68
+ # @return [RubyPi::LLM::Response]
69
+ def perform_complete(messages:, tools:, stream:, &block)
70
+ body = build_request_body(messages, tools, stream)
71
+
72
+ if stream && block_given?
73
+ perform_streaming_request(body, &block)
74
+ else
75
+ perform_standard_request(body)
76
+ end
77
+ end
78
+
79
+ # Builds the Anthropic API request body from messages and tools.
80
+ #
81
+ # @param messages [Array<Hash>] conversation messages
82
+ # @param tools [Array<Hash>] tool definitions
83
+ # @param stream [Boolean] whether streaming is enabled
84
+ # @return [Hash] the request body
85
+ def build_request_body(messages, tools, stream)
86
+ # Separate system message from conversation messages
87
+ system_message = nil
88
+ conversation = []
89
+
90
+ messages.each do |msg|
91
+ role = (msg[:role] || msg["role"]).to_s
92
+ content = msg[:content] || msg["content"]
93
+
94
+ if role == "system"
95
+ system_message = content.to_s
96
+ else
97
+ conversation << { role: role, content: content.to_s }
98
+ end
99
+ end
100
+
101
+ body = {
102
+ model: @model,
103
+ max_tokens: @max_tokens,
104
+ messages: conversation
105
+ }
106
+
107
+ body[:system] = system_message if system_message
108
+ body[:stream] = true if stream
109
+
110
+ unless tools.empty?
111
+ body[:tools] = tools.map { |t| format_tool(t) }
112
+ end
113
+
114
+ body
115
+ end
116
+
117
+ # Converts a tool definition to Anthropic's tool format.
118
+ # Accepts either a RubyPi::Tools::Definition or a plain Hash.
119
+ #
120
+ # @param tool [RubyPi::Tools::Definition, Hash] tool definition
121
+ # @return [Hash] Anthropic tool definition
122
+ def format_tool(tool)
123
+ return tool.to_anthropic_format if tool.respond_to?(:to_anthropic_format)
124
+
125
+ {
126
+ name: tool[:name] || tool["name"],
127
+ description: tool[:description] || tool["description"] || "",
128
+ input_schema: tool[:parameters] || tool["parameters"] || { type: "object", properties: {} }
129
+ }
130
+ end
131
+
132
+ # Executes a standard (non-streaming) request to the Anthropic API.
133
+ #
134
+ # @param body [Hash] the request body
135
+ # @return [RubyPi::LLM::Response]
136
+ def perform_standard_request(body)
137
+ conn = build_connection(
138
+ base_url: BASE_URL,
139
+ headers: default_headers
140
+ )
141
+
142
+ response = conn.post("/v1/messages") do |req|
143
+ req.headers["Content-Type"] = "application/json"
144
+ req.body = JSON.generate(body)
145
+ end
146
+
147
+ handle_error_response(response) unless response.success?
148
+ parse_response(JSON.parse(response.body))
149
+ end
150
+
151
+ # Executes a streaming request to the Anthropic API, yielding events.
152
+ #
153
+ # @param body [Hash] the request body
154
+ # @yield [event] StreamEvent objects
155
+ # @return [RubyPi::LLM::Response] final aggregated response
156
+ def perform_streaming_request(body, &block)
157
+ conn = build_connection(
158
+ base_url: BASE_URL,
159
+ headers: default_headers
160
+ )
161
+
162
+ accumulated_text = +""
163
+ accumulated_tool_calls = []
164
+ current_tool_call = nil
165
+ current_tool_json = +""
166
+ usage_data = {}
167
+ finish_reason = nil
168
+
169
+ response = conn.post("/v1/messages") do |req|
170
+ req.headers["Content-Type"] = "application/json"
171
+ req.body = JSON.generate(body)
172
+ end
173
+
174
+ handle_error_response(response) unless response.success?
175
+
176
+ # Parse SSE events from the response body
177
+ parse_sse_events(response.body) do |data|
178
+ event_type = data["type"]
179
+
180
+ case event_type
181
+ when "content_block_start"
182
+ content_block = data["content_block"] || {}
183
+ if content_block["type"] == "tool_use"
184
+ current_tool_call = {
185
+ id: content_block["id"],
186
+ name: content_block["name"]
187
+ }
188
+ current_tool_json = +""
189
+ end
190
+
191
+ when "content_block_delta"
192
+ delta = data["delta"] || {}
193
+ if delta["type"] == "text_delta"
194
+ text = delta["text"] || ""
195
+ accumulated_text << text
196
+ block.call(StreamEvent.new(type: :text_delta, data: text))
197
+ elsif delta["type"] == "input_json_delta"
198
+ json_chunk = delta["partial_json"] || ""
199
+ current_tool_json << json_chunk
200
+ block.call(StreamEvent.new(type: :tool_call_delta, data: {
201
+ id: current_tool_call&.dig(:id),
202
+ partial_json: json_chunk
203
+ }))
204
+ end
205
+
206
+ when "content_block_stop"
207
+ if current_tool_call
208
+ arguments = current_tool_json.empty? ? {} : JSON.parse(current_tool_json)
209
+ accumulated_tool_calls << ToolCall.new(
210
+ id: current_tool_call[:id],
211
+ name: current_tool_call[:name],
212
+ arguments: arguments
213
+ )
214
+ current_tool_call = nil
215
+ current_tool_json = +""
216
+ end
217
+
218
+ when "message_delta"
219
+ delta = data["delta"] || {}
220
+ finish_reason = delta["stop_reason"]
221
+ if data.key?("usage")
222
+ usage_info = data["usage"]
223
+ usage_data[:completion_tokens] = usage_info["output_tokens"]
224
+ end
225
+
226
+ when "message_start"
227
+ if data.dig("message", "usage")
228
+ usage_info = data["message"]["usage"]
229
+ usage_data[:prompt_tokens] = usage_info["input_tokens"]
230
+ end
231
+ end
232
+ end
233
+
234
+ # Signal completion
235
+ block.call(StreamEvent.new(type: :done))
236
+
237
+ # Calculate total tokens
238
+ if usage_data[:prompt_tokens] && usage_data[:completion_tokens]
239
+ usage_data[:total_tokens] = usage_data[:prompt_tokens] + usage_data[:completion_tokens]
240
+ end
241
+
242
+ Response.new(
243
+ content: accumulated_text.empty? ? nil : accumulated_text,
244
+ tool_calls: accumulated_tool_calls,
245
+ usage: usage_data,
246
+ finish_reason: normalize_finish_reason(finish_reason)
247
+ )
248
+ end
249
+
250
+ # Returns the default HTTP headers required by the Anthropic API.
251
+ #
252
+ # @return [Hash] headers hash
253
+ def default_headers
254
+ {
255
+ "x-api-key" => @api_key.to_s,
256
+ "anthropic-version" => API_VERSION
257
+ }
258
+ end
259
+
260
+ # Parses an Anthropic API response hash into a normalized Response object.
261
+ #
262
+ # @param data [Hash] parsed JSON response from Anthropic
263
+ # @return [RubyPi::LLM::Response]
264
+ def parse_response(data)
265
+ content = nil
266
+ tool_calls = []
267
+
268
+ (data["content"] || []).each do |block|
269
+ case block["type"]
270
+ when "text"
271
+ content = (content || +"") << block["text"]
272
+ when "tool_use"
273
+ tool_calls << ToolCall.new(
274
+ id: block["id"],
275
+ name: block["name"],
276
+ arguments: block["input"] || {}
277
+ )
278
+ end
279
+ end
280
+
281
+ # Extract usage
282
+ usage = {}
283
+ if data.key?("usage")
284
+ usage_info = data["usage"]
285
+ usage = {
286
+ prompt_tokens: usage_info["input_tokens"],
287
+ completion_tokens: usage_info["output_tokens"],
288
+ total_tokens: (usage_info["input_tokens"] || 0) + (usage_info["output_tokens"] || 0)
289
+ }
290
+ end
291
+
292
+ Response.new(
293
+ content: content,
294
+ tool_calls: tool_calls,
295
+ usage: usage,
296
+ finish_reason: normalize_finish_reason(data["stop_reason"])
297
+ )
298
+ end
299
+
300
+ # Normalizes Anthropic-specific finish reasons to common values.
301
+ #
302
+ # @param reason [String, nil] Anthropic stop reason
303
+ # @return [String, nil] normalized finish reason
304
+ def normalize_finish_reason(reason)
305
+ case reason
306
+ when "end_turn" then "stop"
307
+ when "tool_use" then "tool_calls"
308
+ when "max_tokens" then "max_tokens"
309
+ else reason
310
+ end
311
+ end
312
+ end
313
+ end
314
+ end
@@ -0,0 +1,220 @@
1
+ # frozen_string_literal: true
2
+
3
+ # lib/ruby_pi/llm/base_provider.rb
4
+ #
5
+ # Abstract base class for all LLM providers. Implements shared concerns such as
6
+ # retry logic with exponential backoff and a consistent public interface. Concrete
7
+ # providers (Gemini, Anthropic, OpenAI) must subclass this and implement the
8
+ # abstract methods.
9
+
10
+ module RubyPi
11
+ module LLM
12
+ # Abstract base class that defines the contract every LLM provider must
13
+ # fulfill. Provides built-in retry logic with exponential backoff for
14
+ # transient errors and a unified #complete interface for both synchronous
15
+ # and streaming completions.
16
+ #
17
+ # Subclasses MUST implement:
18
+ # - #perform_complete(messages:, tools:, stream:, &block)
19
+ # - #model_name
20
+ # - #provider_name
21
+ #
22
+ # @example Subclass implementation
23
+ # class MyProvider < RubyPi::LLM::BaseProvider
24
+ # def model_name = "my-model"
25
+ # def provider_name = :my_provider
26
+ #
27
+ # private
28
+ # def perform_complete(messages:, tools:, stream:, &block)
29
+ # # Make HTTP request and return RubyPi::LLM::Response
30
+ # end
31
+ # end
32
+ class BaseProvider
33
+ # @return [Integer] maximum number of retry attempts
34
+ attr_reader :max_retries
35
+
36
+ # @return [Float] base delay in seconds for exponential backoff
37
+ attr_reader :retry_base_delay
38
+
39
+ # @return [Float] maximum delay in seconds between retries
40
+ attr_reader :retry_max_delay
41
+
42
+ # Initializes the base provider with retry configuration.
43
+ #
44
+ # @param max_retries [Integer, nil] override max retries (defaults to global config)
45
+ # @param retry_base_delay [Float, nil] override base delay (defaults to global config)
46
+ # @param retry_max_delay [Float, nil] override max delay (defaults to global config)
47
+ def initialize(max_retries: nil, retry_base_delay: nil, retry_max_delay: nil)
48
+ config = RubyPi.configuration
49
+ @max_retries = max_retries || config.max_retries
50
+ @retry_base_delay = retry_base_delay || config.retry_base_delay
51
+ @retry_max_delay = retry_max_delay || config.retry_max_delay
52
+ end
53
+
54
+ # Sends a completion request to the LLM provider with automatic retry
55
+ # logic for transient errors. When stream is true and a block is given,
56
+ # yields StreamEvent objects incrementally as they arrive.
57
+ #
58
+ # @param messages [Array<Hash>] conversation messages, each with :role and :content
59
+ # @param tools [Array<Hash>] tool/function definitions for the model
60
+ # @param stream [Boolean] whether to enable streaming mode
61
+ # @yield [event] yields StreamEvent objects when streaming
62
+ # @yieldparam event [RubyPi::LLM::StreamEvent] a stream event
63
+ # @return [RubyPi::LLM::Response] the normalized response
64
+ # @raise [RubyPi::AuthenticationError] on 401/403 responses
65
+ # @raise [RubyPi::RateLimitError] on 429 responses (after retries exhausted)
66
+ # @raise [RubyPi::ApiError] on other HTTP errors (after retries exhausted)
67
+ # @raise [RubyPi::TimeoutError] on request timeouts
68
+ def complete(messages:, tools: [], stream: false, &block)
69
+ attempt = 0
70
+
71
+ begin
72
+ attempt += 1
73
+ perform_complete(messages: messages, tools: tools, stream: stream, &block)
74
+ rescue RubyPi::AuthenticationError
75
+ # Authentication errors are not retryable — raise immediately
76
+ raise
77
+ rescue RubyPi::RateLimitError, RubyPi::ApiError, RubyPi::TimeoutError => e
78
+ if attempt < @max_retries
79
+ delay = calculate_backoff(attempt)
80
+ log_retry(attempt, delay, e)
81
+ sleep(delay)
82
+ retry
83
+ else
84
+ raise
85
+ end
86
+ end
87
+ end
88
+
89
+ # Returns the model name used by this provider instance.
90
+ # Subclasses MUST override this method.
91
+ #
92
+ # @return [String] the model identifier
93
+ # @raise [RubyPi::NotImplementedError] if not overridden
94
+ def model_name
95
+ raise RubyPi::NotImplementedError, :model_name
96
+ end
97
+
98
+ # Returns the provider identifier.
99
+ # Subclasses MUST override this method.
100
+ #
101
+ # @return [Symbol] the provider identifier (e.g., :gemini, :anthropic, :openai)
102
+ # @raise [RubyPi::NotImplementedError] if not overridden
103
+ def provider_name
104
+ raise RubyPi::NotImplementedError, :provider_name
105
+ end
106
+
107
+ private
108
+
109
+ # Performs the actual completion request. Subclasses MUST implement this
110
+ # method with provider-specific HTTP logic.
111
+ #
112
+ # @param messages [Array<Hash>] conversation messages
113
+ # @param tools [Array<Hash>] tool definitions
114
+ # @param stream [Boolean] streaming mode flag
115
+ # @yield [event] optional block for streaming events
116
+ # @return [RubyPi::LLM::Response]
117
+ def perform_complete(messages:, tools:, stream:, &block)
118
+ raise RubyPi::NotImplementedError, :perform_complete
119
+ end
120
+
121
+ # Calculates the backoff delay for a given retry attempt using
122
+ # exponential backoff with jitter.
123
+ #
124
+ # @param attempt [Integer] the current attempt number (1-based)
125
+ # @return [Float] delay in seconds
126
+ def calculate_backoff(attempt)
127
+ base = @retry_base_delay * (2**(attempt - 1))
128
+ jitter = rand * @retry_base_delay * 0.5
129
+ [base + jitter, @retry_max_delay].min
130
+ end
131
+
132
+ # Logs a retry attempt if a logger is configured.
133
+ #
134
+ # @param attempt [Integer] current attempt number
135
+ # @param delay [Float] delay before next retry
136
+ # @param error [Exception] the error that triggered the retry
137
+ # @return [void]
138
+ def log_retry(attempt, delay, error)
139
+ logger = RubyPi.configuration.logger
140
+ return unless logger
141
+
142
+ logger.warn(
143
+ "[RubyPi::#{provider_name}] Retry #{attempt}/#{@max_retries} " \
144
+ "after #{delay.round(2)}s — #{error.class}: #{error.message}"
145
+ )
146
+ end
147
+
148
+ # Builds a Faraday connection with retry middleware and standard settings.
149
+ #
150
+ # @param base_url [String] the base URL for the API
151
+ # @param headers [Hash] default headers for all requests
152
+ # @return [Faraday::Connection]
153
+ def build_connection(base_url:, headers: {})
154
+ config = RubyPi.configuration
155
+
156
+ Faraday.new(url: base_url) do |conn|
157
+ conn.headers.update(headers)
158
+ conn.options.timeout = config.request_timeout
159
+ conn.options.open_timeout = config.open_timeout
160
+ conn.adapter :net_http
161
+ end
162
+ end
163
+
164
+ # Handles HTTP error responses by raising the appropriate RubyPi error.
165
+ #
166
+ # @param response [Faraday::Response] the HTTP response
167
+ # @raise [RubyPi::AuthenticationError] on 401 or 403
168
+ # @raise [RubyPi::RateLimitError] on 429
169
+ # @raise [RubyPi::ApiError] on other error status codes
170
+ def handle_error_response(response)
171
+ case response.status
172
+ when 401, 403
173
+ raise RubyPi::AuthenticationError.new(
174
+ "#{provider_name} authentication failed (HTTP #{response.status})",
175
+ response_body: response.body
176
+ )
177
+ when 429
178
+ retry_after = response.headers["retry-after"]&.to_f
179
+ raise RubyPi::RateLimitError.new(
180
+ "#{provider_name} rate limit exceeded (HTTP 429)",
181
+ retry_after: retry_after,
182
+ response_body: response.body
183
+ )
184
+ else
185
+ raise RubyPi::ApiError.new(
186
+ "#{provider_name} API error (HTTP #{response.status})",
187
+ status_code: response.status,
188
+ response_body: response.body
189
+ )
190
+ end
191
+ end
192
+
193
+ # Processes a streaming response body line by line, parsing SSE events.
194
+ # Yields parsed data hashes to the provided block.
195
+ #
196
+ # @param response_body [String] the raw SSE response body
197
+ # @yield [data] parsed SSE event data
198
+ # @yieldparam data [Hash] a parsed JSON event payload
199
+ # @return [void]
200
+ def parse_sse_events(response_body, &block)
201
+ response_body.each_line do |line|
202
+ line = line.strip
203
+ next if line.empty?
204
+ next unless line.start_with?("data: ")
205
+
206
+ data_str = line.sub(/\Adata: /, "")
207
+ next if data_str == "[DONE]"
208
+
209
+ begin
210
+ data = JSON.parse(data_str)
211
+ block.call(data)
212
+ rescue JSON::ParserError
213
+ # Skip malformed SSE data lines
214
+ next
215
+ end
216
+ end
217
+ end
218
+ end
219
+ end
220
+ end
@@ -0,0 +1,96 @@
1
+ # frozen_string_literal: true
2
+
3
+ # lib/ruby_pi/llm/fallback.rb
4
+ #
5
+ # Provides automatic failover between LLM providers. Wraps a primary provider
6
+ # with one or more fallback providers. If the primary fails with a retryable
7
+ # error, the Fallback wrapper automatically routes the request to the next
8
+ # available provider.
9
+
10
+ module RubyPi
11
+ module LLM
12
+ # A resilient provider wrapper that tries a primary provider first and
13
+ # automatically falls back to an alternative provider on failure. Both
14
+ # providers must conform to the BaseProvider interface.
15
+ #
16
+ # Authentication errors are NOT retried with the fallback since they
17
+ # indicate a configuration problem rather than a transient failure.
18
+ #
19
+ # @example Setting up a fallback chain
20
+ # primary = RubyPi::LLM.model(:gemini, "gemini-2.0-flash")
21
+ # backup = RubyPi::LLM.model(:openai, "gpt-4o")
22
+ # provider = RubyPi::LLM::Fallback.new(primary: primary, fallback: backup)
23
+ #
24
+ # # If Gemini fails, automatically retries with OpenAI
25
+ # response = provider.complete(messages: messages)
26
+ class Fallback < BaseProvider
27
+ # @return [RubyPi::LLM::BaseProvider] the primary provider
28
+ attr_reader :primary
29
+
30
+ # @return [RubyPi::LLM::BaseProvider] the fallback provider
31
+ attr_reader :fallback
32
+
33
+ # Creates a new Fallback wrapper with a primary and fallback provider.
34
+ #
35
+ # @param primary [RubyPi::LLM::BaseProvider] the preferred provider
36
+ # @param fallback [RubyPi::LLM::BaseProvider] the backup provider
37
+ # @param options [Hash] additional options passed to BaseProvider
38
+ def initialize(primary:, fallback:, **options)
39
+ super(**options)
40
+ @primary = primary
41
+ @fallback = fallback
42
+ end
43
+
44
+ # Returns the model name of the primary provider.
45
+ #
46
+ # @return [String]
47
+ def model_name
48
+ @primary.model_name
49
+ end
50
+
51
+ # Returns :fallback as the provider identifier.
52
+ #
53
+ # @return [Symbol]
54
+ def provider_name
55
+ :fallback
56
+ end
57
+
58
+ private
59
+
60
+ # Attempts the completion with the primary provider. If it fails with
61
+ # a retryable error (ApiError, RateLimitError, TimeoutError, ProviderError),
62
+ # the request is retried with the fallback provider. Authentication errors
63
+ # propagate immediately since they indicate misconfiguration.
64
+ #
65
+ # @param messages [Array<Hash>] conversation messages
66
+ # @param tools [Array<Hash>] tool definitions
67
+ # @param stream [Boolean] streaming mode flag
68
+ # @yield [event] optional block for streaming events
69
+ # @return [RubyPi::LLM::Response]
70
+ def perform_complete(messages:, tools:, stream:, &block)
71
+ @primary.complete(messages: messages, tools: tools, stream: stream, &block)
72
+ rescue RubyPi::AuthenticationError
73
+ # Configuration errors should not trigger fallback
74
+ raise
75
+ rescue RubyPi::Error => e
76
+ log_fallback(e)
77
+ @fallback.complete(messages: messages, tools: tools, stream: stream, &block)
78
+ end
79
+
80
+ # Logs the fallback event if a logger is configured.
81
+ #
82
+ # @param error [Exception] the error that triggered the fallback
83
+ # @return [void]
84
+ def log_fallback(error)
85
+ logger = RubyPi.configuration.logger
86
+ return unless logger
87
+
88
+ logger.warn(
89
+ "[RubyPi::Fallback] Primary provider (#{@primary.provider_name}/#{@primary.model_name}) " \
90
+ "failed with #{error.class}: #{error.message}. " \
91
+ "Falling back to #{@fallback.provider_name}/#{@fallback.model_name}."
92
+ )
93
+ end
94
+ end
95
+ end
96
+ end