omniai-anthropic 1.6.3 → 1.8.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 41fa9f2696b2457c4fdadf958b6e89c91555b4fad07b424455def24687fc77e5
4
- data.tar.gz: 8e274f3203deddafc11be3b34d4947a5c95e0591180a605e09a54c4d3a4e8fe5
3
+ metadata.gz: 692299f3cf52f0f131202156f1fc5d0f973ec0488eff51f1e0186d3e88befa5d
4
+ data.tar.gz: 2d77787b55f8aa59acc9a33c92627451089fc63c988723963f0bf967dc30df40
5
5
  SHA512:
6
- metadata.gz: 895ef4d62346311990d7ccb6d096321071414b492a8e7f2ed52bcfa49017da89b6db31ab66e5cee299dded336e28ef41e50e70478d3977422a0a1f5751ea3326
7
- data.tar.gz: 852a864614cf0c6495820e830ef6159b8d5856afccdaad9dd664571ddcb18d24b9e7c24d80d15c1d085c850424c2213b54d60031aebd8d393c9210bcfd470bc7
6
+ metadata.gz: 97364e1f0f92ac5984dadaa4ff3e19b71463f2bc8acccc939b2f215e3d28daad8f4032d0b251f9253932134dc4c6808e71c80ab796f5c325c6108e4d76535019
7
+ data.tar.gz: 5697f26b5817a949aa4d5cf9e760de598a48c4ed298f9822bd659738ed3e0631b1cbaa98d8bad019e66501f6b331896e637a9f13608a07aa433be64749e54fff
data/README.md CHANGED
@@ -40,7 +40,7 @@ A chat completion is generated by passing in prompts using any a variety of form
40
40
 
41
41
  ```ruby
42
42
  completion = client.chat('Tell me a joke!')
43
- completion.choice.message.content # 'Why did the chicken cross the road? To get to the other side.'
43
+ completion.text # 'Why did the chicken cross the road? To get to the other side.'
44
44
  ```
45
45
 
46
46
  ```ruby
@@ -48,7 +48,7 @@ completion = client.chat do |prompt|
48
48
  prompt.system('You are a helpful assistant.')
49
49
  prompt.user('What is the capital of Canada?')
50
50
  end
51
- completion.choice.message.content # 'The capital of Canada is Ottawa.'
51
+ completion.text # 'The capital of Canada is Ottawa.'
52
52
  ```
53
53
 
54
54
  #### Model
@@ -57,7 +57,7 @@ completion.choice.message.content # 'The capital of Canada is Ottawa.'
57
57
 
58
58
  ```ruby
59
59
  completion = client.chat('Provide code for fibonacci', model: OmniAI::Anthropic::Chat::Model::CLAUDE_SONNET)
60
- completion.choice.message.content # 'def fibonacci(n)...end'
60
+ completion.text # 'def fibonacci(n)...end'
61
61
  ```
62
62
 
63
63
  [Anthropic API Reference `model`](https://docs.anthropic.com/en/api/messages)
@@ -68,7 +68,7 @@ completion.choice.message.content # 'def fibonacci(n)...end'
68
68
 
69
69
  ```ruby
70
70
  completion = client.chat('Pick a number between 1 and 5', temperature: 1.0)
71
- completion.choice.message.content # '3'
71
+ completion.text # '3'
72
72
  ```
73
73
 
74
74
  [Anthropic API Reference `temperature`](https://docs.anthropic.com/en/api/messages)
@@ -79,7 +79,7 @@ completion.choice.message.content # '3'
79
79
 
80
80
  ```ruby
81
81
  stream = proc do |chunk|
82
- print(chunk.choice.delta.content) # 'Better', 'three', 'hours', ...
82
+ print(chunk.text) # 'Better', 'three', 'hours', ...
83
83
  end
84
84
  client.chat('Be poetic.', stream:)
85
85
  ```
@@ -94,7 +94,7 @@ client.chat('Be poetic.', stream:)
94
94
  completion = client.chat(format: :json) do |prompt|
95
95
  prompt.system(OmniAI::Chat::JSON_PROMPT)
96
96
  prompt.user('What is the name of the drummer for the Beatles?')
97
- JSON.parse(completion.choice.message.content) # { "name": "Ringo" }
97
+ JSON.parse(completion.text) # { "name": "Ringo" }
98
98
  ```
99
99
 
100
100
  [Anthropic API Reference `control-output-format`](https://docs.anthropic.com/en/docs/control-output-format)
@@ -0,0 +1,25 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides choice serialize / deserialize.
7
+ module ChoiceSerializer
8
+ # @param choice [OmniAI::Chat::Choice]
9
+ # @param context [Context]
10
+ # @return [Hash]
11
+ def self.serialize(choice, context:)
12
+ choice.message.serialize(context:)
13
+ end
14
+
15
+ # @param data [Hash]
16
+ # @param context [Context]
17
+ # @return [OmniAI::Chat::Choice]
18
+ def self.deserialize(data, context:)
19
+ message = OmniAI::Chat::Message.deserialize(data, context:)
20
+ OmniAI::Chat::Choice.new(message:)
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides content serialize / deserialize.
7
+ module ContentSerializer
8
+ # @param data [Hash]
9
+ # @param context [Context]
10
+ # @return [OmniAI::Chat::Text, OmniAI::Chat::ToolCall]
11
+ def self.deserialize(data, context:)
12
+ case data['type']
13
+ when 'text' then OmniAI::Chat::Text.deserialize(data, context:)
14
+ when 'tool_use' then OmniAI::Chat::ToolCall.deserialize(data, context:)
15
+ end
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides function serialize / deserialize.
7
+ module FunctionSerializer
8
+ # @param function [OmniAI::Chat::Function]
9
+ # @return [Hash]
10
+ def self.serialize(function, *)
11
+ {
12
+ name: function.name,
13
+ input: function.arguments,
14
+ }
15
+ end
16
+
17
+ # @param data [Hash]
18
+ # @return [OmniAI::Chat::Function]
19
+ def self.deserialize(data, *)
20
+ name = data['name']
21
+ arguments = data['input']
22
+ OmniAI::Chat::Function.new(name:, arguments:)
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides media serialize / deserialize.
7
+ module MediaSerializer
8
+ # @param payload [OmniAI::Chat::Media]
9
+ # @return [Hash]
10
+ def self.serialize(media, *)
11
+ {
12
+ type: media.kind, # i.e. 'image' / 'video' / 'audio' / ...
13
+ source: {
14
+ type: 'base64',
15
+ media_type: media.type, # i.e. 'image/jpeg' / 'video/ogg' / 'audio/mpeg' / ...
16
+ data: media.data,
17
+ },
18
+ }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,49 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides message serialize / deserialize.
7
+ module MessageSerializer
8
+ # @param message [OmniAI::Chat::Message]
9
+ # @param context [OmniAI::Context]
10
+ # @return [Hash]
11
+ def self.serialize(message, context:)
12
+ role = message.role
13
+ parts = arrayify(message.content) + arrayify(message.tool_call_list)
14
+ content = parts.map do |part|
15
+ case part
16
+ when String then { type: 'text', text: part }
17
+ else part.serialize(context:)
18
+ end
19
+ end
20
+
21
+ { role:, content: }
22
+ end
23
+
24
+ # @param data [Hash]
25
+ # @param context [OmniAI::Context]
26
+ # @return [OmniAI::Chat::Message]
27
+ def self.deserialize(data, context:)
28
+ role = data['role']
29
+ parts = arrayify(data['content']).map do |content|
30
+ ContentSerializer.deserialize(content, context:)
31
+ end
32
+
33
+ tool_call_list = parts.select { |part| part.is_a?(OmniAI::Chat::ToolCall) }
34
+ content = parts.reject { |part| part.is_a?(OmniAI::Chat::ToolCall) }
35
+
36
+ OmniAI::Chat::Message.new(content:, role:, tool_call_list:)
37
+ end
38
+
39
+ # @param content [Object]
40
+ # @return [Array<Object>]
41
+ def self.arrayify(content)
42
+ return [] if content.nil?
43
+
44
+ content.is_a?(Array) ? content : [content]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,30 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides payload serialize / deserialize.
7
+ module PayloadSerializer
8
+ # @param payload [OmniAI::Chat::Payload]
9
+ # @param context [OmniAI::Context]
10
+ # @return [Hash]
11
+ def self.serialize(payload, context:)
12
+ usage = payload.usage.serialize(context:)
13
+ choice = payload.choice.serialize(context:)
14
+
15
+ choice.merge({ usage: })
16
+ end
17
+
18
+ # @param data [Hash]
19
+ # @param context [OmniAI::Context]
20
+ # @return [OmniAI::Chat::Payload]
21
+ def self.deserialize(data, context:)
22
+ usage = OmniAI::Chat::Usage.deserialize(data['usage'], context:) if data['usage']
23
+ choice = OmniAI::Chat::Choice.deserialize(data, context:)
24
+
25
+ OmniAI::Chat::Payload.new(choices: [choice], usage:)
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,94 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # A stream given when streaming.
7
+ class Stream < OmniAI::Chat::Stream
8
+ module Type
9
+ PING = 'ping'
10
+ MESSAGE_START = 'message_start'
11
+ MESSAGE_STOP = 'message_stop'
12
+ MESSAGE_DELTA = 'message_delta'
13
+ CONTENT_BLOCK_START = 'content_block_start'
14
+ CONTENT_BLOCK_STOP = 'content_block_stop'
15
+ CONTENT_BLOCK_DELTA = 'content_block_delta'
16
+ end
17
+
18
+ # Process the stream into chunks by event.
19
+ class Builder
20
+ # @return [OmniAI::Chat::Payload, nil]
21
+ def payload(context:)
22
+ return unless @content
23
+
24
+ OmniAI::Chat::Payload.deserialize(@message.merge({
25
+ 'content' => @content,
26
+ }), context:)
27
+ end
28
+
29
+ # Handler for Type::MESSAGE_START
30
+ #
31
+ # @param data [Hash]
32
+ def message_start(data)
33
+ @message = data['message']
34
+ end
35
+
36
+ # Handler for Type::MESSAGE_STOP
37
+ #
38
+ # @param _data [Hash]
39
+ def message_stop(_data)
40
+ @message = nil
41
+ end
42
+
43
+ # Handler for Type::CONTENT_BLOCK_START
44
+ #
45
+ # @param data [Hash]
46
+ def content_block_start(_data)
47
+ @content = nil
48
+ end
49
+
50
+ # Handler for Type::CONTENT_BLOCK_STOP
51
+ #
52
+ # @param _data [Hash]
53
+ def content_block_stop(_data)
54
+ @content = nil
55
+ end
56
+
57
+ # Handler for Type::CONTENT_BLOCK_DELTA
58
+ #
59
+ # @param data [Hash]
60
+ def content_block_delta(data)
61
+ @content = [{ 'type' => 'text', 'text' => data['delta']['text'] }]
62
+ end
63
+ end
64
+
65
+ protected
66
+
67
+ def builder
68
+ @builder ||= Builder.new
69
+ end
70
+
71
+ # @param type [String]
72
+ # @param data [Hash]
73
+ # @param builder [Builder]
74
+ def process!(type, data, id, &block)
75
+ log(type, data, id)
76
+
77
+ data = JSON.parse(data)
78
+
79
+ case type
80
+ when Type::MESSAGE_START then builder.message_start(data)
81
+ when Type::CONTENT_BLOCK_START then builder.content_block_start(data)
82
+ when Type::CONTENT_BLOCK_STOP then builder.content_block_stop(data)
83
+ when Type::MESSAGE_STOP then builder.message_stop(data)
84
+ when Type::CONTENT_BLOCK_DELTA
85
+ builder.content_block_delta(data)
86
+
87
+ payload = builder.payload(context: @context)
88
+ block.call(payload) if payload
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
94
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides text serialize / deserialize.
7
+ module TextSerializer
8
+ # @param text [OmniAI::Chat::Text]
9
+ # @return [Hash]
10
+ def self.serialize(text, *)
11
+ { type: 'text', text: text.text }
12
+ end
13
+
14
+ # @param data [Hash]
15
+ # @return [OmniAI::Chat::Text]
16
+ def self.deserialize(data, *)
17
+ OmniAI::Chat::Text.new(data['text'])
18
+ end
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides tool-call response serialize / deserialize.
7
+ module ToolCallResultSerializer
8
+ # @param tool_call_result [OmniAI::Chat::ToolCallResult]
9
+ # @return [Hash]
10
+ def self.serialize(tool_call_result, *)
11
+ {
12
+ type: 'tool_result',
13
+ tool_use_id: tool_call_result.tool_call_id,
14
+ content: tool_call_result.content,
15
+ }
16
+ end
17
+
18
+ # @param data [Hash]
19
+ # @return [OmniAI::Chat::ToolCallResult]
20
+ def self.deserialize(data, *)
21
+ tool_call_id = data['tool_use_id']
22
+ content = data['content']
23
+
24
+ OmniAI::Chat::ToolCallResult.new(content:, tool_call_id:)
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides tool-call serialize / deserialize.
7
+ module ToolCallSerializer
8
+ # @param tool_call [OmniAI::Chat::ToolCall]
9
+ # @param context [OmniAI::Context]
10
+ # @return [Hash]
11
+ def self.serialize(tool_call, context:)
12
+ function = tool_call.function.serialize(context:)
13
+ {
14
+ id: tool_call.id,
15
+ type: 'tool_use',
16
+ }.merge(function)
17
+ end
18
+
19
+ # @param data [Hash]
20
+ # @param context [OmniAI::Context]
21
+ # @return [OmniAI::Chat::ToolCall]
22
+ def self.deserialize(data, context:)
23
+ function = OmniAI::Chat::Function.deserialize(data, context:)
24
+ OmniAI::Chat::ToolCall.new(id: data['id'], function:)
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OmniAI
4
+ module Anthropic
5
+ class Chat
6
+ # Overrides tool serialize / deserialize.
7
+ module ToolSerializer
8
+ # @param tool [OmniAI::Tool]
9
+ # @return [Hash]
10
+ def self.serialize(tool, *)
11
+ {
12
+ name: tool.name,
13
+ description: tool.description,
14
+ input_schema: tool.parameters.is_a?(Tool::Parameters) ? tool.parameters.serialize : tool.parameters,
15
+ }.compact
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -10,7 +10,7 @@ module OmniAI
10
10
  # prompt.system('You are an expert in the field of AI.')
11
11
  # prompt.user('What are the biggest risks of AI?')
12
12
  # end
13
- # completion.choice.message.content # '...'
13
+ # completion.text # '...'
14
14
  class Chat < OmniAI::Chat
15
15
  module Model
16
16
  CLAUDE_INSTANT_1_0 = 'claude-instant-1.2'
@@ -27,26 +27,30 @@ module OmniAI
27
27
 
28
28
  DEFAULT_MODEL = Model::CLAUDE_SONNET
29
29
 
30
- # @param [Media]
31
- # @return [Hash]
32
- # @example
33
- # media = Media.new(...)
34
- # MEDIA_SERIALIZER.call(media)
35
- MEDIA_SERIALIZER = lambda do |media, *|
36
- {
37
- type: media.kind, # i.e. 'image' / 'video' / 'audio' / ...
38
- source: {
39
- type: 'base64',
40
- media_type: media.type, # i.e. 'image/jpeg' / 'video/ogg' / 'audio/mpeg' / ...
41
- data: media.data,
42
- },
43
- }
44
- end
45
-
46
30
  # @return [Context]
47
31
  CONTEXT = Context.build do |context|
48
- context.serializers[:file] = MEDIA_SERIALIZER
49
- context.serializers[:url] = MEDIA_SERIALIZER
32
+ context.serializers[:tool] = ToolSerializer.method(:serialize)
33
+
34
+ context.serializers[:file] = MediaSerializer.method(:serialize)
35
+ context.serializers[:url] = MediaSerializer.method(:serialize)
36
+
37
+ context.serializers[:choice] = ChoiceSerializer.method(:serialize)
38
+ context.deserializers[:choice] = ChoiceSerializer.method(:deserialize)
39
+
40
+ context.serializers[:tool_call] = ToolCallSerializer.method(:serialize)
41
+ context.deserializers[:tool_call] = ToolCallSerializer.method(:deserialize)
42
+
43
+ context.serializers[:tool_call_result] = ToolCallResultSerializer.method(:serialize)
44
+ context.deserializers[:tool_call_result] = ToolCallResultSerializer.method(:deserialize)
45
+
46
+ context.serializers[:function] = FunctionSerializer.method(:serialize)
47
+ context.deserializers[:function] = FunctionSerializer.method(:deserialize)
48
+
49
+ context.serializers[:message] = MessageSerializer.method(:serialize)
50
+ context.deserializers[:message] = MessageSerializer.method(:deserialize)
51
+
52
+ context.deserializers[:content] = ContentSerializer.method(:deserialize)
53
+ context.deserializers[:payload] = PayloadSerializer.method(:deserialize)
50
54
  end
51
55
 
52
56
  # @return [Hash]
@@ -63,14 +67,16 @@ module OmniAI
63
67
 
64
68
  # @return [Array<Hash>]
65
69
  def messages
66
- messages = @prompt.messages.filter(&:user?)
67
- messages.map { |message| message.serialize(context: CONTEXT) }
70
+ messages = @prompt.messages.reject(&:system?)
71
+ messages.map { |message| message.serialize(context:) }
68
72
  end
69
73
 
70
74
  # @return [String, nil]
71
75
  def system
72
76
  messages = @prompt.messages.filter(&:system?)
73
- messages.map(&:content).join("\n\n") if messages.any?
77
+ return if messages.empty?
78
+
79
+ messages.filter(&:text?).map(&:text).join("\n\n")
74
80
  end
75
81
 
76
82
  # @return [String]
@@ -78,17 +84,27 @@ module OmniAI
78
84
  "/#{Client::VERSION}/messages"
79
85
  end
80
86
 
87
+ protected
88
+
89
+ # @return [Context]
90
+ def context
91
+ CONTEXT
92
+ end
93
+
94
+ # @return [Array<Message>]
95
+ def build_tool_call_messages(tool_call_list)
96
+ content = tool_call_list.map do |tool_call|
97
+ ToolCallResult.new(tool_call_id: tool_call.id, content: execute_tool_call(tool_call))
98
+ end
99
+
100
+ [Message.new(role: OmniAI::Chat::Role::USER, content:)]
101
+ end
102
+
81
103
  private
82
104
 
83
105
  # @return [Array<Hash>, nil]
84
106
  def tools_payload
85
- @tools&.map do |tool|
86
- {
87
- name: tool.name,
88
- description: tool.description,
89
- input_schema: tool.parameters&.prepare,
90
- }.compact
91
- end
107
+ @tools.map { |tool| tool.serialize(context:) } if @tools&.any?
92
108
  end
93
109
  end
94
110
  end
@@ -2,6 +2,6 @@
2
2
 
3
3
  module OmniAI
4
4
  module Anthropic
5
- VERSION = '1.6.3'
5
+ VERSION = '1.8.1'
6
6
  end
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: omniai-anthropic
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.6.3
4
+ version: 1.8.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Kevin Sylvestre
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-07-19 00:00:00.000000000 Z
11
+ date: 2024-08-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: event_stream_parser
@@ -63,8 +63,17 @@ files:
63
63
  - README.md
64
64
  - lib/omniai/anthropic.rb
65
65
  - lib/omniai/anthropic/chat.rb
66
- - lib/omniai/anthropic/chat/response/completion.rb
67
- - lib/omniai/anthropic/chat/response/stream.rb
66
+ - lib/omniai/anthropic/chat/choice_serializer.rb
67
+ - lib/omniai/anthropic/chat/content_serializer.rb
68
+ - lib/omniai/anthropic/chat/function_serializer.rb
69
+ - lib/omniai/anthropic/chat/media_serializer.rb
70
+ - lib/omniai/anthropic/chat/message_serializer.rb
71
+ - lib/omniai/anthropic/chat/payload_serializer.rb
72
+ - lib/omniai/anthropic/chat/stream.rb
73
+ - lib/omniai/anthropic/chat/text_serializer.rb
74
+ - lib/omniai/anthropic/chat/tool_call_result_serializer.rb
75
+ - lib/omniai/anthropic/chat/tool_call_serializer.rb
76
+ - lib/omniai/anthropic/chat/tool_serializer.rb
68
77
  - lib/omniai/anthropic/client.rb
69
78
  - lib/omniai/anthropic/config.rb
70
79
  - lib/omniai/anthropic/version.rb
@@ -1,29 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module OmniAI
4
- module Anthropic
5
- class Chat
6
- module Response
7
- # A completion returned by the API.
8
- class Completion < OmniAI::Chat::Response::Completion
9
- # @return [Array<OmniAI::Chat::Response::MessageChoice>]
10
- def choices
11
- @choices ||= begin
12
- role = @data['role']
13
-
14
- @data['content'].map do |data, index|
15
- OmniAI::Chat::Response::MessageChoice.new(data: {
16
- 'index' => index,
17
- 'message' => {
18
- 'role' => role,
19
- 'content' => data['text'],
20
- },
21
- })
22
- end
23
- end
24
- end
25
- end
26
- end
27
- end
28
- end
29
- end
@@ -1,111 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module OmniAI
4
- module Anthropic
5
- class Chat
6
- module Response
7
- # A stream given when streaming.
8
- class Stream < OmniAI::Chat::Response::Stream
9
- module Type
10
- PING = 'ping'
11
- MESSAGE_START = 'message_start'
12
- MESSAGE_STOP = 'message_stop'
13
- MESSAGE_DELTA = 'message_delta'
14
- CONTENT_BLOCK_START = 'content_block_start'
15
- CONTENT_BLOCK_STOP = 'content_block_stop'
16
- CONTENT_BLOCK_DELTA = 'content_block_delta'
17
- end
18
-
19
- # Process the stream into chunks by event.
20
- class Builder
21
- attr_reader :id, :model, :role, :content, :index
22
-
23
- # @return [OmniAI::Chat::Chunk]
24
- def chunk
25
- OmniAI::Chat::Response::Chunk.new(data: {
26
- 'id' => @id,
27
- 'model' => @model,
28
- 'choices' => [{
29
- 'index' => @index,
30
- 'delta' => {
31
- 'role' => @role,
32
- 'content' => @content,
33
- },
34
- }],
35
- })
36
- end
37
-
38
- # Handler for Type::MESSAGE_START
39
- #
40
- # @param data [Hash]
41
- def message_start(data)
42
- @id = data['id']
43
- @model = data['model']
44
- @role = data['role']
45
- end
46
-
47
- # Handler for Type::MESSAGE_STOP
48
- #
49
- # @param _data [Hash]
50
- def message_stop(_data)
51
- @id = nil
52
- @model = nil
53
- @role = nil
54
- end
55
-
56
- # Handler for Type::CONTENT_BLOCK_START
57
- #
58
- # @param data [Hash]
59
- def content_block_start(data)
60
- @index = data['index']
61
- end
62
-
63
- # Handler for Type::CONTENT_BLOCK_STOP
64
- #
65
- # @param _data [Hash]
66
- def content_block_stop(_data)
67
- @index = nil
68
- end
69
-
70
- # Handler for Type::CONTENT_BLOCK_DELTA
71
- #
72
- # @param data [Hash]
73
- def content_block_delta(data)
74
- return unless data['delta']['type'].eql?('text_delta')
75
-
76
- @content = data['delta']['text']
77
- end
78
- end
79
-
80
- # @yield [OmniAI::Chat::Chunk]
81
- def stream!(&block)
82
- builder = Builder.new
83
-
84
- @response.body.each do |chunk|
85
- @parser.feed(chunk) do |type, data|
86
- process(type:, data: JSON.parse(data), builder:, &block)
87
- end
88
- end
89
- end
90
-
91
- private
92
-
93
- # @param type [String]
94
- # @param data [Hash]
95
- # @param builder [Builder]
96
- def process(type:, data:, builder:, &)
97
- case type
98
- when Type::MESSAGE_START then builder.message_start(data)
99
- when Type::CONTENT_BLOCK_START then builder.content_block_start(data)
100
- when Type::CONTENT_BLOCK_STOP then builder.content_block_stop(data)
101
- when Type::MESSAGE_STOP then builder.message_stop(data)
102
- when Type::CONTENT_BLOCK_DELTA
103
- builder.content_block_delta(data)
104
- yield(builder.chunk)
105
- end
106
- end
107
- end
108
- end
109
- end
110
- end
111
- end