inst_llm 0.1.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/inst_llm/client.rb +126 -0
- data/lib/inst_llm/parameter/all.rb +5 -0
- data/lib/inst_llm/parameter/claude_parameters.rb +27 -0
- data/lib/inst_llm/parameter/cohere_embed_parameters.rb +22 -0
- data/lib/inst_llm/parameter/mistral_parameters.rb +51 -0
- data/lib/inst_llm/response/all.rb +4 -0
- data/lib/inst_llm/response/chat_response.rb +46 -0
- data/lib/inst_llm/response/embedding_response.rb +23 -0
- data/lib/inst_llm/version.rb +5 -0
- data/lib/inst_llm.rb +13 -0
- metadata +13 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9dd355405ae5b7659a6cee01536e9ca422a0658208e125303a6206e4e6129863
|
4
|
+
data.tar.gz: 8e0955a5bf08923385e093a7a6d739ff224695a345d52684d2c435675e3aa4c9
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 787201e52ae44238a3a22d53a93f83fd15718b251a49b46fded527aab1e30eda9fe805574c1f42c0369a5586679a9d8f27c0cae5bb689c47c2c8fb65355e81b3
|
7
|
+
data.tar.gz: 14ea0088107fc9be654483530f314b5f00ea4d72686326b0d6c1349470446de8b8731fe06313081a1b78c78d6b9f2d8bff125752e60183b2dc3978626ffaf83a
|
@@ -0,0 +1,126 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "aws-sdk-bedrockruntime"
|
4
|
+
require "json"
|
5
|
+
|
6
|
+
require_relative "parameter/all"
|
7
|
+
require_relative "response/all"
|
8
|
+
|
9
|
+
module InstLLM
|
10
|
+
class Client
|
11
|
+
MODELS = {
|
12
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": { format: :claude, provider: :bedrock, type: :chat },
|
13
|
+
"anthropic.claude-3-haiku-20240307-v1:0": { format: :claude, provider: :bedrock, type: :chat },
|
14
|
+
|
15
|
+
"mistral.mistral-7b-instruct-v0:2": { format: :mistral, provider: :bedrock, type: :chat },
|
16
|
+
"mistral.mixtral-8x7b-instruct-v0:1": { format: :mistral, provider: :bedrock, type: :chat },
|
17
|
+
"mistral.mistral-large-2402-v1:0": { format: :mistral, provider: :bedrock, type: :chat },
|
18
|
+
|
19
|
+
"cohere.embed-english-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
|
20
|
+
"cohere.embed-multilingual-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
|
21
|
+
}.freeze
|
22
|
+
|
23
|
+
def initialize(model, **options)
|
24
|
+
model = model.to_sym
|
25
|
+
raise UnknownArgumentError unless MODELS.key?(model)
|
26
|
+
|
27
|
+
@model = model
|
28
|
+
@options = options
|
29
|
+
end
|
30
|
+
|
31
|
+
def chat(messages, **options)
|
32
|
+
model = (options[:model] || options[:model_id] || @model).to_sym
|
33
|
+
raise ArgumentError, "Model #{model} is not a chat model" unless chat_model?(model)
|
34
|
+
|
35
|
+
response_factory(model, call(model, messages, **options))
|
36
|
+
end
|
37
|
+
|
38
|
+
def embedding(message, **options)
|
39
|
+
model = (options[:model] || options[:model_id] || @model).to_sym
|
40
|
+
raise ArgumentError, "Model #{model} is not an embedding model" unless embedding_model?(model)
|
41
|
+
|
42
|
+
embedding_response_factory(model, call(model, message, **options))
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def call(model, messages, **options)
|
48
|
+
params = params_factory(model, messages, **options)
|
49
|
+
|
50
|
+
begin
|
51
|
+
res = client.invoke_model(
|
52
|
+
content_type: "application/json",
|
53
|
+
**params
|
54
|
+
)
|
55
|
+
rescue => error
|
56
|
+
raise map_error_type(error)
|
57
|
+
end
|
58
|
+
|
59
|
+
JSON.parse(res.body.read)
|
60
|
+
end
|
61
|
+
|
62
|
+
def chat_model?(model)
|
63
|
+
MODELS[model][:type] == :chat
|
64
|
+
end
|
65
|
+
|
66
|
+
def embedding_model?(model)
|
67
|
+
MODELS[model][:type] == :embedding
|
68
|
+
end
|
69
|
+
|
70
|
+
def client
|
71
|
+
return @client if @client
|
72
|
+
|
73
|
+
case MODELS[@model][:provider]
|
74
|
+
when :bedrock
|
75
|
+
@client = Aws::BedrockRuntime::Client.new(**@options)
|
76
|
+
else
|
77
|
+
raise UnknownArgumentError
|
78
|
+
end
|
79
|
+
|
80
|
+
@client
|
81
|
+
end
|
82
|
+
|
83
|
+
def map_error_type(error)
|
84
|
+
mapped_error_type = nil
|
85
|
+
|
86
|
+
case MODELS[@model][:provider]
|
87
|
+
when :bedrock
|
88
|
+
case error
|
89
|
+
when Aws::BedrockRuntime::Errors::ServiceQuotaExceededException
|
90
|
+
mapped_error_type = ServiceQuotaExceededError
|
91
|
+
when Aws::BedrockRuntime::Errors::ThrottlingException
|
92
|
+
mapped_error_type = ThrottlingError
|
93
|
+
when Aws::BedrockRuntime::Errors::ValidationException
|
94
|
+
if error.message.include?("too long")
|
95
|
+
mapped_error_type = ValidationTooLongError
|
96
|
+
else
|
97
|
+
mapped_error_type = ValidationError
|
98
|
+
end
|
99
|
+
else
|
100
|
+
mapped_error_type = Error
|
101
|
+
end
|
102
|
+
else
|
103
|
+
raise UnknownArgumentError
|
104
|
+
end
|
105
|
+
|
106
|
+
mapped_error_type.new(error.message)
|
107
|
+
end
|
108
|
+
|
109
|
+
def params_factory(model, messages, **options)
|
110
|
+
params_table = {
|
111
|
+
claude: Parameter::ClaudeParameters,
|
112
|
+
cohere_embed: Parameter::CohereEmbedParameters,
|
113
|
+
mistral: Parameter::MistralParameters
|
114
|
+
}
|
115
|
+
params_table[MODELS[model][:format]].new(model: model, messages: messages, **options)
|
116
|
+
end
|
117
|
+
|
118
|
+
def embedding_response_factory(model, response)
|
119
|
+
Response::EmbeddingResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
|
120
|
+
end
|
121
|
+
|
122
|
+
def response_factory(model, response)
|
123
|
+
Response::ChatResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
|
124
|
+
end
|
125
|
+
end
|
126
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class ClaudeParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
anthropic_version: "bedrock-2023-05-31",
|
8
|
+
max_tokens: 2000,
|
9
|
+
stop_sequences: nil,
|
10
|
+
temperature: nil,
|
11
|
+
top_k: nil,
|
12
|
+
top_p: nil,
|
13
|
+
system: nil,
|
14
|
+
}.freeze
|
15
|
+
|
16
|
+
def initialize(model:, messages: [], **options)
|
17
|
+
@messages = messages
|
18
|
+
@model = model
|
19
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
20
|
+
end
|
21
|
+
|
22
|
+
def to_hash
|
23
|
+
{ model_id: @model, body: { messages: @messages }.merge(@options).to_json }
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class CohereEmbedParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
input_type: nil,
|
8
|
+
truncate: nil
|
9
|
+
}.freeze
|
10
|
+
|
11
|
+
def initialize(model:, texts: [], **options)
|
12
|
+
@model = model
|
13
|
+
@texts = texts
|
14
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_hash
|
18
|
+
{ model_id: @model, body: { texts: @texts }.merge(@options).to_json }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class MistralParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
max_tokens: nil,
|
8
|
+
stop: nil,
|
9
|
+
temperature: nil,
|
10
|
+
top_p: nil,
|
11
|
+
top_k: nil
|
12
|
+
}.freeze
|
13
|
+
|
14
|
+
def initialize(model:, messages:, **options)
|
15
|
+
@model = model
|
16
|
+
@messages = messages
|
17
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
18
|
+
end
|
19
|
+
|
20
|
+
def to_hash
|
21
|
+
{ model_id: @model, body: { prompt: prompt }.merge(@options).to_json }
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def prompt
|
27
|
+
system_message = nil
|
28
|
+
prompt = @messages.map do |message|
|
29
|
+
case message[:role].to_sym
|
30
|
+
when :assistant
|
31
|
+
"#{message[:content]}"
|
32
|
+
when :system
|
33
|
+
system_message = message[:content]
|
34
|
+
when :user
|
35
|
+
"[INST] #{message[:content]} [/INST]"
|
36
|
+
else
|
37
|
+
raise UnknownArgumentError
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
prompt = "<s>" + prompt.join("\n\n")
|
42
|
+
|
43
|
+
if system_message
|
44
|
+
prompt.sub("\[INST\]", "[INST] #{system_message}\n")
|
45
|
+
end
|
46
|
+
|
47
|
+
prompt
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "securerandom"
|
4
|
+
|
5
|
+
module InstLLM
|
6
|
+
module Response
|
7
|
+
class ChatResponse
|
8
|
+
attr_reader :created, :fingerprint, :stop_reason, :message, :model, :usage
|
9
|
+
|
10
|
+
def initialize(model:, message:, stop_reason:, usage:)
|
11
|
+
@created = Time.now.to_i
|
12
|
+
@fingerprint = SecureRandom.uuid
|
13
|
+
@message = message
|
14
|
+
@model = model
|
15
|
+
@stop_reason = stop_reason
|
16
|
+
@usage = usage
|
17
|
+
end
|
18
|
+
|
19
|
+
class << self
|
20
|
+
def from_claude(model:, response:)
|
21
|
+
new(
|
22
|
+
model: model,
|
23
|
+
message: { role: :assistant, content: response["content"][0]["text"] },
|
24
|
+
stop_reason: response["stop_reason"],
|
25
|
+
usage: {
|
26
|
+
input_tokens: response["usage"]["input_tokens"],
|
27
|
+
output_tokens: response["usage"]["output_tokens"]
|
28
|
+
}
|
29
|
+
)
|
30
|
+
end
|
31
|
+
|
32
|
+
def from_mistral(model:, response:)
|
33
|
+
new(
|
34
|
+
model: model,
|
35
|
+
message: { role: :assistant, content: response["outputs"][0]["text"] },
|
36
|
+
stop_reason: response["outputs"][0]["stop_reason"],
|
37
|
+
usage: {
|
38
|
+
input_tokens: -1,
|
39
|
+
output_tokens: -1
|
40
|
+
}
|
41
|
+
)
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Response
|
5
|
+
class EmbeddingResponse
|
6
|
+
attr_reader :model, :embeddings
|
7
|
+
|
8
|
+
def initialize(model, embeddings)
|
9
|
+
@model = model
|
10
|
+
@embeddings = embeddings
|
11
|
+
end
|
12
|
+
|
13
|
+
class << self
|
14
|
+
def from_cohere_embed(model:, response:)
|
15
|
+
embeddings = response["embeddings"].map.with_index do |embedding, i|
|
16
|
+
{ object: "embedding", embedding: embedding, index: i }
|
17
|
+
end
|
18
|
+
new(model, embeddings)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
data/lib/inst_llm.rb
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "inst_llm/version"
|
4
|
+
require_relative "inst_llm/client"
|
5
|
+
|
6
|
+
module InstLLM
|
7
|
+
class Error < StandardError; end
|
8
|
+
class UnknownArgumentError < StandardError; end
|
9
|
+
class ServiceQuotaExceededError < StandardError; end
|
10
|
+
class ThrottlingError < StandardError; end
|
11
|
+
class ValidationTooLongError < StandardError; end
|
12
|
+
class ValidationError < StandardError; end
|
13
|
+
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: inst_llm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Zach Pendleton
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-04-
|
11
|
+
date: 2024-04-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: aws-sdk-bedrockruntime
|
@@ -30,7 +30,17 @@ email:
|
|
30
30
|
executables: []
|
31
31
|
extensions: []
|
32
32
|
extra_rdoc_files: []
|
33
|
-
files:
|
33
|
+
files:
|
34
|
+
- lib/inst_llm.rb
|
35
|
+
- lib/inst_llm/client.rb
|
36
|
+
- lib/inst_llm/parameter/all.rb
|
37
|
+
- lib/inst_llm/parameter/claude_parameters.rb
|
38
|
+
- lib/inst_llm/parameter/cohere_embed_parameters.rb
|
39
|
+
- lib/inst_llm/parameter/mistral_parameters.rb
|
40
|
+
- lib/inst_llm/response/all.rb
|
41
|
+
- lib/inst_llm/response/chat_response.rb
|
42
|
+
- lib/inst_llm/response/embedding_response.rb
|
43
|
+
- lib/inst_llm/version.rb
|
34
44
|
homepage: https://instructure.com
|
35
45
|
licenses:
|
36
46
|
- MIT
|