llm_gateway 0.5.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +38 -0
  3. data/README.md +350 -43
  4. data/docs/migration_guide_0.6.0.md +386 -0
  5. data/docs/migration_guide_0.7.0.md +193 -0
  6. data/lib/llm_gateway/adapters/adapter.rb +8 -11
  7. data/lib/llm_gateway/adapters/anthropic/input_mapper.rb +24 -0
  8. data/lib/llm_gateway/adapters/anthropic/stream_mapper.rb +61 -11
  9. data/lib/llm_gateway/adapters/anthropic_option_mapper.rb +1 -1
  10. data/lib/llm_gateway/adapters/groq/option_mapper.rb +1 -1
  11. data/lib/llm_gateway/adapters/input_message_sanitizer.rb +98 -7
  12. data/lib/llm_gateway/adapters/normalized_stream_accumulator.rb +132 -39
  13. data/lib/llm_gateway/adapters/openai/chat_completions/option_mapper.rb +1 -1
  14. data/lib/llm_gateway/adapters/openai/chat_completions/stream_mapper.rb +40 -16
  15. data/lib/llm_gateway/adapters/openai/responses/input_mapper.rb +47 -31
  16. data/lib/llm_gateway/adapters/openai/responses/option_mapper.rb +1 -1
  17. data/lib/llm_gateway/adapters/openai/responses/stream_mapper.rb +173 -24
  18. data/lib/llm_gateway/adapters/stream_mapper.rb +9 -2
  19. data/lib/llm_gateway/adapters/structs.rb +140 -55
  20. data/lib/llm_gateway/agents/event.rb +105 -0
  21. data/lib/llm_gateway/agents/file_session_manager.rb +100 -0
  22. data/lib/llm_gateway/agents/harness.rb +176 -0
  23. data/lib/llm_gateway/agents/in_memory_session_manager.rb +222 -0
  24. data/lib/llm_gateway/agents/tools/bash_tool.rb +132 -0
  25. data/lib/llm_gateway/agents/tools/edit_tool.rb +215 -0
  26. data/lib/llm_gateway/agents/tools/read_tool.rb +143 -0
  27. data/lib/llm_gateway/agents/tools/tool_utils.rb +164 -0
  28. data/lib/llm_gateway/agents/tools/write_tool.rb +34 -0
  29. data/lib/llm_gateway/base_client.rb +5 -7
  30. data/lib/llm_gateway/clients/anthropic.rb +10 -9
  31. data/lib/llm_gateway/clients/claude_code/oauth_flow.rb +2 -2
  32. data/lib/llm_gateway/clients/groq.rb +8 -6
  33. data/lib/llm_gateway/clients/openai.rb +22 -20
  34. data/lib/llm_gateway/clients/openai_codex/oauth_flow.rb +4 -4
  35. data/lib/llm_gateway/prompt.rb +107 -52
  36. data/lib/llm_gateway/utils.rb +116 -13
  37. data/lib/llm_gateway/version.rb +1 -1
  38. data/lib/llm_gateway.rb +7 -21
  39. metadata +13 -2
@@ -0,0 +1,143 @@
1
+ require "base64"
2
+ require_relative "tool_utils"
3
+
4
+ class ReadTool < LlmGateway::Tool
5
+ # Pi adaptation notes:
6
+ # - Keep offset/limit schema as integer: gruv treats integer and number schemas equivalently for line counts.
7
+ # - Do not add pi's image resize/model-omission behavior: current LLMs allow larger images than pi's conservative limit, and gruv tools do not receive model capability context.
8
+ # - Do not add pi's compact read UI, pluggable operations, AbortSignal handling, or details metadata: those are UI/runtime extension concerns outside this tool contract.
9
+ name "read"
10
+ description "Read the contents of a file. Supports text files and images (jpg, png, gif, webp). Images are sent as attachments. For text files, output is truncated to #{ToolUtils::DEFAULT_MAX_LINES} lines or #{ToolUtils::DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Use offset/limit for large files. When you need the full file, continue with offset until complete."
11
+ input_schema({
12
+ type: "object",
13
+ properties: {
14
+ path: { type: "string", description: "Path to the file to read (relative or absolute)" },
15
+ offset: { type: "integer", description: "Line number to start reading from (1-indexed)" },
16
+ limit: { type: "integer", description: "Maximum number of lines to read" }
17
+ },
18
+ required: [ "path" ]
19
+ })
20
+
21
+ IMAGE_TYPE_SNIFF_BYTES = 4100
22
+ PNG_SIGNATURE = [ 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a ].freeze
23
+
24
+ def execute(input)
25
+ path = input[:path] || input["path"]
26
+ offset = input[:offset] || input["offset"]
27
+ limit = input[:limit] || input["limit"]
28
+
29
+ absolute_path = ToolUtils.resolve_read_path(path)
30
+
31
+ return "File not found: #{path}" unless File.exist?(absolute_path)
32
+ return "Cannot read directory: #{path}" if File.directory?(absolute_path)
33
+ return "File is not readable: #{path}" unless File.readable?(absolute_path)
34
+
35
+ mime_type = detect_supported_image_mime_type_from_file(absolute_path)
36
+ if mime_type
37
+ data = Base64.strict_encode64(File.binread(absolute_path))
38
+ return [
39
+ { type: "text", text: "Read image file [#{mime_type}]" },
40
+ { type: "image", data: data, media_type: mime_type }
41
+ ]
42
+ end
43
+
44
+ content = File.read(absolute_path, mode: "r:bom|utf-8")
45
+ all_lines = content.split("\n", -1)
46
+ total_file_lines = all_lines.length
47
+
48
+ start_line = [ 0, (offset || 1).to_i - 1 ].max
49
+ return "Offset #{offset} is beyond end of file (#{all_lines.length} lines total)" if start_line >= all_lines.length
50
+
51
+ selected_content = if limit
52
+ end_line = [ start_line + limit.to_i, all_lines.length ].min
53
+ all_lines[start_line...end_line].join("\n")
54
+ else
55
+ all_lines[start_line..].join("\n")
56
+ end
57
+
58
+ truncation = ToolUtils.truncate_head(selected_content)
59
+ start_display = start_line + 1
60
+
61
+ if truncation[:first_line_exceeds_limit]
62
+ first_line_size = ToolUtils.format_size(all_lines[start_line].to_s.bytesize)
63
+ return "[Line #{start_display} is #{first_line_size}, exceeds #{ToolUtils.format_size(ToolUtils::DEFAULT_MAX_BYTES)} limit. Use bash: sed -n '#{start_display}p' #{path} | head -c #{ToolUtils::DEFAULT_MAX_BYTES}]"
64
+ end
65
+
66
+ output = truncation[:content]
67
+
68
+ if truncation[:truncated]
69
+ end_display = start_display + truncation[:output_lines] - 1
70
+ next_offset = end_display + 1
71
+ suffix = if truncation[:truncated_by] == "lines"
72
+ "[Showing lines #{start_display}-#{end_display} of #{total_file_lines}. Use offset=#{next_offset} to continue.]"
73
+ else
74
+ "[Showing lines #{start_display}-#{end_display} of #{total_file_lines} (#{ToolUtils.format_size(ToolUtils::DEFAULT_MAX_BYTES)} limit). Use offset=#{next_offset} to continue.]"
75
+ end
76
+ output = "#{output}\n\n#{suffix}"
77
+ elsif limit && (start_line + limit.to_i) < all_lines.length
78
+ next_offset = start_line + limit.to_i + 1
79
+ remaining = all_lines.length - (start_line + limit.to_i)
80
+ output = "#{output}\n\n[#{remaining} more lines in file. Use offset=#{next_offset} to continue.]"
81
+ end
82
+
83
+ output
84
+ rescue StandardError => e
85
+ "Error reading file: #{e.message}"
86
+ end
87
+
88
+ private
89
+
90
+ def detect_supported_image_mime_type_from_file(path)
91
+ detect_supported_image_mime_type(File.binread(path, IMAGE_TYPE_SNIFF_BYTES))
92
+ end
93
+
94
+ def detect_supported_image_mime_type(buffer)
95
+ bytes = buffer.bytes
96
+ return "image/jpeg" if jpeg?(bytes)
97
+ return "image/png" if png?(bytes) && !animated_png?(bytes)
98
+ return "image/gif" if ascii_at?(bytes, 0, "GIF")
99
+ return "image/webp" if ascii_at?(bytes, 0, "RIFF") && ascii_at?(bytes, 8, "WEBP")
100
+
101
+ nil
102
+ end
103
+
104
+ def jpeg?(bytes)
105
+ bytes.length >= 4 && bytes[0] == 0xff && bytes[1] == 0xd8 && bytes[2] == 0xff && bytes[3] != 0xf7
106
+ end
107
+
108
+ def png?(bytes)
109
+ starts_with?(bytes, PNG_SIGNATURE) && bytes.length >= 16 && read_uint32_be(bytes, PNG_SIGNATURE.length) == 13 && ascii_at?(bytes, 12, "IHDR")
110
+ end
111
+
112
+ def animated_png?(bytes)
113
+ offset = PNG_SIGNATURE.length
114
+ while offset + 8 <= bytes.length
115
+ chunk_length = read_uint32_be(bytes, offset)
116
+ chunk_type_offset = offset + 4
117
+ return true if ascii_at?(bytes, chunk_type_offset, "acTL")
118
+ return false if ascii_at?(bytes, chunk_type_offset, "IDAT")
119
+
120
+ next_offset = offset + 8 + chunk_length + 4
121
+ return false if next_offset <= offset || next_offset > bytes.length
122
+
123
+ offset = next_offset
124
+ end
125
+ false
126
+ end
127
+
128
+ def read_uint32_be(bytes, offset)
129
+ ((bytes[offset] || 0) << 24) + ((bytes[offset + 1] || 0) << 16) + ((bytes[offset + 2] || 0) << 8) + (bytes[offset + 3] || 0)
130
+ end
131
+
132
+ def starts_with?(bytes, prefix)
133
+ return false if bytes.length < prefix.length
134
+
135
+ prefix.each_with_index.all? { |byte, index| bytes[index] == byte }
136
+ end
137
+
138
+ def ascii_at?(bytes, offset, text)
139
+ return false if bytes.length < offset + text.length
140
+
141
+ text.bytes.each_with_index.all? { |byte, index| bytes[offset + index] == byte }
142
+ end
143
+ end
@@ -0,0 +1,164 @@
1
+ require "pathname"
2
+ require "thread"
3
+
4
+ module ToolUtils
5
+ DEFAULT_MAX_LINES = 2000
6
+ DEFAULT_MAX_BYTES = 50 * 1024
7
+
8
+ @file_mutation_locks = Hash.new { |hash, key| hash[key] = Mutex.new }
9
+ @file_mutation_locks_mutex = Mutex.new
10
+
11
+ module_function
12
+
13
+ def with_file_mutation_lock(path)
14
+ lock = @file_mutation_locks_mutex.synchronize { @file_mutation_locks[path] }
15
+ lock.synchronize { yield }
16
+ end
17
+
18
+ def format_size(bytes)
19
+ return "#{bytes}B" if bytes < 1024
20
+ return format("%.1fKB", bytes / 1024.0) if bytes < 1024 * 1024
21
+
22
+ format("%.1fMB", bytes / (1024.0 * 1024.0))
23
+ end
24
+
25
+ def expand_path(file_path)
26
+ normalized = file_path.to_s.sub(/^@/, "").gsub(/[\u00A0\u2000-\u200A\u202F\u205F\u3000]/, " ")
27
+ return Dir.home if normalized == "~"
28
+ return File.join(Dir.home, normalized[2..]) if normalized.start_with?("~/")
29
+
30
+ normalized
31
+ end
32
+
33
+ def resolve_to_cwd(file_path, cwd = Dir.pwd)
34
+ expanded = expand_path(file_path)
35
+ Pathname.new(expanded).absolute? ? expanded : File.expand_path(expanded, cwd)
36
+ end
37
+
38
+ def resolve_read_path(file_path, cwd = Dir.pwd)
39
+ resolved = resolve_to_cwd(file_path, cwd)
40
+ return resolved if File.exist?(resolved)
41
+
42
+ am_pm_variant = resolved.gsub(/ (AM|PM)\./i) { "\u202F#{Regexp.last_match(1)}." }
43
+ return am_pm_variant if File.exist?(am_pm_variant)
44
+
45
+ nfd_variant = resolved.unicode_normalize(:nfd)
46
+ return nfd_variant if File.exist?(nfd_variant)
47
+
48
+ curly_variant = resolved.tr("'", "\u2019")
49
+ return curly_variant if File.exist?(curly_variant)
50
+
51
+ nfd_curly_variant = nfd_variant.tr("'", "\u2019")
52
+ return nfd_curly_variant if File.exist?(nfd_curly_variant)
53
+
54
+ resolved
55
+ end
56
+
57
+ def truncate_head(content, max_lines: DEFAULT_MAX_LINES, max_bytes: DEFAULT_MAX_BYTES)
58
+ lines = split_lines_for_counting(content)
59
+ total_lines = lines.length
60
+ total_bytes = content.bytesize
61
+
62
+ if total_lines <= max_lines && total_bytes <= max_bytes
63
+ return truncation_result(content, false, nil, total_lines, total_bytes, total_lines, total_bytes, false, false, max_lines, max_bytes)
64
+ end
65
+
66
+ first_line_bytes = lines.first.to_s.bytesize
67
+ if first_line_bytes > max_bytes
68
+ return truncation_result("", true, "bytes", total_lines, total_bytes, 0, 0, false, true, max_lines, max_bytes)
69
+ end
70
+
71
+ out_lines = []
72
+ out_bytes = 0
73
+ truncated_by = "lines"
74
+
75
+ lines.each_with_index do |line, index|
76
+ break if index >= max_lines
77
+
78
+ line_bytes = line.bytesize + (index.positive? ? 1 : 0)
79
+ if out_bytes + line_bytes > max_bytes
80
+ truncated_by = "bytes"
81
+ break
82
+ end
83
+
84
+ out_lines << line
85
+ out_bytes += line_bytes
86
+ end
87
+
88
+ output = out_lines.join("\n")
89
+ truncation_result(output, true, truncated_by, total_lines, total_bytes, out_lines.length, output.bytesize, false, false, max_lines, max_bytes)
90
+ end
91
+
92
+ def truncate_tail(content, max_lines: DEFAULT_MAX_LINES, max_bytes: DEFAULT_MAX_BYTES)
93
+ lines = split_lines_for_counting(content)
94
+ total_lines = lines.length
95
+ total_bytes = content.bytesize
96
+
97
+ if total_lines <= max_lines && total_bytes <= max_bytes
98
+ return truncation_result(content, false, nil, total_lines, total_bytes, total_lines, total_bytes, false, false, max_lines, max_bytes)
99
+ end
100
+
101
+ out_lines = []
102
+ out_bytes = 0
103
+ truncated_by = "lines"
104
+ last_line_partial = false
105
+
106
+ (lines.length - 1).downto(0) do |i|
107
+ break if out_lines.length >= max_lines
108
+
109
+ line = lines[i]
110
+ line_bytes = line.bytesize + (out_lines.empty? ? 0 : 1)
111
+
112
+ if out_bytes + line_bytes > max_bytes
113
+ truncated_by = "bytes"
114
+ if out_lines.empty?
115
+ out_lines.unshift(truncate_string_to_bytes_from_end(line, max_bytes))
116
+ out_bytes = out_lines.first.bytesize
117
+ last_line_partial = true
118
+ end
119
+ break
120
+ end
121
+
122
+ out_lines.unshift(line)
123
+ out_bytes += line_bytes
124
+ end
125
+
126
+ output = out_lines.join("\n")
127
+ truncation_result(output, true, truncated_by, total_lines, total_bytes, out_lines.length, output.bytesize, last_line_partial, false, max_lines, max_bytes)
128
+ end
129
+
130
+ def split_lines_for_counting(content)
131
+ return [] if content.empty?
132
+
133
+ lines = content.split("\n", -1)
134
+ lines.pop if content.end_with?("\n")
135
+ lines
136
+ end
137
+
138
+ def truncation_result(content, truncated, truncated_by, total_lines, total_bytes, output_lines, output_bytes, last_line_partial, first_line_exceeds_limit, max_lines, max_bytes)
139
+ {
140
+ content: content,
141
+ truncated: truncated,
142
+ truncated_by: truncated_by,
143
+ total_lines: total_lines,
144
+ total_bytes: total_bytes,
145
+ output_lines: output_lines,
146
+ output_bytes: output_bytes,
147
+ last_line_partial: last_line_partial,
148
+ first_line_exceeds_limit: first_line_exceeds_limit,
149
+ max_lines: max_lines,
150
+ max_bytes: max_bytes
151
+ }
152
+ end
153
+
154
+ def truncate_string_to_bytes_from_end(str, max_bytes)
155
+ bytes = str.dup.force_encoding("UTF-8").bytes
156
+ return str if bytes.length <= max_bytes
157
+
158
+ tail = bytes.last(max_bytes).pack("C*")
159
+ until tail.valid_encoding?
160
+ tail = tail.bytes.drop(1).pack("C*")
161
+ end
162
+ tail
163
+ end
164
+ end
@@ -0,0 +1,34 @@
1
+ require "fileutils"
2
+ require_relative "tool_utils"
3
+
4
+ class WriteTool < LlmGateway::Tool
5
+ # Pi adaptation notes:
6
+ # - Keep Ruby bytesize in the success message rather than pi's JS string length; the byte count is more accurate for UTF-8 content.
7
+ # - Do not add pi's pluggable operations, AbortSignal handling, render previews, or details metadata: those are UI/runtime extension concerns outside this tool contract.
8
+ name "write"
9
+ description "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
10
+ input_schema({
11
+ type: "object",
12
+ properties: {
13
+ path: { type: "string", description: "Path to the file to write (relative or absolute)" },
14
+ content: { type: "string", description: "Content to write to the file" }
15
+ },
16
+ required: [ "path", "content" ]
17
+ })
18
+
19
+ def execute(input)
20
+ path = input[:path] || input["path"]
21
+ content = input[:content] || input["content"]
22
+
23
+ absolute_path = ToolUtils.resolve_to_cwd(path)
24
+
25
+ ToolUtils.with_file_mutation_lock(absolute_path) do
26
+ FileUtils.mkdir_p(File.dirname(absolute_path))
27
+ File.write(absolute_path, content)
28
+ end
29
+
30
+ "Successfully wrote #{content.bytesize} bytes to #{path}"
31
+ rescue StandardError => e
32
+ e.message
33
+ end
34
+ end
@@ -6,11 +6,9 @@ require "json"
6
6
 
7
7
  module LlmGateway
8
8
  class BaseClient
9
- attr_accessor
10
- attr_reader :api_key, :model_key, :base_endpoint
9
+ attr_reader :api_key, :base_endpoint
11
10
 
12
- def initialize(model_key:, api_key:)
13
- @model_key = model_key
11
+ def initialize(api_key:)
14
12
  @api_key = api_key
15
13
  end
16
14
 
@@ -43,7 +41,7 @@ module LlmGateway
43
41
  request.set_form(form_data, "multipart/form-data")
44
42
 
45
43
  # Headers (excluding Content-Type because set_form already sets it)
46
- multipart_headers = build_headers.reject { |k, _| k.downcase == "content-type" }
44
+ multipart_headers = build_headers.except("content-type", "Content-Type")
47
45
  multipart_headers.each { |key, value| request[key] = value }
48
46
 
49
47
  response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
@@ -114,7 +112,7 @@ module LlmGateway
114
112
  next if data_str == "[DONE]"
115
113
 
116
114
  data = begin
117
- LlmGateway::Utils.deep_symbolize_keys(JSON.parse(data_str))
115
+ JSON.parse(data_str).deep_symbolize_keys
118
116
  rescue JSON::ParserError
119
117
  { raw: data_str }
120
118
  end
@@ -143,7 +141,7 @@ module LlmGateway
143
141
  when 200
144
142
  content_type = response["content-type"]
145
143
  if content_type&.include?("application/json")
146
- LlmGateway::Utils.deep_symbolize_keys(JSON.parse(response.body))
144
+ JSON.parse(response.body).deep_symbolize_keys
147
145
  else
148
146
  response.body
149
147
  end
@@ -9,10 +9,11 @@ module LlmGateway
9
9
  module Clients
10
10
  class Anthropic < BaseClient
11
11
  CLAUDE_CODE_VERSION = "2.1.2"
12
+ DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
12
13
 
13
- def initialize(model_key: "claude-3-7-sonnet-20250219", api_key: ENV["ANTHROPIC_API_KEY"])
14
+ def initialize(api_key: ENV["ANTHROPIC_API_KEY"])
14
15
  @base_endpoint = "https://api.anthropic.com/v1"
15
- super(model_key: model_key, api_key: api_key)
16
+ super(api_key: api_key)
16
17
  end
17
18
 
18
19
  def chat(messages, **kwargs)
@@ -44,28 +45,28 @@ module LlmGateway
44
45
 
45
46
  private
46
47
 
47
- def build_body(messages, tools: nil, system: [], cache_retention: nil, **options)
48
+ def build_body(messages, tools: nil, system: [], cache_retention: nil, model: DEFAULT_MODEL, **options)
48
49
  cache_control = anthropic_cache_control_for(cache_retention)
49
50
 
50
51
  body = {
51
- model: model_key,
52
+ model: model,
52
53
  messages: messages
53
54
  }
54
55
 
55
56
  tools = apply_tools_cache_control(tools, cache_retention)
56
- body.merge!(tools: tools) if LlmGateway::Utils.present?(tools)
57
+ body.merge!(tools: tools) if tools.present?
57
58
 
58
59
  system = prepend_claude_code_identity(system) if claude_code_oauth_api_key?
59
60
  system = apply_system_cache_control(system, cache_retention)
60
61
 
61
- body.merge!(system: system) if LlmGateway::Utils.present?(system)
62
+ body.merge!(system: system) if system.present?
62
63
  body.merge!(cache_control: cache_control) unless cache_control.nil?
63
64
  body.merge!(options)
64
65
  body
65
66
  end
66
67
 
67
68
  def apply_system_cache_control(system, cache_retention)
68
- return system if system.nil? || system.empty? || !system.is_a?(Array)
69
+ return system if system.blank? || !system.is_a?(Array)
69
70
 
70
71
  cache_control = anthropic_cache_control_for(cache_retention)
71
72
  return system if cache_control.nil?
@@ -83,7 +84,7 @@ module LlmGateway
83
84
  end
84
85
 
85
86
  def apply_tools_cache_control(tools, cache_retention)
86
- return tools if tools.nil? || tools.empty? || !tools.is_a?(Array)
87
+ return tools if tools.blank? || !tools.is_a?(Array)
87
88
 
88
89
  cache_control = anthropic_cache_control_for(cache_retention)
89
90
  return tools if cache_control.nil?
@@ -148,7 +149,7 @@ module LlmGateway
148
149
  text: "You are Claude Code, Anthropic's official CLI for Claude."
149
150
  }
150
151
 
151
- if system.nil? || system.empty?
152
+ if system.blank?
152
153
  [ identity ]
153
154
  else
154
155
  [ identity ] + system
@@ -105,7 +105,7 @@ module LlmGateway
105
105
  code = uri.query && URI.decode_www_form(uri.query).to_h["code"]
106
106
  state = uri.query && URI.decode_www_form(uri.query).to_h["state"]
107
107
 
108
- raise ArgumentError, "Callback URL is missing code parameter" if code.nil? || code.empty?
108
+ raise ArgumentError, "Callback URL is missing code parameter" if code.blank?
109
109
 
110
110
  { code: code, state: state }
111
111
  rescue URI::InvalidURIError => e
@@ -116,7 +116,7 @@ module LlmGateway
116
116
 
117
117
  def extract_code_and_state(auth_code_or_callback, state)
118
118
  value = auth_code_or_callback.to_s.strip
119
- raise ArgumentError, "Authorization code is required" if value.empty?
119
+ raise ArgumentError, "Authorization code is required" if value.blank?
120
120
 
121
121
  if looks_like_url?(value)
122
122
  callback = parse_callback(value)
@@ -5,14 +5,16 @@ require_relative "../base_client"
5
5
  module LlmGateway
6
6
  module Clients
7
7
  class Groq < BaseClient
8
- def initialize(model_key: "openai/gpt-oss-120b", api_key: ENV["GROQ_API_KEY"])
8
+ DEFAULT_MODEL = "openai/gpt-oss-120b"
9
+
10
+ def initialize(api_key: ENV["GROQ_API_KEY"])
9
11
  @base_endpoint = "https://api.groq.com/openai/v1"
10
- super(model_key: model_key, api_key: api_key)
12
+ super(api_key: api_key)
11
13
  end
12
14
 
13
- def chat(messages, tools: nil, system: [], **options)
15
+ def chat(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options)
14
16
  body = {
15
- model: model_key,
17
+ model: model,
16
18
  messages: system + messages,
17
19
  tools: tools
18
20
  }
@@ -21,9 +23,9 @@ module LlmGateway
21
23
  post("chat/completions", body)
22
24
  end
23
25
 
24
- def stream(messages, tools: nil, system: [], **options, &block)
26
+ def stream(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options, &block)
25
27
  body = {
26
- model: model_key,
28
+ model: model,
27
29
  messages: system + messages,
28
30
  tools: tools,
29
31
  stream_options: { include_usage: true }
@@ -6,18 +6,20 @@ module LlmGateway
6
6
  module Clients
7
7
  class OpenAI < BaseClient
8
8
  CODEX_BASE_ENDPOINT = "https://chatgpt.com/backend-api/codex"
9
+ DEFAULT_MODEL = "gpt-4o"
10
+ DEFAULT_EMBEDDINGS_MODEL = "text-embedding-3-small"
9
11
 
10
12
  attr_reader :account_id
11
13
 
12
- def initialize(model_key: "gpt-4o", api_key: ENV["OPENAI_API_KEY"], account_id: nil)
14
+ def initialize(api_key: ENV["OPENAI_API_KEY"], account_id: nil)
13
15
  @base_endpoint = "https://api.openai.com/v1"
14
16
  @account_id = account_id
15
- super(model_key: model_key, api_key: api_key)
17
+ super(api_key: api_key)
16
18
  end
17
19
 
18
- def chat(messages, tools: nil, system: [], **options)
20
+ def chat(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options)
19
21
  body = {
20
- model: model_key,
22
+ model: model,
21
23
  messages: system + messages
22
24
  }
23
25
  body[:tools] = tools if tools
@@ -26,9 +28,9 @@ module LlmGateway
26
28
  post("chat/completions", body)
27
29
  end
28
30
 
29
- def stream(messages, tools: nil, system: [], **options, &block)
31
+ def stream(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options, &block)
30
32
  body = {
31
- model: model_key,
33
+ model: model,
32
34
  messages: system + messages
33
35
  }
34
36
  body[:tools] = tools if tools
@@ -38,9 +40,9 @@ module LlmGateway
38
40
  post_stream("chat/completions", body, &block)
39
41
  end
40
42
 
41
- def responses(messages, tools: nil, system: [], **options)
43
+ def responses(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options)
42
44
  body = {
43
- model: model_key,
45
+ model: model,
44
46
  input: messages.flatten
45
47
  }
46
48
  body[:instructions] = system[0][:content] if system.any?
@@ -50,9 +52,9 @@ module LlmGateway
50
52
  post("responses", body)
51
53
  end
52
54
 
53
- def stream_responses(messages, tools: nil, system: [], **options, &block)
55
+ def stream_responses(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options, &block)
54
56
  body = {
55
- model: model_key,
57
+ model: model,
56
58
  input: messages.flatten
57
59
  }
58
60
  body[:instructions] = system[0][:content] if system.any?
@@ -74,8 +76,8 @@ module LlmGateway
74
76
  token_manager.access_token
75
77
  end
76
78
 
77
- def chat_codex(messages, tools: nil, system: [], account_id: nil, **options)
78
- body = build_codex_body(messages, system, tools, **options)
79
+ def chat_codex(messages, tools: nil, system: [], account_id: nil, model: DEFAULT_MODEL, **options)
80
+ body = build_codex_body(messages, system, tools, model: model, **options)
79
81
 
80
82
  completed_response = nil
81
83
  post_codex_stream("responses", body, account_id: account_id) do |raw_sse|
@@ -87,8 +89,8 @@ module LlmGateway
87
89
  completed_response
88
90
  end
89
91
 
90
- def stream_codex(messages, tools: nil, system: [], account_id: nil, **options, &block)
91
- body = build_codex_body(messages, system, tools, **options)
92
+ def stream_codex(messages, tools: nil, system: [], account_id: nil, model: DEFAULT_MODEL, **options, &block)
93
+ body = build_codex_body(messages, system, tools, model: model, **options)
92
94
  post_codex_stream("responses", body, account_id: account_id, &block)
93
95
  end
94
96
 
@@ -96,10 +98,10 @@ module LlmGateway
96
98
  get("files/#{file_id}/content")
97
99
  end
98
100
 
99
- def generate_embeddings(input)
101
+ def generate_embeddings(input, model: DEFAULT_EMBEDDINGS_MODEL)
100
102
  body = {
101
103
  input:,
102
- model: model_key
104
+ model: model
103
105
  }
104
106
  post("embeddings", body)
105
107
  end
@@ -110,12 +112,12 @@ module LlmGateway
110
112
 
111
113
  private
112
114
 
113
- def build_codex_body(messages, system, tools, **options)
115
+ def build_codex_body(messages, system, tools, model:, **options)
114
116
  instructions = Array(system).filter_map { |s| s.is_a?(Hash) ? s[:content] : s }.join("\n")
115
- instructions = "You are a helpful assistant." if instructions.empty?
117
+ instructions = instructions.presence || "You are a helpful assistant."
116
118
 
117
119
  body = {
118
- model: model_key,
120
+ model: model,
119
121
  instructions: instructions,
120
122
  input: messages,
121
123
  store: false,
@@ -194,7 +196,7 @@ module LlmGateway
194
196
  end
195
197
  # If we get here, we didn't handle it specifically
196
198
  fallback_body = response.body.to_s.strip
197
- fallback_message = if fallback_body.empty?
199
+ fallback_message = if fallback_body.blank?
198
200
  "OpenAI request failed with status #{response.code}"
199
201
  else
200
202
  "OpenAI request failed with status #{response.code}: #{fallback_body}"
@@ -96,7 +96,7 @@ module LlmGateway
96
96
  uri = URI.parse(callback_url)
97
97
  params = URI.decode_www_form(uri.query.to_s).to_h
98
98
  code = params["code"]
99
- raise ArgumentError, "Callback URL is missing code parameter" if code.nil? || code.empty?
99
+ raise ArgumentError, "Callback URL is missing code parameter" if code.blank?
100
100
 
101
101
  { code: code, state: params["state"] }
102
102
  rescue URI::InvalidURIError => e
@@ -120,7 +120,7 @@ module LlmGateway
120
120
  input = tty.gets&.strip
121
121
  tty.close
122
122
 
123
- raise "No authorization code provided" if input.nil? || input.empty?
123
+ raise "No authorization code provided" if input.blank?
124
124
 
125
125
  exchange_code(input, flow[:code_verifier], expected_state: flow[:state])
126
126
  end
@@ -183,7 +183,7 @@ module LlmGateway
183
183
  auth = payload[JWT_CLAIM_PATH]
184
184
  account_id = auth&.dig("chatgpt_account_id")
185
185
 
186
- account_id.is_a?(String) && !account_id.empty? ? account_id : nil
186
+ account_id.is_a?(String) ? account_id.presence : nil
187
187
  rescue StandardError
188
188
  nil
189
189
  end
@@ -214,7 +214,7 @@ module LlmGateway
214
214
  end
215
215
 
216
216
  def parse_authorization_input(input, expected_state = nil)
217
- return nil if input.nil? || input.empty?
217
+ return nil if input.blank?
218
218
 
219
219
  value = input.to_s.strip
220
220