inst_llm 0.1.0 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/inst_llm/client.rb +123 -0
- data/lib/inst_llm/parameter/all.rb +5 -0
- data/lib/inst_llm/parameter/claude_parameters.rb +27 -0
- data/lib/inst_llm/parameter/cohere_embed_parameters.rb +22 -0
- data/lib/inst_llm/parameter/mistral_parameters.rb +51 -0
- data/lib/inst_llm/response/all.rb +4 -0
- data/lib/inst_llm/response/chat_response.rb +46 -0
- data/lib/inst_llm/response/embedding_response.rb +23 -0
- data/lib/inst_llm/version.rb +5 -0
- data/lib/inst_llm.rb +13 -0
- metadata +12 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 105c04ef5d12358f4ff663bc8c1278e731c0664a40ac016e29ec2defc999b6a3
|
4
|
+
data.tar.gz: 4873e34b6e2677822a6f8cabb3cac06367a6062d04301810856b33b751b260fb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6d7c85ea828972e507b8c680e340fb5ad76b5784277a441ecad7fb1ee215a2bc839c426daf76207cacf63b1ca9fe54ff4d7bf1bf3b9d1f78ef413a28e64ddc33
|
7
|
+
data.tar.gz: 7ffed97282a142a6adb4227334ddc333825b2de7225ec61b2f8d1796681a68b693be4799f73aee09cc4f5b482c2c8722a7633f9973d47bb42a8546880d75b997
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "aws-sdk-bedrockruntime"
|
4
|
+
require "json"
|
5
|
+
|
6
|
+
require_relative "parameter/all"
|
7
|
+
require_relative "response/all"
|
8
|
+
|
9
|
+
module InstLLM
|
10
|
+
class Client
|
11
|
+
MODELS = {
|
12
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": { format: :claude, provider: :bedrock, type: :chat },
|
13
|
+
"anthropic.claude-3-haiku-20240307-v1:0": { format: :claude, provider: :bedrock, type: :chat },
|
14
|
+
|
15
|
+
"mistral.mistral-7b-instruct-v0:2": { format: :mistral, provider: :bedrock, type: :chat },
|
16
|
+
"mistral.mixtral-8x7b-instruct-v0:1": { format: :mistral, provider: :bedrock, type: :chat },
|
17
|
+
"mistral.mistral-large-2402-v1:0": { format: :mistral, provider: :bedrock, type: :chat },
|
18
|
+
|
19
|
+
"cohere.embed-english-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
|
20
|
+
"cohere.embed-multilingual-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
|
21
|
+
}.freeze
|
22
|
+
|
23
|
+
def initialize(model, **options)
|
24
|
+
model = model.to_sym
|
25
|
+
raise UnknownArgumentError unless MODELS.key?(model)
|
26
|
+
|
27
|
+
@model = model
|
28
|
+
@options = options
|
29
|
+
end
|
30
|
+
|
31
|
+
def chat(messages, **options)
|
32
|
+
model = (options[:model] || options[:model_id] || @model).to_sym
|
33
|
+
raise ArgumentError, "Model #{model} is not a chat model" unless chat_model?(model)
|
34
|
+
|
35
|
+
response_factory(model, call(model, messages, **options))
|
36
|
+
end
|
37
|
+
|
38
|
+
def embedding(message, **options)
|
39
|
+
model = (options[:model] || options[:model_id] || @model).to_sym
|
40
|
+
raise ArgumentError, "Model #{model} is not an embedding model" unless embedding_model?(model)
|
41
|
+
|
42
|
+
embedding_response_factory(model, call(model, message, **options))
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def call(model, messages, **options)
|
48
|
+
params = params_factory(model, messages, **options)
|
49
|
+
|
50
|
+
begin
|
51
|
+
res = client.invoke_model(**params)
|
52
|
+
rescue => error
|
53
|
+
raise map_error_type(error)
|
54
|
+
end
|
55
|
+
|
56
|
+
JSON.parse(res.body.read)
|
57
|
+
end
|
58
|
+
|
59
|
+
def chat_model?(model)
|
60
|
+
MODELS[model][:type] == :chat
|
61
|
+
end
|
62
|
+
|
63
|
+
def embedding_model?(model)
|
64
|
+
MODELS[model][:type] == :embedding
|
65
|
+
end
|
66
|
+
|
67
|
+
def client
|
68
|
+
return @client if @client
|
69
|
+
|
70
|
+
case MODELS[@model][:provider]
|
71
|
+
when :bedrock
|
72
|
+
@client = Aws::BedrockRuntime::Client.new(**@options)
|
73
|
+
else
|
74
|
+
raise UnknownArgumentError
|
75
|
+
end
|
76
|
+
|
77
|
+
@client
|
78
|
+
end
|
79
|
+
|
80
|
+
def map_error_type(error)
|
81
|
+
mapped_error_type = nil
|
82
|
+
|
83
|
+
case MODELS[@model][:provider]
|
84
|
+
when :bedrock
|
85
|
+
case error
|
86
|
+
when Aws::BedrockRuntime::Errors::ServiceQuotaExceededException
|
87
|
+
mapped_error_type = ServiceQuotaExceededError
|
88
|
+
when Aws::BedrockRuntime::Errors::ThrottlingException
|
89
|
+
mapped_error_type = ThrottlingError
|
90
|
+
when Aws::BedrockRuntime::Errors::ValidationException
|
91
|
+
if error.message.include?("too long")
|
92
|
+
mapped_error_type = ValidationTooLongError
|
93
|
+
else
|
94
|
+
mapped_error_type = ValidationError
|
95
|
+
end
|
96
|
+
else
|
97
|
+
mapped_error_type = Error
|
98
|
+
end
|
99
|
+
else
|
100
|
+
raise UnknownArgumentError
|
101
|
+
end
|
102
|
+
|
103
|
+
mapped_error_type.new(error.message)
|
104
|
+
end
|
105
|
+
|
106
|
+
def params_factory(model, messages, **options)
|
107
|
+
params_table = {
|
108
|
+
claude: Parameter::ClaudeParameters,
|
109
|
+
cohere_embed: Parameter::CohereEmbedParameters,
|
110
|
+
mistral: Parameter::MistralParameters
|
111
|
+
}
|
112
|
+
params_table[MODELS[model][:format]].new(model: model, messages: messages, **options)
|
113
|
+
end
|
114
|
+
|
115
|
+
def embedding_response_factory(model, response)
|
116
|
+
Response::EmbeddingResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
|
117
|
+
end
|
118
|
+
|
119
|
+
def response_factory(model, response)
|
120
|
+
Response::ChatResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class ClaudeParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
anthropic_version: "bedrock-2023-05-31",
|
8
|
+
max_tokens: 2000,
|
9
|
+
stop_sequences: nil,
|
10
|
+
temperature: nil,
|
11
|
+
top_k: nil,
|
12
|
+
top_p: nil,
|
13
|
+
system: nil,
|
14
|
+
}.freeze
|
15
|
+
|
16
|
+
def initialize(model:, messages: [], **options)
|
17
|
+
@messages = messages
|
18
|
+
@model = model
|
19
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
20
|
+
end
|
21
|
+
|
22
|
+
def to_hash
|
23
|
+
{ model_id: @model, body: { messages: @messages }.merge(@options).to_json }
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class CohereEmbedParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
input_type: nil,
|
8
|
+
truncate: nil
|
9
|
+
}.freeze
|
10
|
+
|
11
|
+
def initialize(model:, texts: [], **options)
|
12
|
+
@model = model
|
13
|
+
@texts = texts
|
14
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_hash
|
18
|
+
{ model_id: @model, body: { texts: @texts }.merge(@options).to_json }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class MistralParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
max_tokens: nil,
|
8
|
+
stop: nil,
|
9
|
+
temperature: nil,
|
10
|
+
top_p: nil,
|
11
|
+
top_k: nil
|
12
|
+
}.freeze
|
13
|
+
|
14
|
+
def initialize(model:, messages:, **options)
|
15
|
+
@model = model
|
16
|
+
@messages = messages
|
17
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
18
|
+
end
|
19
|
+
|
20
|
+
def to_hash
|
21
|
+
{ model_id: @model, body: { prompt: prompt }.merge(@options).to_json }
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def prompt
|
27
|
+
system_message = nil
|
28
|
+
prompt = @messages.map do |message|
|
29
|
+
case message[:role].to_sym
|
30
|
+
when :assistant
|
31
|
+
"#{message[:content]}"
|
32
|
+
when :system
|
33
|
+
system_message = message[:content]
|
34
|
+
when :user
|
35
|
+
"[INST] #{message[:content]} [/INST]"
|
36
|
+
else
|
37
|
+
raise UnknownArgumentError
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
prompt = "<s>" + prompt.join("\n\n")
|
42
|
+
|
43
|
+
if system_message
|
44
|
+
prompt.sub("\[INST\]", "[INST] #{system_message}\n")
|
45
|
+
end
|
46
|
+
|
47
|
+
prompt
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "securerandom"
|
4
|
+
|
5
|
+
module InstLLM
|
6
|
+
module Response
|
7
|
+
class ChatResponse
|
8
|
+
attr_reader :created, :fingerprint, :stop_reason, :message, :model, :usage
|
9
|
+
|
10
|
+
def initialize(model:, message:, stop_reason:, usage:)
|
11
|
+
@created = Time.now.to_i
|
12
|
+
@fingerprint = SecureRandom.uuid
|
13
|
+
@message = message
|
14
|
+
@model = model
|
15
|
+
@stop_reason = stop_reason
|
16
|
+
@usage = usage
|
17
|
+
end
|
18
|
+
|
19
|
+
class << self
|
20
|
+
def from_claude(model:, response:)
|
21
|
+
new(
|
22
|
+
model: model,
|
23
|
+
message: { role: :assistant, content: response["content"][0]["text"] },
|
24
|
+
stop_reason: response["stop_reason"],
|
25
|
+
usage: {
|
26
|
+
input_tokens: response["usage"]["input_tokens"],
|
27
|
+
output_tokens: response["usage"]["output_tokens"]
|
28
|
+
}
|
29
|
+
)
|
30
|
+
end
|
31
|
+
|
32
|
+
def from_mistral(model:, response:)
|
33
|
+
new(
|
34
|
+
model: model,
|
35
|
+
message: { role: :assistant, content: response["outputs"][0]["text"] },
|
36
|
+
stop_reason: response["outputs"][0]["stop_reason"],
|
37
|
+
usage: {
|
38
|
+
input_tokens: -1,
|
39
|
+
output_tokens: -1
|
40
|
+
}
|
41
|
+
)
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Response
|
5
|
+
class EmbeddingResponse
|
6
|
+
attr_reader :model, :embeddings
|
7
|
+
|
8
|
+
def initialize(model, embeddings)
|
9
|
+
@model = model
|
10
|
+
@embeddings = embeddings
|
11
|
+
end
|
12
|
+
|
13
|
+
class << self
|
14
|
+
def from_cohere_embed(model:, response:)
|
15
|
+
embeddings = response["embeddings"].map.with_index do |embedding, i|
|
16
|
+
{ object: "embedding", embedding: embedding, index: i }
|
17
|
+
end
|
18
|
+
new(model, embeddings)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
data/lib/inst_llm.rb
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "inst_llm/version"
|
4
|
+
require_relative "inst_llm/client"
|
5
|
+
|
6
|
+
module InstLLM
|
7
|
+
class Error < StandardError; end
|
8
|
+
class UnknownArgumentError < StandardError; end
|
9
|
+
class ServiceQuotaExceededError < StandardError; end
|
10
|
+
class ThrottlingError < StandardError; end
|
11
|
+
class ValidationTooLongError < StandardError; end
|
12
|
+
class ValidationError < StandardError; end
|
13
|
+
end
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: inst_llm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Zach Pendleton
|
@@ -30,7 +30,17 @@ email:
|
|
30
30
|
executables: []
|
31
31
|
extensions: []
|
32
32
|
extra_rdoc_files: []
|
33
|
-
files:
|
33
|
+
files:
|
34
|
+
- lib/inst_llm.rb
|
35
|
+
- lib/inst_llm/client.rb
|
36
|
+
- lib/inst_llm/parameter/all.rb
|
37
|
+
- lib/inst_llm/parameter/claude_parameters.rb
|
38
|
+
- lib/inst_llm/parameter/cohere_embed_parameters.rb
|
39
|
+
- lib/inst_llm/parameter/mistral_parameters.rb
|
40
|
+
- lib/inst_llm/response/all.rb
|
41
|
+
- lib/inst_llm/response/chat_response.rb
|
42
|
+
- lib/inst_llm/response/embedding_response.rb
|
43
|
+
- lib/inst_llm/version.rb
|
34
44
|
homepage: https://instructure.com
|
35
45
|
licenses:
|
36
46
|
- MIT
|