llm.rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +21 -0
- data/README.md +146 -0
- data/lib/llm/conversation.rb +38 -0
- data/lib/llm/core_ext/ostruct.rb +37 -0
- data/lib/llm/error.rb +28 -0
- data/lib/llm/file.rb +66 -0
- data/lib/llm/http_client.rb +29 -0
- data/lib/llm/lazy_conversation.rb +39 -0
- data/lib/llm/message.rb +55 -0
- data/lib/llm/message_queue.rb +47 -0
- data/lib/llm/provider.rb +114 -0
- data/lib/llm/providers/anthropic/error_handler.rb +32 -0
- data/lib/llm/providers/anthropic/format.rb +31 -0
- data/lib/llm/providers/anthropic/response_parser.rb +29 -0
- data/lib/llm/providers/anthropic.rb +63 -0
- data/lib/llm/providers/gemini/error_handler.rb +43 -0
- data/lib/llm/providers/gemini/format.rb +31 -0
- data/lib/llm/providers/gemini/response_parser.rb +31 -0
- data/lib/llm/providers/gemini.rb +64 -0
- data/lib/llm/providers/ollama/error_handler.rb +32 -0
- data/lib/llm/providers/ollama/format.rb +28 -0
- data/lib/llm/providers/ollama/response_parser.rb +18 -0
- data/lib/llm/providers/ollama.rb +51 -0
- data/lib/llm/providers/openai/error_handler.rb +32 -0
- data/lib/llm/providers/openai/format.rb +28 -0
- data/lib/llm/providers/openai/response_parser.rb +35 -0
- data/lib/llm/providers/openai.rb +62 -0
- data/lib/llm/response/completion.rb +50 -0
- data/lib/llm/response/embedding.rb +23 -0
- data/lib/llm/response.rb +24 -0
- data/lib/llm/version.rb +5 -0
- data/lib/llm.rb +47 -0
- data/llm.gemspec +40 -0
- data/spec/anthropic/completion_spec.rb +76 -0
- data/spec/gemini/completion_spec.rb +80 -0
- data/spec/gemini/embedding_spec.rb +33 -0
- data/spec/llm/conversation_spec.rb +56 -0
- data/spec/llm/lazy_conversation_spec.rb +110 -0
- data/spec/ollama/completion_spec.rb +52 -0
- data/spec/ollama/embedding_spec.rb +15 -0
- data/spec/openai/completion_spec.rb +99 -0
- data/spec/openai/embedding_spec.rb +33 -0
- data/spec/readme_spec.rb +64 -0
- data/spec/setup.rb +29 -0
- metadata +194 -0
@@ -0,0 +1,29 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Anthropic
|
4
|
+
module ResponseParser
|
5
|
+
def parse_embedding(body)
|
6
|
+
{
|
7
|
+
model: body["model"],
|
8
|
+
embeddings: body["data"].map { _1["embedding"] },
|
9
|
+
total_tokens: body.dig("usage", "total_tokens")
|
10
|
+
}
|
11
|
+
end
|
12
|
+
|
13
|
+
##
|
14
|
+
# @param [Hash] body
|
15
|
+
# The response body from the LLM provider
|
16
|
+
# @return [Hash]
|
17
|
+
def parse_completion(body)
|
18
|
+
{
|
19
|
+
model: body["model"],
|
20
|
+
choices: body["content"].map do
|
21
|
+
# TODO: don't hardcode role
|
22
|
+
LLM::Message.new("assistant", _1["text"], {completion: self})
|
23
|
+
end,
|
24
|
+
prompt_tokens: body.dig("usage", "input_tokens"),
|
25
|
+
completion_tokens: body.dig("usage", "output_tokens")
|
26
|
+
}
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
##
|
5
|
+
# The Anthropic class implements a provider for
|
6
|
+
# [Anthropic](https://www.anthropic.com)
|
7
|
+
class Anthropic < Provider
|
8
|
+
require_relative "anthropic/error_handler"
|
9
|
+
require_relative "anthropic/response_parser"
|
10
|
+
require_relative "anthropic/format"
|
11
|
+
include Format
|
12
|
+
|
13
|
+
HOST = "api.anthropic.com"
|
14
|
+
DEFAULT_PARAMS = {max_tokens: 1024, model: "claude-3-5-sonnet-20240620"}.freeze
|
15
|
+
|
16
|
+
##
|
17
|
+
# @param secret (see LLM::Provider#initialize)
|
18
|
+
def initialize(secret, **)
|
19
|
+
super(secret, host: HOST, **)
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# @param input (see LLM::Provider#embed)
|
24
|
+
# @return (see LLM::Provider#embed)
|
25
|
+
def embed(input, **params)
|
26
|
+
req = Net::HTTP::Post.new ["api.voyageai.com/v1", "embeddings"].join("/")
|
27
|
+
body = {input:, model: "voyage-2"}.merge!(params)
|
28
|
+
req = preflight(req, body)
|
29
|
+
res = request(@http, req)
|
30
|
+
Response::Embedding.new(res).extend(response_parser)
|
31
|
+
end
|
32
|
+
|
33
|
+
##
|
34
|
+
# @see https://docs.anthropic.com/en/api/messages Anthropic docs
|
35
|
+
# @param prompt (see LLM::Provider#complete)
|
36
|
+
# @param role (see LLM::Provider#complete)
|
37
|
+
# @return (see LLM::Provider#complete)
|
38
|
+
def complete(prompt, role = :user, **params)
|
39
|
+
req = Net::HTTP::Post.new ["/v1", "messages"].join("/")
|
40
|
+
messages = [*(params.delete(:messages) || []), Message.new(role, prompt)]
|
41
|
+
params = DEFAULT_PARAMS.merge(params)
|
42
|
+
body = {messages: format(messages)}.merge!(params)
|
43
|
+
req = preflight(req, body)
|
44
|
+
res = request(@http, req)
|
45
|
+
Response::Completion.new(res).extend(response_parser)
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
|
50
|
+
def auth(req)
|
51
|
+
req["anthropic-version"] = "2023-06-01"
|
52
|
+
req["x-api-key"] = @secret
|
53
|
+
end
|
54
|
+
|
55
|
+
def response_parser
|
56
|
+
LLM::Anthropic::ResponseParser
|
57
|
+
end
|
58
|
+
|
59
|
+
def error_handler
|
60
|
+
LLM::Anthropic::ErrorHandler
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Gemini
|
4
|
+
class ErrorHandler
|
5
|
+
##
|
6
|
+
# @return [Net::HTTPResponse]
|
7
|
+
# Non-2XX response from the server
|
8
|
+
attr_reader :res
|
9
|
+
|
10
|
+
##
|
11
|
+
# @param [Net::HTTPResponse] res
|
12
|
+
# The response from the server
|
13
|
+
# @return [LLM::Gemini::ErrorHandler]
|
14
|
+
def initialize(res)
|
15
|
+
@res = res
|
16
|
+
end
|
17
|
+
|
18
|
+
##
|
19
|
+
# @raise [LLM::Error]
|
20
|
+
# Raises a subclass of {LLM::Error LLM::Error}
|
21
|
+
def raise_error!
|
22
|
+
case res
|
23
|
+
when Net::HTTPBadRequest
|
24
|
+
reason = body.dig("error", "details", 0, "reason")
|
25
|
+
if reason == "API_KEY_INVALID"
|
26
|
+
raise LLM::Error::Unauthorized.new { _1.response = res }, "Authentication error"
|
27
|
+
else
|
28
|
+
raise LLM::Error::BadResponse.new { _1.response = res }, "Unexpected response"
|
29
|
+
end
|
30
|
+
when Net::HTTPTooManyRequests
|
31
|
+
raise LLM::Error::RateLimit.new { _1.response = res }, "Too many requests"
|
32
|
+
else
|
33
|
+
raise LLM::Error::BadResponse.new { _1.response = res }, "Unexpected response"
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def body
|
40
|
+
@body ||= JSON.parse(res.body)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Gemini
|
4
|
+
module Format
|
5
|
+
##
|
6
|
+
# @param [Array<LLM::Message>] messages
|
7
|
+
# The messages to format
|
8
|
+
# @return [Array<Hash>]
|
9
|
+
def format(messages)
|
10
|
+
messages.map { {role: _1.role, parts: [format_content(_1.content)]} }
|
11
|
+
end
|
12
|
+
|
13
|
+
private
|
14
|
+
|
15
|
+
##
|
16
|
+
# @param [String, LLM::File] content
|
17
|
+
# The content to format
|
18
|
+
# @return [String, Hash]
|
19
|
+
# The formatted content
|
20
|
+
def format_content(content)
|
21
|
+
if LLM::File === content
|
22
|
+
file = content
|
23
|
+
{
|
24
|
+
inline_data: {mime_type: file.mime_type, data: [File.binread(file.path)].pack("m0")}
|
25
|
+
}
|
26
|
+
else
|
27
|
+
{text: content}
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Gemini
|
4
|
+
module ResponseParser
|
5
|
+
def parse_embedding(body)
|
6
|
+
{
|
7
|
+
model: "text-embedding-004",
|
8
|
+
embeddings: body.dig("embedding", "values")
|
9
|
+
}
|
10
|
+
end
|
11
|
+
|
12
|
+
##
|
13
|
+
# @param [Hash] body
|
14
|
+
# The response body from the LLM provider
|
15
|
+
# @return [Hash]
|
16
|
+
def parse_completion(body)
|
17
|
+
{
|
18
|
+
model: body["modelVersion"],
|
19
|
+
choices: body["candidates"].map do
|
20
|
+
LLM::Message.new(
|
21
|
+
_1.dig("content", "role"),
|
22
|
+
_1.dig("content", "parts", 0, "text"),
|
23
|
+
{completion: self}
|
24
|
+
)
|
25
|
+
end,
|
26
|
+
prompt_tokens: body.dig("usageMetadata", "promptTokenCount"),
|
27
|
+
completion_tokens: body.dig("usageMetadata", "candidatesTokenCount")
|
28
|
+
}
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
##
|
5
|
+
# The Gemini class implements a provider for
|
6
|
+
# [Gemini](https://ai.google.dev/)
|
7
|
+
class Gemini < Provider
|
8
|
+
require_relative "gemini/error_handler"
|
9
|
+
require_relative "gemini/response_parser"
|
10
|
+
require_relative "gemini/format"
|
11
|
+
include Format
|
12
|
+
|
13
|
+
HOST = "generativelanguage.googleapis.com"
|
14
|
+
DEFAULT_PARAMS = {model: "gemini-1.5-flash"}.freeze
|
15
|
+
|
16
|
+
##
|
17
|
+
# @param secret (see LLM::Provider#initialize)
|
18
|
+
def initialize(secret, **)
|
19
|
+
super(secret, host: HOST, **)
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# @param input (see LLM::Provider#embed)
|
24
|
+
# @return (see LLM::Provider#embed)
|
25
|
+
def embed(input, **params)
|
26
|
+
path = ["/v1beta/models", "text-embedding-004"].join("/")
|
27
|
+
req = Net::HTTP::Post.new [path, "embedContent"].join(":")
|
28
|
+
body = {content: {parts: [{text: input}]}}
|
29
|
+
req = preflight(req, body)
|
30
|
+
res = request @http, req
|
31
|
+
Response::Embedding.new(res).extend(response_parser)
|
32
|
+
end
|
33
|
+
|
34
|
+
##
|
35
|
+
# @see https://ai.google.dev/api/generate-content#v1beta.models.generateContent Gemini docs
|
36
|
+
# @param prompt (see LLM::Provider#complete)
|
37
|
+
# @param role (see LLM::Provider#complete)
|
38
|
+
# @return (see LLM::Provider#complete)
|
39
|
+
def complete(prompt, role = :user, **params)
|
40
|
+
params = DEFAULT_PARAMS.merge(params)
|
41
|
+
path = ["/v1beta/models", params.delete(:model)].join("/")
|
42
|
+
req = Net::HTTP::Post.new [path, "generateContent"].join(":")
|
43
|
+
messages = [*(params.delete(:messages) || []), LLM::Message.new(role, prompt)]
|
44
|
+
body = {contents: format(messages)}
|
45
|
+
req = preflight(req, body)
|
46
|
+
res = request(@http, req)
|
47
|
+
Response::Completion.new(res).extend(response_parser)
|
48
|
+
end
|
49
|
+
|
50
|
+
private
|
51
|
+
|
52
|
+
def auth(req)
|
53
|
+
req.path.replace [req.path, URI.encode_www_form(key: @secret)].join("?")
|
54
|
+
end
|
55
|
+
|
56
|
+
def response_parser
|
57
|
+
LLM::Gemini::ResponseParser
|
58
|
+
end
|
59
|
+
|
60
|
+
def error_handler
|
61
|
+
LLM::Gemini::ErrorHandler
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Ollama
|
4
|
+
class ErrorHandler
|
5
|
+
##
|
6
|
+
# @return [Net::HTTPResponse]
|
7
|
+
# Non-2XX response from the server
|
8
|
+
attr_reader :res
|
9
|
+
|
10
|
+
##
|
11
|
+
# @param [Net::HTTPResponse] res
|
12
|
+
# The response from the server
|
13
|
+
# @return [LLM::OpenAI::ErrorHandler]
|
14
|
+
def initialize(res)
|
15
|
+
@res = res
|
16
|
+
end
|
17
|
+
|
18
|
+
##
|
19
|
+
# @raise [LLM::Error]
|
20
|
+
# Raises a subclass of {LLM::Error LLM::Error}
|
21
|
+
def raise_error!
|
22
|
+
case res
|
23
|
+
when Net::HTTPUnauthorized
|
24
|
+
raise LLM::Error::Unauthorized.new { _1.response = res }, "Authentication error"
|
25
|
+
when Net::HTTPTooManyRequests
|
26
|
+
raise LLM::Error::RateLimit.new { _1.response = res }, "Too many requests"
|
27
|
+
else
|
28
|
+
raise LLM::Error::BadResponse.new { _1.response = res }, "Unexpected response"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Ollama
|
4
|
+
module Format
|
5
|
+
##
|
6
|
+
# @param [Array<LLM::Message>] messages
|
7
|
+
# The messages to format
|
8
|
+
# @return [Array<Hash>]
|
9
|
+
def format(messages)
|
10
|
+
messages.map { {role: _1.role, content: format_content(_1.content)} }
|
11
|
+
end
|
12
|
+
|
13
|
+
private
|
14
|
+
|
15
|
+
##
|
16
|
+
# @param [String, URI] content
|
17
|
+
# The content to format
|
18
|
+
# @return [String, Hash]
|
19
|
+
# The formatted content
|
20
|
+
def format_content(content)
|
21
|
+
if URI === content
|
22
|
+
[{type: :image_url, image_url: {url: content.to_s}}]
|
23
|
+
else
|
24
|
+
content
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::Ollama
|
4
|
+
module ResponseParser
|
5
|
+
##
|
6
|
+
# @param [Hash] body
|
7
|
+
# The response body from the LLM provider
|
8
|
+
# @return [Hash]
|
9
|
+
def parse_completion(body)
|
10
|
+
{
|
11
|
+
model: body["model"],
|
12
|
+
choices: [LLM::Message.new(*body["message"].values_at("role", "content"), {completion: self})],
|
13
|
+
prompt_tokens: body.dig("prompt_eval_count"),
|
14
|
+
completion_tokens: body.dig("eval_count")
|
15
|
+
}
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
##
|
5
|
+
# The Ollama class implements a provider for
|
6
|
+
# [Ollama](https://ollama.ai/)
|
7
|
+
class Ollama < Provider
|
8
|
+
require_relative "ollama/error_handler"
|
9
|
+
require_relative "ollama/response_parser"
|
10
|
+
require_relative "ollama/format"
|
11
|
+
include Format
|
12
|
+
|
13
|
+
HOST = "localhost"
|
14
|
+
DEFAULT_PARAMS = {model: "llama3.2", stream: false}.freeze
|
15
|
+
|
16
|
+
##
|
17
|
+
# @param secret (see LLM::Provider#initialize)
|
18
|
+
def initialize(secret, **)
|
19
|
+
super(secret, host: HOST, port: 11434, ssl: false, **)
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion Ollama docs
|
24
|
+
# @param prompt (see LLM::Provider#complete)
|
25
|
+
# @param role (see LLM::Provider#complete)
|
26
|
+
# @return (see LLM::Provider#complete)
|
27
|
+
def complete(prompt, role = :user, **params)
|
28
|
+
req = Net::HTTP::Post.new ["/api", "chat"].join("/")
|
29
|
+
messages = [*(params.delete(:messages) || []), LLM::Message.new(role, prompt)]
|
30
|
+
params = DEFAULT_PARAMS.merge(params)
|
31
|
+
body = {messages: messages.map(&:to_h)}.merge!(params)
|
32
|
+
req = preflight(req, body)
|
33
|
+
res = request(@http, req)
|
34
|
+
Response::Completion.new(res).extend(response_parser)
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def auth(req)
|
40
|
+
req["Authorization"] = "Bearer #{@secret}"
|
41
|
+
end
|
42
|
+
|
43
|
+
def response_parser
|
44
|
+
LLM::Ollama::ResponseParser
|
45
|
+
end
|
46
|
+
|
47
|
+
def error_handler
|
48
|
+
LLM::Ollama::ErrorHandler
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::OpenAI
|
4
|
+
class ErrorHandler
|
5
|
+
##
|
6
|
+
# @return [Net::HTTPResponse]
|
7
|
+
# Non-2XX response from the server
|
8
|
+
attr_reader :res
|
9
|
+
|
10
|
+
##
|
11
|
+
# @param [Net::HTTPResponse] res
|
12
|
+
# The response from the server
|
13
|
+
# @return [LLM::OpenAI::ErrorHandler]
|
14
|
+
def initialize(res)
|
15
|
+
@res = res
|
16
|
+
end
|
17
|
+
|
18
|
+
##
|
19
|
+
# @raise [LLM::Error]
|
20
|
+
# Raises a subclass of {LLM::Error LLM::Error}
|
21
|
+
def raise_error!
|
22
|
+
case res
|
23
|
+
when Net::HTTPUnauthorized
|
24
|
+
raise LLM::Error::Unauthorized.new { _1.response = res }, "Authentication error"
|
25
|
+
when Net::HTTPTooManyRequests
|
26
|
+
raise LLM::Error::RateLimit.new { _1.response = res }, "Too many requests"
|
27
|
+
else
|
28
|
+
raise LLM::Error::BadResponse.new { _1.response = res }, "Unexpected response"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::OpenAI
|
4
|
+
module Format
|
5
|
+
##
|
6
|
+
# @param [Array<LLM::Message>] messages
|
7
|
+
# The messages to format
|
8
|
+
# @return [Array<Hash>]
|
9
|
+
def format(messages)
|
10
|
+
messages.map { {role: _1.role, content: format_content(_1.content)} }
|
11
|
+
end
|
12
|
+
|
13
|
+
private
|
14
|
+
|
15
|
+
##
|
16
|
+
# @param [String, URI] content
|
17
|
+
# The content to format
|
18
|
+
# @return [String, Hash]
|
19
|
+
# The formatted content
|
20
|
+
def format_content(content)
|
21
|
+
if URI === content
|
22
|
+
[{type: :image_url, image_url: {url: content.to_s}}]
|
23
|
+
else
|
24
|
+
content
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
class LLM::OpenAI
|
4
|
+
module ResponseParser
|
5
|
+
def parse_embedding(body)
|
6
|
+
{
|
7
|
+
model: body["model"],
|
8
|
+
embeddings: body.dig("data").map do |data|
|
9
|
+
data["embedding"]
|
10
|
+
end,
|
11
|
+
prompt_tokens: body.dig("usage", "prompt_tokens"),
|
12
|
+
total_tokens: body.dig("usage", "total_tokens")
|
13
|
+
}
|
14
|
+
end
|
15
|
+
|
16
|
+
##
|
17
|
+
# @param [Hash] body
|
18
|
+
# The response body from the LLM provider
|
19
|
+
# @return [Hash]
|
20
|
+
def parse_completion(body)
|
21
|
+
{
|
22
|
+
model: body["model"],
|
23
|
+
choices: body["choices"].map do
|
24
|
+
mesg = _1["message"]
|
25
|
+
logprobs = _1["logprobs"]
|
26
|
+
role, content = mesg.values_at("role", "content")
|
27
|
+
LLM::Message.new(role, content, {completion: self, logprobs:})
|
28
|
+
end,
|
29
|
+
prompt_tokens: body.dig("usage", "prompt_tokens"),
|
30
|
+
completion_tokens: body.dig("usage", "completion_tokens"),
|
31
|
+
total_tokens: body.dig("usage", "total_tokens")
|
32
|
+
}
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
##
|
5
|
+
# The OpenAI class implements a provider for
|
6
|
+
# [OpenAI](https://platform.openai.com/)
|
7
|
+
class OpenAI < Provider
|
8
|
+
require_relative "openai/error_handler"
|
9
|
+
require_relative "openai/response_parser"
|
10
|
+
require_relative "openai/format"
|
11
|
+
include Format
|
12
|
+
|
13
|
+
HOST = "api.openai.com"
|
14
|
+
DEFAULT_PARAMS = {model: "gpt-4o-mini"}.freeze
|
15
|
+
|
16
|
+
##
|
17
|
+
# @param secret (see LLM::Provider#initialize)
|
18
|
+
def initialize(secret, **)
|
19
|
+
super(secret, host: HOST, **)
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# @param input (see LLM::Provider#embed)
|
24
|
+
# @return (see LLM::Provider#embed)
|
25
|
+
def embed(input, **params)
|
26
|
+
req = Net::HTTP::Post.new ["/v1", "embeddings"].join("/")
|
27
|
+
body = {input:, model: "text-embedding-3-small"}.merge!(params)
|
28
|
+
req = preflight(req, body)
|
29
|
+
res = request @http, req
|
30
|
+
Response::Embedding.new(res).extend(response_parser)
|
31
|
+
end
|
32
|
+
|
33
|
+
##
|
34
|
+
# @see https://platform.openai.com/docs/api-reference/chat/create OpenAI docs
|
35
|
+
# @param prompt (see LLM::Provider#complete)
|
36
|
+
# @param role (see LLM::Provider#complete)
|
37
|
+
# @return (see LLM::Provider#complete)
|
38
|
+
def complete(prompt, role = :user, **params)
|
39
|
+
req = Net::HTTP::Post.new ["/v1", "chat", "completions"].join("/")
|
40
|
+
messages = [*(params.delete(:messages) || []), Message.new(role, prompt)]
|
41
|
+
params = DEFAULT_PARAMS.merge(params)
|
42
|
+
body = {messages: format(messages)}.merge!(params)
|
43
|
+
req = preflight(req, body)
|
44
|
+
res = request(@http, req)
|
45
|
+
Response::Completion.new(res).extend(response_parser)
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
|
50
|
+
def auth(req)
|
51
|
+
req["Authorization"] = "Bearer #{@secret}"
|
52
|
+
end
|
53
|
+
|
54
|
+
def response_parser
|
55
|
+
LLM::OpenAI::ResponseParser
|
56
|
+
end
|
57
|
+
|
58
|
+
def error_handler
|
59
|
+
LLM::OpenAI::ErrorHandler
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
class Response::Completion < Response
|
5
|
+
##
|
6
|
+
# @return [String]
|
7
|
+
# Returns the model name used for the completion
|
8
|
+
def model
|
9
|
+
parsed[:model]
|
10
|
+
end
|
11
|
+
|
12
|
+
##
|
13
|
+
# @return [Array<LLM::Message>]
|
14
|
+
# Returns an array of messages
|
15
|
+
def choices
|
16
|
+
parsed[:choices]
|
17
|
+
end
|
18
|
+
|
19
|
+
##
|
20
|
+
# @return [Integer]
|
21
|
+
# Returns the count of prompt tokens
|
22
|
+
def prompt_tokens
|
23
|
+
parsed[:prompt_tokens]
|
24
|
+
end
|
25
|
+
|
26
|
+
##
|
27
|
+
# @return [Integer]
|
28
|
+
# Returns the count of completion tokens
|
29
|
+
def completion_tokens
|
30
|
+
parsed[:completion_tokens]
|
31
|
+
end
|
32
|
+
|
33
|
+
##
|
34
|
+
# @return [Integer]
|
35
|
+
# Returns the total count of tokens
|
36
|
+
def total_tokens
|
37
|
+
prompt_tokens + completion_tokens
|
38
|
+
end
|
39
|
+
|
40
|
+
private
|
41
|
+
|
42
|
+
##
|
43
|
+
# @private
|
44
|
+
# @return [Hash]
|
45
|
+
# Returns the parsed completion response from the provider
|
46
|
+
def parsed
|
47
|
+
@parsed ||= parse_completion(body)
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
class Response::Embedding < Response
|
5
|
+
def model
|
6
|
+
parsed[:model]
|
7
|
+
end
|
8
|
+
|
9
|
+
def embeddings
|
10
|
+
parsed[:embeddings]
|
11
|
+
end
|
12
|
+
|
13
|
+
def total_tokens
|
14
|
+
parsed[:total_tokens]
|
15
|
+
end
|
16
|
+
|
17
|
+
private
|
18
|
+
|
19
|
+
def parsed
|
20
|
+
@parsed ||= parse_embedding(body)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
data/lib/llm/response.rb
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module LLM
|
4
|
+
class Response
|
5
|
+
require "json"
|
6
|
+
require_relative "response/completion"
|
7
|
+
require_relative "response/embedding"
|
8
|
+
|
9
|
+
##
|
10
|
+
# @return [Hash]
|
11
|
+
# Returns the response body
|
12
|
+
attr_reader :body
|
13
|
+
|
14
|
+
##
|
15
|
+
# @param [Net::HTTPResponse] res
|
16
|
+
# HTTP response
|
17
|
+
# @return [LLM::Response]
|
18
|
+
# Returns an instance of LLM::Response
|
19
|
+
def initialize(res)
|
20
|
+
@res = res
|
21
|
+
@body = JSON.parse(res.body)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|