llm_gateway 0.5.0 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,8 +22,8 @@ module LlmGateway
22
22
  #
23
23
  # Accepted event shapes:
24
24
  #
25
- # { type: :message_start, delta: { id: "...", model: "...", role: "assistant" }, usage_increment: { ... } }
26
- # { type: :message_delta, delta: { stop_reason: "stop" }, usage_increment: { ... } }
25
+ # { type: :message_start, delta: { id: "...", model: "...", role: "assistant", timestamp: 1716650000000 } }
26
+ # { type: :message_delta, delta: { stop_reason: "stop" }, usage: { output: 2 } }
27
27
  # { type: :message_end }
28
28
  #
29
29
  # { type: :text_start, delta: "hi" }
@@ -50,7 +50,16 @@ module LlmGateway
50
50
  # The accumulator creates the public Assistant* event structs, updates its
51
51
  # accumulated message state, then yields the created event to the callback.
52
52
  attr_accessor :blocks, :message_hash, :usage_hash
53
- attr_reader :active_block_type
53
+ attr_reader :active_block_type, :final_message
54
+
55
+ DEFAULT_USAGE = {
56
+ input: 0,
57
+ cache_write: 0,
58
+ cache_read: 0,
59
+ output: 0,
60
+ total: 0,
61
+ raw: {}
62
+ }.freeze
54
63
 
55
64
  BLOCK_EVENT_TRANSITIONS = {
56
65
  text_start: { block_type: :text, phase: :start },
@@ -64,28 +73,32 @@ module LlmGateway
64
73
  reasoning_end: { block_type: :reasoning, phase: :end }
65
74
  }.freeze
66
75
 
67
- def initialize
76
+ def initialize(provider: nil, api: nil)
77
+ @provider = provider
78
+ @api = api
68
79
  @message_hash = {}
69
- @usage_hash = {
70
- input_tokens: 0,
71
- cache_creation_input_tokens: 0,
72
- cache_read_input_tokens: 0,
73
- output_tokens: 0,
74
- reasoning_tokens: 0
75
- }
80
+ @usage_hash = default_usage
76
81
  @blocks = []
77
82
  @next_content_index = 0
78
83
  @active_block_type = nil
79
84
  @active_content_index = nil
85
+ @timestamp = nil
80
86
  end
81
87
 
82
88
  def result
89
+ ensure_timestamp!
90
+
83
91
  message_hash.merge(
92
+ timestamp: @timestamp,
84
93
  usage: usage_hash,
85
94
  content: serialized_blocks
86
95
  )
87
96
  end
88
97
 
98
+ def final_result
99
+ result.merge(provider: @provider, api: @api)
100
+ end
101
+
89
102
  def active_tool?
90
103
  active_block_type == :tool
91
104
  end
@@ -96,11 +109,19 @@ module LlmGateway
96
109
  event_patch = symbolize_keys(event_patch)
97
110
  type = event_patch.fetch(:type).to_sym
98
111
  event_patch = prepare_event_patch(event_patch.merge(type:), type)
112
+ ensure_timestamp!
99
113
 
100
- event = build_event(event_patch)
114
+ if type == :message_end
115
+ @final_message = AssistantMessage.new(final_result)
116
+ block.call(AssistantStreamMessageEndEvent.new(type:, message: final_message)) if block
117
+ return nil
118
+ end
119
+
120
+ event = build_event(event_patch, partial: empty_partial)
101
121
  accumulate(event)
102
122
  content_index = event.content_index if event.respond_to?(:content_index)
103
123
  commit_block_transition(type, content_index)
124
+ event = build_event(event_patch, partial: partial_message)
104
125
  block.call(event) if block
105
126
 
106
127
  nil
@@ -166,16 +187,21 @@ module LlmGateway
166
187
  end
167
188
  end
168
189
 
169
- def build_event(event_patch)
190
+ def build_event(event_patch, partial:)
170
191
  event_patch = symbolize_keys(event_patch)
171
192
  type = event_patch.fetch(:type).to_sym
172
193
 
173
194
  case type
174
- when :message_start, :message_delta, :message_end
195
+ when :message_start, :message_delta
196
+ delta = symbolize_keys(event_patch[:delta] || {})
197
+ raw_usage = event_patch[:usage] || delta.delete(:usage) || {}
198
+ usage = raw_usage.empty? ? {} : normalized_usage(raw_usage)
199
+
175
200
  AssistantStreamMessageEvent.new(
176
201
  type:,
177
- delta: symbolize_keys(event_patch[:delta] || {}),
178
- usage_increment: symbolize_keys(event_patch[:usage_increment] || {})
202
+ delta:,
203
+ usage:,
204
+ partial:
179
205
  )
180
206
  when :tool_start
181
207
  AssistantToolStartEvent.new(
@@ -183,20 +209,23 @@ module LlmGateway
183
209
  content_index: event_patch.fetch(:content_index),
184
210
  delta: string_value(event_patch[:delta]),
185
211
  id: event_patch[:id],
186
- name: event_patch[:name]
212
+ name: event_patch[:name],
213
+ partial:
187
214
  )
188
215
  when :reasoning_start, :reasoning_delta, :reasoning_end
189
216
  AssistantStreamReasoningEvent.new(
190
217
  type:,
191
218
  content_index: event_patch.fetch(:content_index),
192
219
  delta: string_value(event_patch[:delta]),
193
- signature: string_value(event_patch[:signature])
220
+ signature: string_value(event_patch[:signature]),
221
+ partial:
194
222
  )
195
223
  when :text_start, :text_delta, :text_end, :tool_delta, :tool_end
196
224
  AssistantStreamEvent.new(
197
225
  type:,
198
226
  content_index: event_patch.fetch(:content_index),
199
- delta: string_value(event_patch[:delta])
227
+ delta: string_value(event_patch[:delta]),
228
+ partial:
200
229
  )
201
230
  else
202
231
  raise ArgumentError, "Unsupported normalized stream event type: #{type.inspect}"
@@ -204,6 +233,8 @@ module LlmGateway
204
233
  end
205
234
 
206
235
  def accumulate(event)
236
+ @timestamp = event.delta[:timestamp] if event.respond_to?(:delta) && event.delta.is_a?(Hash) && event.delta[:timestamp]
237
+
207
238
  case event.type
208
239
  when :text_start
209
240
  blocks[event.content_index] = {
@@ -224,9 +255,6 @@ module LlmGateway
224
255
  blocks[event.content_index][:input] += event.delta
225
256
  when :message_start
226
257
  message_hash.merge!(event.delta)
227
- usage_hash.each_key do |key|
228
- usage_hash[key] += event.usage_increment.fetch(key, 0)
229
- end
230
258
  when :reasoning_start
231
259
  blocks[event.content_index] = {
232
260
  type: "reasoning",
@@ -240,13 +268,42 @@ module LlmGateway
240
268
  blocks[event.content_index][:signature] += event.signature
241
269
  when :message_delta
242
270
  message_hash.merge!(event.delta)
243
- usage_hash.each_key do |key|
244
- usage_hash[key] += event.usage_increment.fetch(key, 0)
245
- end
246
- when :message_end
271
+ assign_usage(event.usage) unless event.usage.empty?
247
272
  end
248
273
  end
249
274
 
275
+ def empty_partial
276
+ PartialAssistantMessage.new(timestamp: @timestamp)
277
+ end
278
+
279
+ def partial_message
280
+ PartialAssistantMessage.new(partial_result)
281
+ end
282
+
283
+ def partial_result
284
+ ensure_timestamp!
285
+
286
+ message_hash.merge(
287
+ timestamp: @timestamp,
288
+ content: serialized_blocks
289
+ )
290
+ end
291
+
292
+ def assign_usage(usage)
293
+ @usage_hash = normalized_usage(usage)
294
+ end
295
+
296
+ def normalized_usage(usage)
297
+ usage = default_usage.merge(symbolize_keys(usage).slice(*DEFAULT_USAGE.keys))
298
+ usage[:total] = usage[:input] + usage[:cache_write] + usage[:cache_read] + usage[:output]
299
+ usage[:raw] ||= {}
300
+ usage
301
+ end
302
+
303
+ def default_usage
304
+ DEFAULT_USAGE.merge(raw: {})
305
+ end
306
+
250
307
  def serialized_blocks
251
308
  blocks.map do |content_block|
252
309
  next content_block unless content_block[:type] == "tool_use"
@@ -270,6 +327,10 @@ module LlmGateway
270
327
  def string_value(value)
271
328
  value.nil? ? "" : value.to_s
272
329
  end
330
+
331
+ def ensure_timestamp!
332
+ @timestamp ||= (Time.now.to_f * 1000).to_i
333
+ end
273
334
  end
274
335
  end
275
336
  end
@@ -92,9 +92,9 @@ module LlmGateway
92
92
  delta: {
93
93
  id: data[:id],
94
94
  model: data[:model],
95
- role: delta[:role] || "assistant"
96
- }.compact,
97
- usage_increment: {}
95
+ role: delta[:role] || "assistant",
96
+ timestamp: timestamp_milliseconds(data[:created])
97
+ }.compact
98
98
  }
99
99
  ]
100
100
  end
@@ -198,34 +198,58 @@ module LlmGateway
198
198
  *close_active_block_patches(active_block_type:),
199
199
  {
200
200
  type: :message_delta,
201
- delta: { stop_reason: normalize_stop_reason(finish_reason) },
202
- usage_increment: {}
201
+ delta: { stop_reason: normalize_stop_reason(finish_reason) }
203
202
  }
204
203
  ]
205
204
  end
206
205
 
207
206
  def final_usage_patches(data)
207
+ patch = {
208
+ type: :message_delta,
209
+ delta: {}
210
+ }
211
+ patch[:usage] = usage(data) if data.key?(:usage)
212
+
208
213
  [
209
- {
210
- type: accumulator.message_hash.empty? ? :message_start : :message_delta,
211
- delta: {},
212
- usage_increment: usage_increment(data)
213
- }
214
+ patch,
215
+ { type: :message_end }
214
216
  ]
215
217
  end
216
218
 
217
- def usage_increment(data)
219
+ def usage(data)
218
220
  usage = data[:usage] || {}
221
+ cache_read = token_count(
222
+ usage.dig(:prompt_tokens_details, :cached_tokens),
223
+ usage[:prompt_cache_hit_tokens]
224
+ )
225
+ cache_write = token_count(
226
+ usage.dig(:prompt_tokens_details, :cache_write_tokens),
227
+ usage[:cache_write_tokens]
228
+ )
229
+ prompt_tokens = token_count(usage[:prompt_tokens])
230
+ input = [ prompt_tokens - cache_read - cache_write, 0 ].max
231
+ output = token_count(usage[:completion_tokens])
219
232
 
220
233
  {
221
- input_tokens: usage[:prompt_tokens] || 0,
222
- cache_creation_input_tokens: 0,
223
- cache_read_input_tokens: usage.dig(:prompt_tokens_details, :cached_tokens) || 0,
224
- output_tokens: usage[:completion_tokens] || 0,
225
- reasoning_tokens: usage.dig(:completion_tokens_details, :reasoning_tokens) || 0
234
+ input:,
235
+ cache_write:,
236
+ cache_read:,
237
+ output:,
238
+ total: input + cache_write + cache_read + output,
239
+ raw: usage
226
240
  }
227
241
  end
228
242
 
243
+ def token_count(*values)
244
+ values.compact.first.to_i
245
+ end
246
+
247
+ def timestamp_milliseconds(unix_seconds)
248
+ return nil if unix_seconds.nil?
249
+
250
+ (unix_seconds.to_f * 1000).to_i
251
+ end
252
+
229
253
  def normalize_stop_reason(finish_reason)
230
254
  case finish_reason
231
255
  when "tool_calls"
@@ -55,9 +55,9 @@ module LlmGateway
55
55
  delta: {
56
56
  id: response[:id],
57
57
  model: response[:model],
58
- role: "assistant"
59
- }.compact,
60
- usage_increment: {}
58
+ role: "assistant",
59
+ timestamp: timestamp_milliseconds(response[:created_at])
60
+ }.compact
61
61
  }
62
62
  ]
63
63
  end
@@ -72,8 +72,7 @@ module LlmGateway
72
72
  [
73
73
  {
74
74
  type: :message_start,
75
- delta: { role: item[:role] || "assistant" },
76
- usage_increment: {}
75
+ delta: { role: item[:role] || "assistant" }
77
76
  }
78
77
  ]
79
78
  when "function_call"
@@ -106,33 +105,55 @@ module LlmGateway
106
105
 
107
106
  def response_completed_patches(response)
108
107
  response ||= {}
108
+ patch = {
109
+ type: :message_delta,
110
+ delta: {
111
+ id: response[:id],
112
+ model: response[:model],
113
+ role: "assistant",
114
+ timestamp: timestamp_milliseconds(response[:created_at]),
115
+ stop_reason: stop_reason_for(response)
116
+ }.compact
117
+ }
118
+ patch[:usage] = usage(response) if response.key?(:usage)
109
119
 
110
120
  [
111
- {
112
- type: accumulator.message_hash.empty? ? :message_start : :message_delta,
113
- delta: {
114
- id: response[:id],
115
- model: response[:model],
116
- role: "assistant",
117
- stop_reason: stop_reason_for(response)
118
- }.compact,
119
- usage_increment: usage_increment(response)
120
- }
121
+ patch,
122
+ { type: :message_end }
121
123
  ]
122
124
  end
123
125
 
124
- def usage_increment(response)
126
+ def usage(response)
125
127
  usage = response[:usage] || {}
128
+ cache_read = token_count(usage.dig(:input_tokens_details, :cached_tokens))
129
+ cache_write = token_count(
130
+ usage.dig(:input_tokens_details, :cache_write_tokens),
131
+ usage[:cache_write_tokens]
132
+ )
133
+ input_tokens = token_count(usage[:input_tokens])
134
+ input = [ input_tokens - cache_read - cache_write, 0 ].max
135
+ output = token_count(usage[:output_tokens])
126
136
 
127
137
  {
128
- input_tokens: usage[:input_tokens] || 0,
129
- cache_creation_input_tokens: 0,
130
- cache_read_input_tokens: usage.dig(:input_tokens_details, :cached_tokens) || 0,
131
- output_tokens: usage[:output_tokens] || 0,
132
- reasoning_tokens: usage.dig(:output_tokens_details, :reasoning_tokens) || 0
138
+ input:,
139
+ cache_write:,
140
+ cache_read:,
141
+ output:,
142
+ total: input + cache_write + cache_read + output,
143
+ raw: usage
133
144
  }
134
145
  end
135
146
 
147
+ def token_count(*values)
148
+ values.compact.first.to_i
149
+ end
150
+
151
+ def timestamp_milliseconds(unix_seconds)
152
+ return nil if unix_seconds.nil?
153
+
154
+ (unix_seconds.to_f * 1000).to_i
155
+ end
156
+
136
157
  def stop_reason_for(response)
137
158
  output = response[:output] || []
138
159
  last_item = output.last || {}
@@ -5,14 +5,21 @@ require_relative "normalized_stream_accumulator"
5
5
  module LlmGateway
6
6
  module Adapters
7
7
  class StreamMapper
8
+ def initialize(provider:, api:)
9
+ @provider = provider
10
+ @api = api
11
+ end
12
+
8
13
  def result
9
- accumulator.result
14
+ accumulator.final_message
10
15
  end
11
16
 
12
17
  private
13
18
 
19
+ attr_reader :provider, :api
20
+
14
21
  def accumulator
15
- @accumulator ||= LlmGateway::Adapters::NormalizedStreamAccumulator.new
22
+ @accumulator ||= LlmGateway::Adapters::NormalizedStreamAccumulator.new(provider:, api:)
16
23
  end
17
24
 
18
25
  def push_patches(patches, &block)
@@ -9,35 +9,6 @@ class BaseStruct < Dry::Struct
9
9
  transform_keys(&:to_sym)
10
10
  end
11
11
 
12
- class AssistantStreamEvent < BaseStruct
13
- EventType = Types::Coercible::Symbol.enum(:text_start, :text_delta, :text_end, :tool_start, :tool_delta, :tool_end, :reasoning_start, :reasoning_delta, :reasoning_end)
14
-
15
- attribute :type, EventType
16
- attribute :delta, Types::Coercible::String.default { "" }
17
- attribute :content_index, Types::Integer
18
- end
19
-
20
-
21
- class AssistantToolStartEvent < AssistantStreamEvent
22
- attribute :id, Types::String
23
- attribute :name, Types::String
24
- attribute :content_index, Types::Integer
25
- end
26
-
27
-
28
- class AssistantStreamReasoningEvent < AssistantStreamEvent
29
- attribute :signature, Types::Coercible::String.default { "" }
30
- attribute :content_index, Types::Integer
31
- end
32
-
33
- class AssistantStreamMessageEvent < BaseStruct
34
- EventType = Types::Coercible::Symbol.enum(:message_start, :message_delta, :message_end)
35
-
36
- attribute :type, EventType
37
- attribute :delta, Types::Coercible::Hash.default { {} }
38
- attribute :usage_increment, Types::Coercible::Hash.default { {} }
39
- end
40
-
41
12
  class TextContent < BaseStruct
42
13
  attribute :type, Types::String.enum("text")
43
14
  attribute :text, Types::String
@@ -87,12 +58,101 @@ class ToolResult < BaseStruct
87
58
  attribute :content, Types::String
88
59
  end
89
60
 
90
- class AssistantMessage < BaseStruct
61
+ class PartialAssistantMessage < BaseStruct
91
62
  ContentBlock =
92
63
  Types.Instance(TextContent) |
93
64
  Types.Instance(ReasoningContent) |
94
65
  Types.Instance(ToolCall)
95
66
 
67
+ attribute? :id, Types::String.optional
68
+ attribute? :model, Types::String.optional
69
+ attribute? :role, Types::String.enum("assistant").optional
70
+ attribute :timestamp, Types::Integer
71
+ attribute? :stop_reason, Types::String.enum("stop", "length", "tool_use", "toolUse", "error", "aborted").optional
72
+ attribute? :content, Types::Array.of(ContentBlock).optional
73
+
74
+ def self.new(attributes = {})
75
+ attrs = attributes.to_h.transform_keys(&:to_sym)
76
+ attrs[:content] = Array(attrs[:content]).map { |block| build_content_block(block) } if attrs.key?(:content)
77
+ super(attrs)
78
+ end
79
+
80
+ def self.build_content_block(block)
81
+ return block if block.is_a?(TextContent) || block.is_a?(ReasoningContent) || block.is_a?(ToolCall)
82
+
83
+ case block[:type] || block["type"]
84
+ when "text"
85
+ TextContent.new(block)
86
+ when "reasoning"
87
+ ReasoningContent.new(block)
88
+ when "thinking"
89
+ ReasoningContent.new(
90
+ type: "reasoning",
91
+ reasoning: block[:thinking] || block["thinking"] || block[:reasoning] || block["reasoning"],
92
+ signature: block[:signature] || block["signature"]
93
+ )
94
+ when "tool_use"
95
+ ToolCall.new(block)
96
+ else
97
+ raise ArgumentError, "Unsupported content block type: #{block[:type] || block['type']}"
98
+ end
99
+ end
100
+
101
+ private_class_method :build_content_block
102
+ end
103
+
104
+ class AssistantStreamEvent < BaseStruct
105
+ EventType = Types::Coercible::Symbol.enum(:text_start, :text_delta, :text_end, :tool_start, :tool_delta, :tool_end, :reasoning_start, :reasoning_delta, :reasoning_end)
106
+
107
+ attribute :type, EventType
108
+ attribute :delta, Types::Coercible::String.default { "" }
109
+ attribute :content_index, Types::Integer
110
+ attribute :partial, Types.Instance(PartialAssistantMessage)
111
+
112
+ def content
113
+ case type
114
+ when :text_end
115
+ finalized_content_block&.text
116
+ when :reasoning_end
117
+ finalized_content_block&.reasoning
118
+ when :tool_end
119
+ finalized_content_block
120
+ end
121
+ end
122
+
123
+ def text
124
+ content if type == :text_end
125
+ end
126
+
127
+ def reasoning
128
+ content if type == :reasoning_end
129
+ end
130
+
131
+ def tool_call
132
+ finalized_content_block if type == :tool_end
133
+ end
134
+
135
+ alias tool tool_call
136
+
137
+ private
138
+
139
+ def finalized_content_block
140
+ partial.content&.[](content_index)
141
+ end
142
+ end
143
+
144
+ class AssistantToolStartEvent < AssistantStreamEvent
145
+ attribute :id, Types::String
146
+ attribute :name, Types::String
147
+ attribute :content_index, Types::Integer
148
+ end
149
+
150
+ class AssistantStreamReasoningEvent < AssistantStreamEvent
151
+ attribute :signature, Types::Coercible::String.default { "" }
152
+ attribute :content_index, Types::Integer
153
+ end
154
+
155
+ class AssistantMessage < PartialAssistantMessage
96
156
  attribute :id, Types::String
97
157
  attribute :model, Types::String
98
158
  attribute :usage, Types::Hash
@@ -103,12 +163,6 @@ class AssistantMessage < BaseStruct
103
163
  attribute? :error_message, Types::String.optional
104
164
  attribute :content, Types::Array.of(ContentBlock)
105
165
 
106
- def self.new(attributes)
107
- attrs = attributes.to_h.transform_keys(&:to_sym)
108
- attrs[:content] = Array(attrs[:content]).map { |block| build_content_block(block) }
109
- super(attrs)
110
- end
111
-
112
166
  def to_h
113
167
  result = {
114
168
  id: id,
@@ -120,26 +174,22 @@ class AssistantMessage < BaseStruct
120
174
  api: api,
121
175
  content: content.map(&:to_h)
122
176
  }
177
+ result[:timestamp] = timestamp unless timestamp.nil?
123
178
  result[:error_message] = error_message unless error_message.nil?
124
179
  result
125
180
  end
181
+ end
126
182
 
127
- def self.build_content_block(block)
128
- return block if block.is_a?(TextContent) || block.is_a?(ReasoningContent) || block.is_a?(ToolCall)
183
+ class AssistantStreamMessageEvent < BaseStruct
184
+ EventType = Types::Coercible::Symbol.enum(:message_start, :message_delta)
129
185
 
130
- case block[:type] || block["type"]
131
- when "text"
132
- TextContent.new(block)
133
- when "reasoning"
134
- ReasoningContent.new(block)
135
- when "thinking"
136
- ReasoningContent.new(type: "reasoning", reasoning: block[:thinking] || block["thinking"] || block[:reasoning] || block["reasoning"], signature: block[:signature] || block["signature"])
137
- when "tool_use"
138
- ToolCall.new(block)
139
- else
140
- raise ArgumentError, "Unsupported content block type: #{block[:type] || block['type']}"
141
- end
142
- end
186
+ attribute :type, EventType
187
+ attribute :delta, Types::Coercible::Hash.default { {} }
188
+ attribute :usage, Types::Coercible::Hash.default { {} }
189
+ attribute :partial, Types.Instance(PartialAssistantMessage)
190
+ end
143
191
 
144
- private_class_method :build_content_block
192
+ class AssistantStreamMessageEndEvent < BaseStruct
193
+ attribute :type, Types::Coercible::Symbol.enum(:message_end)
194
+ attribute :message, Types.Instance(AssistantMessage)
145
195
  end
@@ -6,11 +6,9 @@ require "json"
6
6
 
7
7
  module LlmGateway
8
8
  class BaseClient
9
- attr_accessor
10
- attr_reader :api_key, :model_key, :base_endpoint
9
+ attr_reader :api_key, :base_endpoint
11
10
 
12
- def initialize(model_key:, api_key:)
13
- @model_key = model_key
11
+ def initialize(api_key:)
14
12
  @api_key = api_key
15
13
  end
16
14
 
@@ -9,10 +9,11 @@ module LlmGateway
9
9
  module Clients
10
10
  class Anthropic < BaseClient
11
11
  CLAUDE_CODE_VERSION = "2.1.2"
12
+ DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
12
13
 
13
- def initialize(model_key: "claude-3-7-sonnet-20250219", api_key: ENV["ANTHROPIC_API_KEY"])
14
+ def initialize(api_key: ENV["ANTHROPIC_API_KEY"])
14
15
  @base_endpoint = "https://api.anthropic.com/v1"
15
- super(model_key: model_key, api_key: api_key)
16
+ super(api_key: api_key)
16
17
  end
17
18
 
18
19
  def chat(messages, **kwargs)
@@ -44,11 +45,11 @@ module LlmGateway
44
45
 
45
46
  private
46
47
 
47
- def build_body(messages, tools: nil, system: [], cache_retention: nil, **options)
48
+ def build_body(messages, tools: nil, system: [], cache_retention: nil, model: DEFAULT_MODEL, **options)
48
49
  cache_control = anthropic_cache_control_for(cache_retention)
49
50
 
50
51
  body = {
51
- model: model_key,
52
+ model: model,
52
53
  messages: messages
53
54
  }
54
55
 
@@ -5,14 +5,16 @@ require_relative "../base_client"
5
5
  module LlmGateway
6
6
  module Clients
7
7
  class Groq < BaseClient
8
- def initialize(model_key: "openai/gpt-oss-120b", api_key: ENV["GROQ_API_KEY"])
8
+ DEFAULT_MODEL = "openai/gpt-oss-120b"
9
+
10
+ def initialize(api_key: ENV["GROQ_API_KEY"])
9
11
  @base_endpoint = "https://api.groq.com/openai/v1"
10
- super(model_key: model_key, api_key: api_key)
12
+ super(api_key: api_key)
11
13
  end
12
14
 
13
- def chat(messages, tools: nil, system: [], **options)
15
+ def chat(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options)
14
16
  body = {
15
- model: model_key,
17
+ model: model,
16
18
  messages: system + messages,
17
19
  tools: tools
18
20
  }
@@ -21,9 +23,9 @@ module LlmGateway
21
23
  post("chat/completions", body)
22
24
  end
23
25
 
24
- def stream(messages, tools: nil, system: [], **options, &block)
26
+ def stream(messages, tools: nil, system: [], model: DEFAULT_MODEL, **options, &block)
25
27
  body = {
26
- model: model_key,
28
+ model: model,
27
29
  messages: system + messages,
28
30
  tools: tools,
29
31
  stream_options: { include_usage: true }