openai 0.25.1 → 0.27.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,683 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OpenAI
4
+ module Helpers
5
+ module Streaming
6
+ class ChatCompletionStream
7
+ include OpenAI::Internal::Type::BaseStream
8
+
9
+ def initialize(raw_stream:, response_format: nil, input_tools: nil)
10
+ @raw_stream = raw_stream
11
+ @state = ChatCompletionStreamState.new(
12
+ response_format: response_format,
13
+ input_tools: input_tools
14
+ )
15
+ @iterator = iterator
16
+ end
17
+
18
+ def get_final_completion
19
+ until_done
20
+ @state.get_final_completion
21
+ end
22
+
23
+ def get_output_text
24
+ completion = get_final_completion
25
+ text_parts = []
26
+
27
+ completion.choices.each do |choice|
28
+ next unless choice.message.content
29
+ text_parts << choice.message.content
30
+ end
31
+
32
+ text_parts.join
33
+ end
34
+
35
+ def until_done
36
+ each {} # rubocop:disable Lint/EmptyBlock
37
+ self
38
+ end
39
+
40
+ def current_completion_snapshot
41
+ @state.current_completion_snapshot
42
+ end
43
+
44
+ def text
45
+ OpenAI::Internal::Util.chain_fused(@iterator) do |yielder|
46
+ @iterator.each do |event|
47
+ yielder << event.delta if event.is_a?(ChatContentDeltaEvent)
48
+ end
49
+ end
50
+ end
51
+
52
+ private
53
+
54
+ def iterator
55
+ @iterator ||= OpenAI::Internal::Util.chain_fused(@raw_stream) do |y|
56
+ @raw_stream.each do |raw_event|
57
+ next unless valid_chat_completion_chunk?(raw_event)
58
+ @state.handle_chunk(raw_event).each do |event|
59
+ y << event
60
+ end
61
+ end
62
+ end
63
+ end
64
+
65
+ def valid_chat_completion_chunk?(sse_event)
66
+ # Although the _raw_stream is always supposed to contain only objects adhering to ChatCompletionChunk schema,
67
+ # this is broken by the Azure OpenAI in case of Asynchronous Filter enabled.
68
+ # An easy filter is to check for the "object" property:
69
+ # - should be "chat.completion.chunk" for a ChatCompletionChunk;
70
+ # - is an empty string for Asynchronous Filter events.
71
+ sse_event.object == :"chat.completion.chunk"
72
+ end
73
+ end
74
+
75
+ class ChatCompletionStreamState
76
+ attr_reader :current_completion_snapshot
77
+
78
+ def initialize(response_format: nil, input_tools: nil)
79
+ @current_completion_snapshot = nil
80
+ @choice_event_states = []
81
+ @input_tools = Array(input_tools)
82
+ @response_format = response_format
83
+ @rich_response_format = response_format.is_a?(Class) ? response_format : nil
84
+ end
85
+
86
+ def get_final_completion
87
+ parse_chat_completion(
88
+ chat_completion: current_completion_snapshot,
89
+ response_format: @rich_response_format
90
+ )
91
+ end
92
+
93
+ # Transforms raw streaming chunks into higher-level events that represent content changes,
94
+ # tool calls, and completion states. It maintains a running snapshot of the complete
95
+ # response by accumulating data from each chunk.
96
+ #
97
+ # The method performs the following steps:
98
+ # 1. Unwraps the chunk if it's wrapped in a ChatChunkEvent
99
+ # 2. Filters out non-ChatCompletionChunk objects
100
+ # 3. Accumulates the chunk data into the current completion snapshot
101
+ # 4. Generates appropriate events based on the chunk's content
102
+ def handle_chunk(chunk)
103
+ chunk = chunk.chunk if chunk.is_a?(ChatChunkEvent)
104
+
105
+ return [] unless chunk.is_a?(OpenAI::Chat::ChatCompletionChunk)
106
+
107
+ @current_completion_snapshot = accumulate_chunk(chunk)
108
+ build_events(chunk: chunk, completion_snapshot: @current_completion_snapshot)
109
+ end
110
+
111
+ private
112
+
113
+ def get_choice_state(choice)
114
+ index = choice.index
115
+ @choice_event_states[index] ||= ChoiceEventState.new(input_tools: @input_tools)
116
+ end
117
+
118
+ def accumulate_chunk(chunk)
119
+ if @current_completion_snapshot.nil?
120
+ return convert_initial_chunk_into_snapshot(chunk)
121
+ end
122
+
123
+ completion_snapshot = @current_completion_snapshot
124
+
125
+ chunk.choices.each do |choice|
126
+ accumulate_choice!(choice, completion_snapshot)
127
+ end
128
+
129
+ completion_snapshot.usage = chunk.usage if chunk.usage
130
+ completion_snapshot.system_fingerprint = chunk.system_fingerprint if chunk.system_fingerprint
131
+
132
+ completion_snapshot
133
+ end
134
+
135
+ def accumulate_choice!(choice, completion_snapshot)
136
+ choice_snapshot = completion_snapshot.choices[choice.index]
137
+
138
+ if choice_snapshot.nil?
139
+ choice_snapshot = create_new_choice_snapshot(choice)
140
+ completion_snapshot.choices[choice.index] = choice_snapshot
141
+ else
142
+ update_existing_choice_snapshot(choice, choice_snapshot)
143
+ end
144
+
145
+ if choice.finish_reason
146
+ choice_snapshot.finish_reason = choice.finish_reason
147
+ handle_finish_reason(choice.finish_reason)
148
+ end
149
+
150
+ parse_tool_calls!(choice.delta.tool_calls, choice_snapshot.message.tool_calls)
151
+
152
+ accumulate_logprobs!(choice.logprobs, choice_snapshot)
153
+ end
154
+
155
+ def create_new_choice_snapshot(choice)
156
+ OpenAI::Internal::Type::Converter.coerce(
157
+ OpenAI::Models::Chat::ParsedChoice,
158
+ choice.to_h.except(:delta).merge(message: choice.delta.to_h)
159
+ )
160
+ end
161
+
162
+ def update_existing_choice_snapshot(choice, choice_snapshot)
163
+ delta_data = model_dump(choice.delta)
164
+ message_hash = model_dump(choice_snapshot.message)
165
+
166
+ accumulated_data = accumulate_delta(message_hash, delta_data)
167
+
168
+ choice_snapshot.message = OpenAI::Internal::Type::Converter.coerce(
169
+ OpenAI::Chat::ChatCompletionMessage,
170
+ accumulated_data
171
+ )
172
+ end
173
+
174
+ def build_events(chunk:, completion_snapshot:)
175
+ chunk_event = ChatChunkEvent.new(
176
+ type: :chunk,
177
+ chunk: chunk,
178
+ snapshot: completion_snapshot
179
+ )
180
+
181
+ choice_events = chunk.choices.flat_map do |choice|
182
+ build_choice_events(choice, completion_snapshot)
183
+ end
184
+
185
+ [chunk_event] + choice_events
186
+ end
187
+
188
+ def build_choice_events(choice, completion_snapshot)
189
+ choice_state = get_choice_state(choice)
190
+ choice_snapshot = completion_snapshot.choices[choice.index]
191
+
192
+ content_delta_events(choice, choice_snapshot) +
193
+ tool_call_delta_events(choice, choice_snapshot) +
194
+ logprobs_delta_events(choice, choice_snapshot) +
195
+ choice_state.get_done_events(
196
+ choice_chunk: choice,
197
+ choice_snapshot: choice_snapshot,
198
+ response_format: @response_format
199
+ )
200
+ end
201
+
202
+ def content_delta_events(choice, choice_snapshot)
203
+ events = []
204
+
205
+ if choice.delta.content && choice_snapshot.message.content
206
+ events << ChatContentDeltaEvent.new(
207
+ type: :"content.delta",
208
+ delta: choice.delta.content,
209
+ snapshot: choice_snapshot.message.content,
210
+ parsed: choice_snapshot.message.parsed
211
+ )
212
+ end
213
+
214
+ if choice.delta.refusal && choice_snapshot.message.refusal
215
+ events << ChatRefusalDeltaEvent.new(
216
+ type: :"refusal.delta",
217
+ delta: choice.delta.refusal,
218
+ snapshot: choice_snapshot.message.refusal
219
+ )
220
+ end
221
+
222
+ events
223
+ end
224
+
225
+ def tool_call_delta_events(choice, choice_snapshot)
226
+ events = []
227
+ return events unless choice.delta.tool_calls
228
+
229
+ tool_calls = choice_snapshot.message.tool_calls
230
+ return events unless tool_calls
231
+
232
+ choice.delta.tool_calls.each do |tool_call_delta|
233
+ tool_call = tool_calls[tool_call_delta.index]
234
+ next unless tool_call.type == :function && tool_call_delta.function
235
+
236
+ parsed_args = if tool_call.function.respond_to?(:parsed)
237
+ tool_call.function.parsed
238
+ end
239
+ events << ChatFunctionToolCallArgumentsDeltaEvent.new(
240
+ type: :"tool_calls.function.arguments.delta",
241
+ name: tool_call.function.name,
242
+ index: tool_call_delta.index,
243
+ arguments: tool_call.function.arguments,
244
+ parsed: parsed_args,
245
+ arguments_delta: tool_call_delta.function.arguments || ""
246
+ )
247
+ end
248
+
249
+ events
250
+ end
251
+
252
+ def logprobs_delta_events(choice, choice_snapshot)
253
+ events = []
254
+ return events unless choice.logprobs && choice_snapshot.logprobs
255
+
256
+ if choice.logprobs.content && choice_snapshot.logprobs.content
257
+ events << ChatLogprobsContentDeltaEvent.new(
258
+ type: :"logprobs.content.delta",
259
+ content: choice.logprobs.content,
260
+ snapshot: choice_snapshot.logprobs.content
261
+ )
262
+ end
263
+
264
+ if choice.logprobs.refusal && choice_snapshot.logprobs.refusal
265
+ events << ChatLogprobsRefusalDeltaEvent.new(
266
+ type: :"logprobs.refusal.delta",
267
+ refusal: choice.logprobs.refusal,
268
+ snapshot: choice_snapshot.logprobs.refusal
269
+ )
270
+ end
271
+
272
+ events
273
+ end
274
+
275
+ def handle_finish_reason(finish_reason)
276
+ return unless parseable_input?
277
+
278
+ case finish_reason
279
+ when :length
280
+ raise LengthFinishReasonError.new(completion: @chat_completion)
281
+ when :content_filter
282
+ raise ContentFilterFinishReasonError.new
283
+ end
284
+ end
285
+
286
+ def parse_tool_calls!(delta_tool_calls, snapshot_tool_calls)
287
+ return unless delta_tool_calls && snapshot_tool_calls
288
+
289
+ delta_tool_calls.each do |tool_call_chunk|
290
+ tool_call_snapshot = snapshot_tool_calls[tool_call_chunk.index]
291
+ next unless tool_call_snapshot&.type == :function
292
+
293
+ input_tool = find_input_tool(tool_call_snapshot.function.name)
294
+ next unless input_tool&.dig(:function, :strict)
295
+ next unless tool_call_snapshot.function.arguments
296
+
297
+ begin
298
+ tool_call_snapshot.function.parsed = JSON.parse(
299
+ tool_call_snapshot.function.arguments,
300
+ symbolize_names: true
301
+ )
302
+ rescue JSON::ParserError
303
+ nil
304
+ end
305
+ end
306
+ end
307
+
308
+ def accumulate_logprobs!(choice_logprobs, choice_snapshot)
309
+ return unless choice_logprobs
310
+
311
+ if choice_snapshot.logprobs.nil?
312
+ choice_snapshot.logprobs = OpenAI::Chat::ChatCompletionChunk::Choice::Logprobs.new(
313
+ content: choice_logprobs.content,
314
+ refusal: choice_logprobs.refusal
315
+ )
316
+ else
317
+ if choice_logprobs.content
318
+ choice_snapshot.logprobs.content ||= []
319
+ choice_snapshot.logprobs.content.concat(choice_logprobs.content)
320
+ end
321
+
322
+ if choice_logprobs.refusal
323
+ choice_snapshot.logprobs.refusal ||= []
324
+ choice_snapshot.logprobs.refusal.concat(choice_logprobs.refusal)
325
+ end
326
+ end
327
+ end
328
+
329
+ def parse_chat_completion(chat_completion:, response_format:)
330
+ choices = chat_completion.choices.map do |choice|
331
+ if parseable_input?
332
+ case choice.finish_reason
333
+ when :length
334
+ raise LengthFinishReasonError.new(completion: chat_completion)
335
+ when :content_filter
336
+ raise ContentFilterFinishReasonError.new
337
+ end
338
+ end
339
+
340
+ build_parsed_choice(choice, response_format)
341
+ end
342
+
343
+ OpenAI::Internal::Type::Converter.coerce(
344
+ OpenAI::Chat::ParsedChatCompletion,
345
+ chat_completion.to_h.merge(choices: choices)
346
+ )
347
+ end
348
+
349
+ def build_parsed_choice(choice, response_format)
350
+ message = choice.message
351
+
352
+ tool_calls = parse_choice_tool_calls(message.tool_calls)
353
+
354
+ choice_data = model_dump(choice)
355
+ choice_data[:message] = model_dump(message)
356
+ choice_data[:message][:tool_calls] = tool_calls && !tool_calls.empty? ? tool_calls : nil
357
+
358
+ if response_format && message.content && !message.refusal
359
+ choice_data[:message][:parsed] = parse_content(response_format, message)
360
+ end
361
+
362
+ choice_data
363
+ end
364
+
365
+ def parse_choice_tool_calls(tool_calls)
366
+ return unless tool_calls
367
+
368
+ tool_calls.map do |tool_call|
369
+ tool_call_hash = model_dump(tool_call)
370
+ next tool_call_hash unless tool_call_hash[:type] == :function && tool_call_hash[:function]
371
+
372
+ function = tool_call_hash[:function]
373
+ parsed_args = parse_function_tool_arguments(function)
374
+ function[:parsed] = parsed_args if parsed_args
375
+
376
+ tool_call_hash
377
+ end
378
+ end
379
+
380
+ def parseable_input?
381
+ @response_format || @input_tools.any?
382
+ end
383
+
384
+ def model_dump(obj)
385
+ if obj.is_a?(OpenAI::Internal::Type::BaseModel)
386
+ obj.deep_to_h
387
+ elsif obj.respond_to?(:to_h)
388
+ obj.to_h
389
+ else
390
+ obj
391
+ end
392
+ end
393
+
394
+ def find_input_tool(name)
395
+ @input_tools.find { |tool| tool.dig(:function, :name) == name }
396
+ end
397
+
398
+ def parse_function_tool_arguments(function)
399
+ return nil unless function[:arguments]
400
+
401
+ input_tool = find_input_tool(function[:name])
402
+ return nil unless input_tool&.dig(:function, :strict)
403
+
404
+ parsed = JSON.parse(function[:arguments], symbolize_names: true)
405
+ return nil unless parsed
406
+
407
+ model_class = input_tool[:model] || input_tool.dig(:function, :parameters)
408
+ if model_class.is_a?(Class)
409
+ OpenAI::Internal::Type::Converter.coerce(model_class, parsed)
410
+ else
411
+ parsed
412
+ end
413
+ rescue JSON::ParserError
414
+ nil
415
+ end
416
+
417
+ def parse_content(response_format, message)
418
+ return nil unless message.content && !message.refusal
419
+
420
+ parsed = JSON.parse(message.content, symbolize_names: true)
421
+ return nil unless parsed
422
+
423
+ if response_format.is_a?(Class)
424
+ OpenAI::Internal::Type::Converter.coerce(response_format, parsed)
425
+ else
426
+ parsed
427
+ end
428
+ rescue JSON::ParserError
429
+ nil
430
+ end
431
+
432
+ def convert_initial_chunk_into_snapshot(chunk)
433
+ data = chunk.to_h
434
+
435
+ choices = []
436
+ chunk.choices.each do |choice|
437
+ choice_hash = choice.to_h
438
+ delta_hash = choice.delta.to_h
439
+
440
+ message_data = delta_hash.dup
441
+ message_data[:role] ||= :assistant
442
+
443
+ choice_data = {
444
+ index: choice_hash[:index],
445
+ message: message_data,
446
+ finish_reason: choice_hash[:finish_reason],
447
+ logprobs: choice_hash[:logprobs]
448
+ }
449
+ choices << choice_data
450
+ end
451
+
452
+ OpenAI::Internal::Type::Converter.coerce(
453
+ OpenAI::Chat::ParsedChatCompletion,
454
+ {
455
+ id: data[:id],
456
+ object: :"chat.completion",
457
+ created: data[:created],
458
+ model: data[:model],
459
+ choices: choices,
460
+ usage: data[:usage],
461
+ system_fingerprint: nil,
462
+ service_tier: data[:service_tier]
463
+ }
464
+ )
465
+ end
466
+
467
+ def accumulate_delta(acc, delta)
468
+ return acc if delta.nil?
469
+
470
+ delta.each do |key, delta_value| # rubocop:disable Metrics/BlockLength
471
+ key = key.to_sym if key.is_a?(String)
472
+
473
+ unless acc.key?(key)
474
+ acc[key] = delta_value
475
+ next
476
+ end
477
+
478
+ acc_value = acc[key]
479
+ if acc_value.nil?
480
+ acc[key] = delta_value
481
+ next
482
+ end
483
+
484
+ # Special properties that should be replaced, not accumulated.
485
+ if [:index, :type, :parsed].include?(key)
486
+ acc[key] = delta_value
487
+ next
488
+ end
489
+
490
+ if acc_value.is_a?(String) && delta_value.is_a?(String)
491
+ acc[key] = acc_value + delta_value
492
+ elsif acc_value.is_a?(Numeric) && delta_value.is_a?(Numeric) # rubocop:disable Lint/DuplicateBranch
493
+ acc[key] = acc_value + delta_value
494
+ elsif acc_value.is_a?(Hash) && delta_value.is_a?(Hash)
495
+ acc[key] = accumulate_delta(acc_value, delta_value)
496
+ elsif acc_value.is_a?(Array) && delta_value.is_a?(Array)
497
+ if acc_value.all? { |x| x.is_a?(String) || x.is_a?(Numeric) }
498
+ acc_value.concat(delta_value)
499
+ next
500
+ end
501
+
502
+ delta_value.each do |delta_entry|
503
+ unless delta_entry.is_a?(Hash)
504
+ raise TypeError,
505
+ "Unexpected list delta entry is not a hash: #{delta_entry}"
506
+ end
507
+
508
+ index = delta_entry[:index] || delta_entry["index"]
509
+ if index.nil?
510
+ raise RuntimeError,
511
+ "Expected list delta entry to have an `index` key; #{delta_entry}"
512
+ end
513
+ unless index.is_a?(Integer)
514
+ raise TypeError,
515
+ "Unexpected, list delta entry `index` value is not an integer; #{index}"
516
+ end
517
+
518
+ if acc_value[index].nil?
519
+ acc_value[index] = delta_entry
520
+ elsif acc_value[index].is_a?(Hash)
521
+ acc_value[index] = accumulate_delta(acc_value[index], delta_entry)
522
+ end
523
+ end
524
+ else
525
+ acc[key] = acc_value
526
+ end
527
+ end
528
+
529
+ acc
530
+ end
531
+ end
532
+
533
+ class ChoiceEventState
534
+ def initialize(input_tools:)
535
+ @input_tools = Array(input_tools)
536
+ @content_done = false
537
+ @refusal_done = false
538
+ @logprobs_content_done = false
539
+ @logprobs_refusal_done = false
540
+ @done_tool_calls = Set.new
541
+ @current_tool_call_index = nil
542
+ end
543
+
544
+ def get_done_events(choice_chunk:, choice_snapshot:, response_format:)
545
+ events = []
546
+
547
+ if choice_snapshot.finish_reason
548
+ events.concat(content_done_events(choice_snapshot, response_format))
549
+
550
+ if @current_tool_call_index && !@done_tool_calls.include?(@current_tool_call_index)
551
+ event = tool_done_event(choice_snapshot, @current_tool_call_index)
552
+ events << event if event
553
+ end
554
+ end
555
+
556
+ Array(choice_chunk.delta.tool_calls).each do |tool_call|
557
+ if @current_tool_call_index != tool_call.index
558
+ events.concat(content_done_events(choice_snapshot, response_format))
559
+
560
+ if @current_tool_call_index
561
+ event = tool_done_event(choice_snapshot, @current_tool_call_index)
562
+ events << event if event
563
+ end
564
+ end
565
+
566
+ @current_tool_call_index = tool_call.index
567
+ end
568
+
569
+ events
570
+ end
571
+
572
+ private
573
+
574
+ def content_done_events(choice_snapshot, response_format)
575
+ events = []
576
+
577
+ if choice_snapshot.message.content && !@content_done
578
+ @content_done = true
579
+ parsed = parse_content(choice_snapshot.message, response_format)
580
+ choice_snapshot.message.parsed = parsed
581
+
582
+ events << ChatContentDoneEvent.new(
583
+ type: :"content.done",
584
+ content: choice_snapshot.message.content,
585
+ parsed: parsed
586
+ )
587
+ end
588
+
589
+ if choice_snapshot.message.refusal && !@refusal_done
590
+ @refusal_done = true
591
+ events << ChatRefusalDoneEvent.new(
592
+ type: :"refusal.done",
593
+ refusal: choice_snapshot.message.refusal
594
+ )
595
+ end
596
+
597
+ events + logprobs_done_events(choice_snapshot)
598
+ end
599
+
600
+ def logprobs_done_events(choice_snapshot)
601
+ events = []
602
+ logprobs = choice_snapshot.logprobs
603
+ return events unless logprobs
604
+
605
+ if logprobs.content&.any? && !@logprobs_content_done
606
+ @logprobs_content_done = true
607
+ events << ChatLogprobsContentDoneEvent.new(
608
+ type: :"logprobs.content.done",
609
+ content: logprobs.content
610
+ )
611
+ end
612
+
613
+ if logprobs.refusal&.any? && !@logprobs_refusal_done
614
+ @logprobs_refusal_done = true
615
+ events << ChatLogprobsRefusalDoneEvent.new(
616
+ type: :"logprobs.refusal.done",
617
+ refusal: logprobs.refusal
618
+ )
619
+ end
620
+
621
+ events
622
+ end
623
+
624
+ def tool_done_event(choice_snapshot, tool_index)
625
+ return nil if @done_tool_calls.include?(tool_index)
626
+
627
+ @done_tool_calls.add(tool_index)
628
+
629
+ tool_call = choice_snapshot.message.tool_calls&.[](tool_index)
630
+ return nil unless tool_call&.type == :function
631
+
632
+ parsed_args = parse_function_tool_arguments(tool_call.function)
633
+
634
+ if tool_call.function.respond_to?(:parsed=)
635
+ tool_call.function.parsed = parsed_args
636
+ end
637
+
638
+ ChatFunctionToolCallArgumentsDoneEvent.new(
639
+ type: :"tool_calls.function.arguments.done",
640
+ index: tool_index,
641
+ name: tool_call.function.name,
642
+ arguments: tool_call.function.arguments,
643
+ parsed: parsed_args
644
+ )
645
+ end
646
+
647
+ def parse_content(message, response_format)
648
+ return nil unless response_format && message.content
649
+
650
+ parsed = JSON.parse(message.content, symbolize_names: true)
651
+ if response_format.is_a?(Class)
652
+ OpenAI::Internal::Type::Converter.coerce(response_format, parsed)
653
+ else
654
+ parsed
655
+ end
656
+ rescue JSON::ParserError
657
+ nil
658
+ end
659
+
660
+ def parse_function_tool_arguments(function)
661
+ return nil unless function.arguments
662
+
663
+ tool = find_input_tool(function.name)
664
+ return nil unless tool&.dig(:function, :strict)
665
+
666
+ parsed = JSON.parse(function.arguments, symbolize_names: true)
667
+
668
+ if tool[:model]
669
+ OpenAI::Internal::Type::Converter.coerce(tool[:model], parsed)
670
+ else
671
+ parsed
672
+ end
673
+ rescue JSON::ParserError
674
+ nil
675
+ end
676
+
677
+ def find_input_tool(name)
678
+ @input_tools.find { |tool| tool.dig(:function, :name) == name }
679
+ end
680
+ end
681
+ end
682
+ end
683
+ end