durable-llm 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +5 -0
- data/Gemfile +20 -0
- data/Gemfile.lock +102 -0
- data/LICENSE.txt +21 -0
- data/README.md +96 -0
- data/Rakefile +16 -0
- data/enterprise/GET_ENTERPRISE.md +40 -0
- data/lib/durable/llm/cli.rb +122 -0
- data/lib/durable/llm/client.rb +32 -0
- data/lib/durable/llm/configuration.rb +63 -0
- data/lib/durable/llm/errors.rb +33 -0
- data/lib/durable/llm/providers/anthropic.rb +164 -0
- data/lib/durable/llm/providers/base.rb +49 -0
- data/lib/durable/llm/providers/groq.rb +177 -0
- data/lib/durable/llm/providers/huggingface.rb +93 -0
- data/lib/durable/llm/providers/openai.rb +212 -0
- data/lib/durable/llm/providers.rb +43 -0
- data/lib/durable/llm/version.rb +7 -0
- data/lib/durable/llm.rb +23 -0
- data/sig/durable/llm.rbs +6 -0
- metadata +208 -0
@@ -0,0 +1,164 @@
|
|
1
|
+
|
2
|
+
require 'faraday'
|
3
|
+
require 'json'
|
4
|
+
require 'durable/llm/errors'
|
5
|
+
require 'durable/llm/providers/base'
|
6
|
+
|
7
|
+
module Durable
|
8
|
+
module Llm
|
9
|
+
module Providers
|
10
|
+
class Anthropic < Durable::Llm::Providers::Base
|
11
|
+
BASE_URL = 'https://api.anthropic.com'
|
12
|
+
|
13
|
+
def default_api_key
|
14
|
+
Durable::Llm.configuration.anthropic&.api_key || ENV['ANTHROPIC_API_KEY']
|
15
|
+
end
|
16
|
+
|
17
|
+
attr_accessor :api_key
|
18
|
+
|
19
|
+
def initialize(api_key: nil)
|
20
|
+
@api_key = api_key || default_api_key
|
21
|
+
@conn = Faraday.new(url: BASE_URL) do |faraday|
|
22
|
+
faraday.request :json
|
23
|
+
faraday.response :json
|
24
|
+
faraday.adapter Faraday.default_adapter
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
def completion(options)
|
29
|
+
response = @conn.post('/v1/chat/completions') do |req|
|
30
|
+
req.headers['x-api-key'] = @api_key
|
31
|
+
req.headers['anthropic-version'] = '2023-06-01'
|
32
|
+
req.body = options
|
33
|
+
end
|
34
|
+
|
35
|
+
handle_response(response)
|
36
|
+
end
|
37
|
+
|
38
|
+
def models
|
39
|
+
self.class.models
|
40
|
+
end
|
41
|
+
def self.models
|
42
|
+
['claude-3-opus', 'claude-3-sonnet', 'claude-3-haiku', 'claude-2.1', 'claude-2.0', 'claude-instant-1.2']
|
43
|
+
end
|
44
|
+
|
45
|
+
def self.stream?
|
46
|
+
true
|
47
|
+
end
|
48
|
+
def stream(options, &block)
|
49
|
+
options[:stream] = true
|
50
|
+
response = @conn.post('/v1/chat/completions') do |req|
|
51
|
+
req.headers['x-api-key'] = @api_key
|
52
|
+
req.headers['anthropic-version'] = '2023-06-01'
|
53
|
+
req.headers['Accept'] = 'text/event-stream'
|
54
|
+
req.body = options
|
55
|
+
req.options.on_data = Proc.new do |chunk, size, total|
|
56
|
+
next if chunk.strip.empty?
|
57
|
+
yield AnthropicStreamResponse.new(chunk) if chunk.start_with?('data: ')
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
handle_response(response)
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def handle_response(response)
|
67
|
+
case response.status
|
68
|
+
when 200..299
|
69
|
+
AnthropicResponse.new(response.body)
|
70
|
+
when 401
|
71
|
+
raise Durable::Llm::AuthenticationError, response.body.dig('error', 'message')
|
72
|
+
when 429
|
73
|
+
raise Durable::Llm::RateLimitError, response.body.dig('error', 'message')
|
74
|
+
when 400..499
|
75
|
+
raise Durable::Llm::InvalidRequestError, response.body.dig('error', 'message')
|
76
|
+
when 500..599
|
77
|
+
raise Durable::Llm::ServerError, response.body.dig('error', 'message')
|
78
|
+
else
|
79
|
+
raise Durable::Llm::APIError, "Unexpected response code: #{response.status}"
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
class AnthropicResponse
|
84
|
+
attr_reader :raw_response
|
85
|
+
|
86
|
+
def initialize(response)
|
87
|
+
@raw_response = response
|
88
|
+
end
|
89
|
+
|
90
|
+
def choices
|
91
|
+
[@raw_response['content']].map { |content| AnthropicChoice.new(content) }
|
92
|
+
end
|
93
|
+
|
94
|
+
def to_s
|
95
|
+
choices.map(&:to_s).join(' ')
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
class AnthropicChoice
|
100
|
+
attr_reader :message
|
101
|
+
|
102
|
+
def initialize(content)
|
103
|
+
@message = AnthropicMessage.new(content)
|
104
|
+
end
|
105
|
+
|
106
|
+
def to_s
|
107
|
+
@message.to_s
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
class AnthropicMessage
|
112
|
+
attr_reader :role, :content
|
113
|
+
|
114
|
+
def initialize(content)
|
115
|
+
@role = [content].flatten.map { |_| _['type']}.join(' ')
|
116
|
+
@content = [content].flatten.map { |_| _['text']}.join(' ')
|
117
|
+
end
|
118
|
+
|
119
|
+
def to_s
|
120
|
+
@content
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
class AnthropicStreamResponse
|
125
|
+
attr_reader :choices
|
126
|
+
|
127
|
+
def initialize(fragment)
|
128
|
+
parsed = JSON.parse(fragment.split("data: ").last)
|
129
|
+
@choices = [AnthropicStreamChoice.new(parsed['delta'])]
|
130
|
+
end
|
131
|
+
|
132
|
+
def to_s
|
133
|
+
@choices.map(&:to_s).join(' ')
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
class AnthropicStreamChoice
|
138
|
+
attr_reader :delta
|
139
|
+
|
140
|
+
def initialize(delta)
|
141
|
+
@delta = AnthropicStreamDelta.new(delta)
|
142
|
+
end
|
143
|
+
|
144
|
+
def to_s
|
145
|
+
@delta.to_s
|
146
|
+
end
|
147
|
+
end
|
148
|
+
|
149
|
+
class AnthropicStreamDelta
|
150
|
+
attr_reader :type, :text
|
151
|
+
|
152
|
+
def initialize(delta)
|
153
|
+
@type = delta['type']
|
154
|
+
@text = delta['text']
|
155
|
+
end
|
156
|
+
|
157
|
+
def to_s
|
158
|
+
@text || ''
|
159
|
+
end
|
160
|
+
end
|
161
|
+
end
|
162
|
+
end
|
163
|
+
end
|
164
|
+
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Durable
|
2
|
+
module Llm
|
3
|
+
module Providers
|
4
|
+
class Base
|
5
|
+
def default_api_key
|
6
|
+
raise NotImplementedError, "Subclasses must implement default_api_key"
|
7
|
+
end
|
8
|
+
|
9
|
+
attr_accessor :api_key
|
10
|
+
|
11
|
+
def initialize(api_key: nil)
|
12
|
+
@api_key = api_key || default_api_key
|
13
|
+
end
|
14
|
+
|
15
|
+
def completion(options)
|
16
|
+
raise NotImplementedError, "Subclasses must implement completion"
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.models
|
20
|
+
[]
|
21
|
+
end
|
22
|
+
def models
|
23
|
+
raise NotImplementedError, "Subclasses must implement models"
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.stream?
|
27
|
+
false
|
28
|
+
end
|
29
|
+
def stream?
|
30
|
+
self.class.stream?
|
31
|
+
end
|
32
|
+
|
33
|
+
def stream(options, &block)
|
34
|
+
raise NotImplementedError, "Subclasses must implement stream"
|
35
|
+
end
|
36
|
+
|
37
|
+
def embedding(model:, input:, **options)
|
38
|
+
raise NotImplementedError, "Subclasses must implement embedding"
|
39
|
+
end
|
40
|
+
|
41
|
+
private
|
42
|
+
|
43
|
+
def handle_response(response)
|
44
|
+
raise NotImplementedError, "Subclasses must implement handle_response"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -0,0 +1,177 @@
|
|
1
|
+
require 'faraday'
|
2
|
+
require 'json'
|
3
|
+
require 'durable/llm/errors'
|
4
|
+
require 'durable/llm/providers/base'
|
5
|
+
|
6
|
+
module Durable
|
7
|
+
module Llm
|
8
|
+
module Providers
|
9
|
+
class Groq < Durable::Llm::Providers::Base
|
10
|
+
BASE_URL = 'https://api.groq.com/openai/v1'
|
11
|
+
|
12
|
+
def default_api_key
|
13
|
+
Durable::Llm.configuration.groq&.api_key || ENV['GROQ_API_KEY']
|
14
|
+
end
|
15
|
+
|
16
|
+
attr_accessor :api_key
|
17
|
+
|
18
|
+
def self.conn
|
19
|
+
Faraday.new(url: BASE_URL) do |faraday|
|
20
|
+
faraday.request :json
|
21
|
+
faraday.response :json
|
22
|
+
faraday.adapter Faraday.default_adapter
|
23
|
+
end
|
24
|
+
end
|
25
|
+
def conn
|
26
|
+
self.class.conn
|
27
|
+
end
|
28
|
+
|
29
|
+
def initialize(api_key: nil)
|
30
|
+
@api_key = api_key || default_api_key
|
31
|
+
end
|
32
|
+
|
33
|
+
def completion(options)
|
34
|
+
response = conn.post('chat/completions') do |req|
|
35
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
36
|
+
req.body = options
|
37
|
+
end
|
38
|
+
|
39
|
+
handle_response(response)
|
40
|
+
end
|
41
|
+
|
42
|
+
def embedding(model:, input:, **options)
|
43
|
+
response = conn.post('embeddings') do |req|
|
44
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
45
|
+
req.body = { model: model, input: input, **options }
|
46
|
+
end
|
47
|
+
|
48
|
+
handle_response(response)
|
49
|
+
end
|
50
|
+
|
51
|
+
def models
|
52
|
+
|
53
|
+
|
54
|
+
response = conn.get('models') do |req|
|
55
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
56
|
+
end
|
57
|
+
|
58
|
+
resp = handle_response(response).to_h
|
59
|
+
|
60
|
+
resp['data'].map { |model| model['id'] }
|
61
|
+
end
|
62
|
+
def self.models
|
63
|
+
Groq.new.models
|
64
|
+
end
|
65
|
+
|
66
|
+
def self.stream?
|
67
|
+
false
|
68
|
+
end
|
69
|
+
|
70
|
+
private
|
71
|
+
|
72
|
+
def handle_response(response)
|
73
|
+
case response.status
|
74
|
+
when 200..299
|
75
|
+
GroqResponse.new(response.body)
|
76
|
+
when 401
|
77
|
+
raise Durable::Llm::AuthenticationError, response.body['error']['message']
|
78
|
+
when 429
|
79
|
+
raise Durable::Llm::RateLimitError, response.body['error']['message']
|
80
|
+
when 400..499
|
81
|
+
raise Durable::Llm::InvalidRequestError, response.body['error']['message']
|
82
|
+
when 500..599
|
83
|
+
raise Durable::Llm::ServerError, response.body['error']['message']
|
84
|
+
else
|
85
|
+
raise Durable::Llm::APIError, "Unexpected response code: #{response.status}"
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
class GroqResponse
|
90
|
+
attr_reader :raw_response
|
91
|
+
|
92
|
+
def initialize(response)
|
93
|
+
@raw_response = response
|
94
|
+
end
|
95
|
+
|
96
|
+
def choices
|
97
|
+
@raw_response['choices'].map { |choice| GroqChoice.new(choice) }
|
98
|
+
end
|
99
|
+
|
100
|
+
def to_s
|
101
|
+
choices.map(&:to_s).join(' ')
|
102
|
+
end
|
103
|
+
def to_h
|
104
|
+
@raw_response.dup
|
105
|
+
end
|
106
|
+
end
|
107
|
+
|
108
|
+
class GroqChoice
|
109
|
+
attr_reader :message, :finish_reason
|
110
|
+
|
111
|
+
def initialize(choice)
|
112
|
+
@message = GroqMessage.new(choice['message'])
|
113
|
+
@finish_reason = choice['finish_reason']
|
114
|
+
end
|
115
|
+
|
116
|
+
def to_s
|
117
|
+
@message.to_s
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
class GroqMessage
|
122
|
+
attr_reader :role, :content
|
123
|
+
|
124
|
+
def initialize(message)
|
125
|
+
@role = message['role']
|
126
|
+
@content = message['content']
|
127
|
+
end
|
128
|
+
|
129
|
+
def to_s
|
130
|
+
@content
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
class GroqStreamResponse
|
135
|
+
attr_reader :choices
|
136
|
+
|
137
|
+
def initialize(fragment)
|
138
|
+
json_frag = fragment.split("data: ").last.strip
|
139
|
+
puts json_frag
|
140
|
+
parsed = JSON.parse(json_frag)
|
141
|
+
@choices = parsed['choices'].map { |choice| GroqStreamChoice.new(choice) }
|
142
|
+
end
|
143
|
+
|
144
|
+
def to_s
|
145
|
+
@choices.map(&:to_s).join(' ')
|
146
|
+
end
|
147
|
+
end
|
148
|
+
|
149
|
+
class GroqStreamChoice
|
150
|
+
attr_reader :delta, :finish_reason
|
151
|
+
|
152
|
+
def initialize(choice)
|
153
|
+
@delta = GroqStreamDelta.new(choice['delta'])
|
154
|
+
@finish_reason = choice['finish_reason']
|
155
|
+
end
|
156
|
+
|
157
|
+
def to_s
|
158
|
+
@delta.to_s
|
159
|
+
end
|
160
|
+
end
|
161
|
+
|
162
|
+
class GroqStreamDelta
|
163
|
+
attr_reader :role, :content
|
164
|
+
|
165
|
+
def initialize(delta)
|
166
|
+
@role = delta['role']
|
167
|
+
@content = delta['content']
|
168
|
+
end
|
169
|
+
|
170
|
+
def to_s
|
171
|
+
@content || ''
|
172
|
+
end
|
173
|
+
end
|
174
|
+
end
|
175
|
+
end
|
176
|
+
end
|
177
|
+
end
|
@@ -0,0 +1,93 @@
|
|
1
|
+
require 'faraday'
|
2
|
+
require 'json'
|
3
|
+
require 'durable/llm/errors'
|
4
|
+
require 'durable/llm/providers/base'
|
5
|
+
|
6
|
+
module Durable
|
7
|
+
module Llm
|
8
|
+
module Providers
|
9
|
+
class Huggingface < Durable::Llm::Providers::Base
|
10
|
+
BASE_URL = 'https://api-inference.huggingface.co/models'
|
11
|
+
|
12
|
+
def default_api_key
|
13
|
+
Durable::Llm.configuration.huggingface&.api_key || ENV['HUGGINGFACE_API_KEY']
|
14
|
+
end
|
15
|
+
|
16
|
+
attr_accessor :api_key
|
17
|
+
|
18
|
+
def initialize(api_key: nil)
|
19
|
+
@api_key = api_key || default_api_key
|
20
|
+
@conn = Faraday.new(url: BASE_URL) do |faraday|
|
21
|
+
faraday.request :json
|
22
|
+
faraday.response :json
|
23
|
+
faraday.adapter Faraday.default_adapter
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def completion(options)
|
28
|
+
model = options.delete(:model) || 'gpt2'
|
29
|
+
response = @conn.post("/#{model}") do |req|
|
30
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
31
|
+
req.body = options
|
32
|
+
end
|
33
|
+
|
34
|
+
handle_response(response)
|
35
|
+
end
|
36
|
+
|
37
|
+
def models
|
38
|
+
self.class.models
|
39
|
+
end
|
40
|
+
def self.models
|
41
|
+
['gpt2', 'bert-base-uncased', 'distilbert-base-uncased'] # could use expansion
|
42
|
+
end
|
43
|
+
|
44
|
+
private
|
45
|
+
|
46
|
+
def handle_response(response)
|
47
|
+
case response.status
|
48
|
+
when 200..299
|
49
|
+
HuggingfaceResponse.new(response.body)
|
50
|
+
when 401
|
51
|
+
raise Durable::Llm::AuthenticationError, response.body['error']
|
52
|
+
when 429
|
53
|
+
raise Durable::Llm::RateLimitError, response.body['error']
|
54
|
+
when 400..499
|
55
|
+
raise Durable::Llm::InvalidRequestError, response.body['error']
|
56
|
+
when 500..599
|
57
|
+
raise Durable::Llm::ServerError, response.body['error']
|
58
|
+
else
|
59
|
+
raise Durable::Llm::APIError, "Unexpected response code: #{response.status}"
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class HuggingfaceResponse
|
64
|
+
attr_reader :raw_response
|
65
|
+
|
66
|
+
def initialize(response)
|
67
|
+
@raw_response = response
|
68
|
+
end
|
69
|
+
|
70
|
+
def choices
|
71
|
+
[@raw_response.first].map { |choice| HuggingfaceChoice.new(choice) }
|
72
|
+
end
|
73
|
+
|
74
|
+
def to_s
|
75
|
+
choices.map(&:to_s).join(' ')
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
class HuggingfaceChoice
|
80
|
+
attr_reader :text
|
81
|
+
|
82
|
+
def initialize(choice)
|
83
|
+
@text = choice['generated_text']
|
84
|
+
end
|
85
|
+
|
86
|
+
def to_s
|
87
|
+
@text
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
@@ -0,0 +1,212 @@
|
|
1
|
+
require 'faraday'
|
2
|
+
require 'json'
|
3
|
+
require 'durable/llm/errors'
|
4
|
+
require 'durable/llm/providers/base'
|
5
|
+
|
6
|
+
module Durable
|
7
|
+
module Llm
|
8
|
+
module Providers
|
9
|
+
class OpenAI < Durable::Llm::Providers::Base
|
10
|
+
BASE_URL = 'https://api.openai.com/v1'
|
11
|
+
|
12
|
+
def default_api_key
|
13
|
+
Durable::Llm.configuration.openai&.api_key || ENV['OPENAI_API_KEY']
|
14
|
+
end
|
15
|
+
|
16
|
+
attr_accessor :api_key, :organization
|
17
|
+
|
18
|
+
def initialize(api_key: nil, organization: nil)
|
19
|
+
@api_key = api_key || default_api_key
|
20
|
+
@organization = organization || ENV['OPENAI_ORGANIZATION']
|
21
|
+
@conn = Faraday.new(url: BASE_URL) do |faraday|
|
22
|
+
faraday.request :json
|
23
|
+
faraday.response :json
|
24
|
+
faraday.adapter Faraday.default_adapter
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
def completion(options)
|
29
|
+
response = @conn.post('chat/completions') do |req|
|
30
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
31
|
+
req.headers['OpenAI-Organization'] = @organization if @organization
|
32
|
+
req.body = options
|
33
|
+
end
|
34
|
+
|
35
|
+
handle_response(response)
|
36
|
+
end
|
37
|
+
|
38
|
+
def embedding(model:, input:, **options)
|
39
|
+
response = @conn.post('embeddings') do |req|
|
40
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
41
|
+
req.headers['OpenAI-Organization'] = @organization if @organization
|
42
|
+
req.body = { model: model, input: input, **options }
|
43
|
+
end
|
44
|
+
|
45
|
+
handle_response(response, OpenAIEmbeddingResponse)
|
46
|
+
end
|
47
|
+
|
48
|
+
def models
|
49
|
+
response = @conn.get('models') do |req|
|
50
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
51
|
+
req.headers['OpenAI-Organization'] = @organization if @organization
|
52
|
+
end
|
53
|
+
|
54
|
+
handle_response(response).data.map { |model| model['id'] }
|
55
|
+
end
|
56
|
+
|
57
|
+
def self.models
|
58
|
+
self.new.models
|
59
|
+
end
|
60
|
+
|
61
|
+
def self.stream?
|
62
|
+
true
|
63
|
+
end
|
64
|
+
|
65
|
+
def stream(options, &block)
|
66
|
+
options[:stream] = true
|
67
|
+
response = @conn.post('chat/completions') do |req|
|
68
|
+
req.headers['Authorization'] = "Bearer #{@api_key}"
|
69
|
+
req.headers['OpenAI-Organization'] = @organization if @organization
|
70
|
+
req.headers['Accept'] = 'text/event-stream'
|
71
|
+
req.body = options
|
72
|
+
req.options.on_data = Proc.new do |chunk, size, total|
|
73
|
+
next if chunk.strip.empty?
|
74
|
+
|
75
|
+
yield OpenAIStreamResponse.new(chunk)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
handle_response(response)
|
80
|
+
end
|
81
|
+
|
82
|
+
private
|
83
|
+
|
84
|
+
def handle_response(response, responseClass=OpenAIResponse)
|
85
|
+
case response.status
|
86
|
+
when 200..299
|
87
|
+
responseClass.new(response.body)
|
88
|
+
when 401
|
89
|
+
raise Durable::Llm::AuthenticationError, parse_error_message(response)
|
90
|
+
when 429
|
91
|
+
raise Durable::Llm::RateLimitError, parse_error_message(response)
|
92
|
+
when 400..499
|
93
|
+
raise Durable::Llm::InvalidRequestError, parse_error_message(response)
|
94
|
+
when 500..599
|
95
|
+
raise Durable::Llm::ServerError, parse_error_message(response)
|
96
|
+
else
|
97
|
+
raise Durable::Llm::APIError, "Unexpected response code: #{response.status}"
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
def parse_error_message(response)
|
102
|
+
body = JSON.parse(response.body) rescue nil
|
103
|
+
message = body&.dig('error', 'message') || response.body
|
104
|
+
"#{response.status} Error: #{message}"
|
105
|
+
end
|
106
|
+
|
107
|
+
class OpenAIResponse
|
108
|
+
attr_reader :raw_response
|
109
|
+
|
110
|
+
def initialize(response)
|
111
|
+
@raw_response = response
|
112
|
+
end
|
113
|
+
|
114
|
+
def choices
|
115
|
+
@raw_response['choices'].map { |choice| OpenAIChoice.new(choice) }
|
116
|
+
end
|
117
|
+
|
118
|
+
def data
|
119
|
+
@raw_response['data']
|
120
|
+
end
|
121
|
+
|
122
|
+
def embedding
|
123
|
+
@raw_response['embedding']
|
124
|
+
end
|
125
|
+
|
126
|
+
def to_s
|
127
|
+
choices.map(&:to_s).join(' ')
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
class OpenAIChoice
|
132
|
+
attr_reader :message, :finish_reason
|
133
|
+
|
134
|
+
def initialize(choice)
|
135
|
+
@message = OpenAIMessage.new(choice['message'])
|
136
|
+
@finish_reason = choice['finish_reason']
|
137
|
+
end
|
138
|
+
|
139
|
+
def to_s
|
140
|
+
@message.to_s
|
141
|
+
end
|
142
|
+
end
|
143
|
+
|
144
|
+
class OpenAIMessage
|
145
|
+
attr_reader :role, :content
|
146
|
+
|
147
|
+
def initialize(message)
|
148
|
+
@role = message['role']
|
149
|
+
@content = message['content']
|
150
|
+
end
|
151
|
+
|
152
|
+
def to_s
|
153
|
+
@content
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
class OpenAIStreamResponse
|
158
|
+
attr_reader :choices
|
159
|
+
|
160
|
+
def initialize(fragment)
|
161
|
+
parsed = fragment.split("\n").map { |_| JSON.parse(_) }
|
162
|
+
|
163
|
+
@choices = parsed.map { |_| OpenAIStreamChoice.new(_['choices'])}
|
164
|
+
end
|
165
|
+
|
166
|
+
def to_s
|
167
|
+
@choices.map(&:to_s).join('')
|
168
|
+
end
|
169
|
+
end
|
170
|
+
|
171
|
+
class OpenAIEmbeddingResponse
|
172
|
+
attr_reader :embedding
|
173
|
+
|
174
|
+
def initialize(data)
|
175
|
+
@embedding = data.dig('data', 0, 'embedding')
|
176
|
+
end
|
177
|
+
|
178
|
+
def to_a
|
179
|
+
@embedding
|
180
|
+
end
|
181
|
+
end
|
182
|
+
|
183
|
+
class OpenAIStreamChoice
|
184
|
+
attr_reader :delta, :finish_reason
|
185
|
+
|
186
|
+
def initialize(choice)
|
187
|
+
@choice = [choice].flatten.first
|
188
|
+
@delta = OpenAIStreamDelta.new(@choice['delta'])
|
189
|
+
@finish_reason = @choice['finish_reason']
|
190
|
+
end
|
191
|
+
|
192
|
+
def to_s
|
193
|
+
@delta.to_s
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
197
|
+
class OpenAIStreamDelta
|
198
|
+
attr_reader :role, :content
|
199
|
+
|
200
|
+
def initialize(delta)
|
201
|
+
@role = delta['role']
|
202
|
+
@content = delta['content']
|
203
|
+
end
|
204
|
+
|
205
|
+
def to_s
|
206
|
+
@content || ''
|
207
|
+
end
|
208
|
+
end
|
209
|
+
end
|
210
|
+
end
|
211
|
+
end
|
212
|
+
end
|