inst_llm 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/inst_llm/client.rb +123 -0
- data/lib/inst_llm/parameter/all.rb +5 -0
- data/lib/inst_llm/parameter/claude_parameters.rb +27 -0
- data/lib/inst_llm/parameter/cohere_embed_parameters.rb +22 -0
- data/lib/inst_llm/parameter/mistral_parameters.rb +51 -0
- data/lib/inst_llm/response/all.rb +4 -0
- data/lib/inst_llm/response/chat_response.rb +46 -0
- data/lib/inst_llm/response/embedding_response.rb +23 -0
- data/lib/inst_llm/version.rb +5 -0
- data/lib/inst_llm.rb +13 -0
- metadata +12 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 105c04ef5d12358f4ff663bc8c1278e731c0664a40ac016e29ec2defc999b6a3
|
4
|
+
data.tar.gz: 4873e34b6e2677822a6f8cabb3cac06367a6062d04301810856b33b751b260fb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6d7c85ea828972e507b8c680e340fb5ad76b5784277a441ecad7fb1ee215a2bc839c426daf76207cacf63b1ca9fe54ff4d7bf1bf3b9d1f78ef413a28e64ddc33
|
7
|
+
data.tar.gz: 7ffed97282a142a6adb4227334ddc333825b2de7225ec61b2f8d1796681a68b693be4799f73aee09cc4f5b482c2c8722a7633f9973d47bb42a8546880d75b997
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "aws-sdk-bedrockruntime"
|
4
|
+
require "json"
|
5
|
+
|
6
|
+
require_relative "parameter/all"
|
7
|
+
require_relative "response/all"
|
8
|
+
|
9
|
+
module InstLLM
|
10
|
+
class Client
|
11
|
+
MODELS = {
|
12
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": { format: :claude, provider: :bedrock, type: :chat },
|
13
|
+
"anthropic.claude-3-haiku-20240307-v1:0": { format: :claude, provider: :bedrock, type: :chat },
|
14
|
+
|
15
|
+
"mistral.mistral-7b-instruct-v0:2": { format: :mistral, provider: :bedrock, type: :chat },
|
16
|
+
"mistral.mixtral-8x7b-instruct-v0:1": { format: :mistral, provider: :bedrock, type: :chat },
|
17
|
+
"mistral.mistral-large-2402-v1:0": { format: :mistral, provider: :bedrock, type: :chat },
|
18
|
+
|
19
|
+
"cohere.embed-english-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
|
20
|
+
"cohere.embed-multilingual-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
|
21
|
+
}.freeze
|
22
|
+
|
23
|
+
def initialize(model, **options)
|
24
|
+
model = model.to_sym
|
25
|
+
raise UnknownArgumentError unless MODELS.key?(model)
|
26
|
+
|
27
|
+
@model = model
|
28
|
+
@options = options
|
29
|
+
end
|
30
|
+
|
31
|
+
def chat(messages, **options)
|
32
|
+
model = (options[:model] || options[:model_id] || @model).to_sym
|
33
|
+
raise ArgumentError, "Model #{model} is not a chat model" unless chat_model?(model)
|
34
|
+
|
35
|
+
response_factory(model, call(model, messages, **options))
|
36
|
+
end
|
37
|
+
|
38
|
+
def embedding(message, **options)
|
39
|
+
model = (options[:model] || options[:model_id] || @model).to_sym
|
40
|
+
raise ArgumentError, "Model #{model} is not an embedding model" unless embedding_model?(model)
|
41
|
+
|
42
|
+
embedding_response_factory(model, call(model, message, **options))
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def call(model, messages, **options)
|
48
|
+
params = params_factory(model, messages, **options)
|
49
|
+
|
50
|
+
begin
|
51
|
+
res = client.invoke_model(**params)
|
52
|
+
rescue => error
|
53
|
+
raise map_error_type(error)
|
54
|
+
end
|
55
|
+
|
56
|
+
JSON.parse(res.body.read)
|
57
|
+
end
|
58
|
+
|
59
|
+
def chat_model?(model)
|
60
|
+
MODELS[model][:type] == :chat
|
61
|
+
end
|
62
|
+
|
63
|
+
def embedding_model?(model)
|
64
|
+
MODELS[model][:type] == :embedding
|
65
|
+
end
|
66
|
+
|
67
|
+
def client
|
68
|
+
return @client if @client
|
69
|
+
|
70
|
+
case MODELS[@model][:provider]
|
71
|
+
when :bedrock
|
72
|
+
@client = Aws::BedrockRuntime::Client.new(**@options)
|
73
|
+
else
|
74
|
+
raise UnknownArgumentError
|
75
|
+
end
|
76
|
+
|
77
|
+
@client
|
78
|
+
end
|
79
|
+
|
80
|
+
def map_error_type(error)
|
81
|
+
mapped_error_type = nil
|
82
|
+
|
83
|
+
case MODELS[@model][:provider]
|
84
|
+
when :bedrock
|
85
|
+
case error
|
86
|
+
when Aws::BedrockRuntime::Errors::ServiceQuotaExceededException
|
87
|
+
mapped_error_type = ServiceQuotaExceededError
|
88
|
+
when Aws::BedrockRuntime::Errors::ThrottlingException
|
89
|
+
mapped_error_type = ThrottlingError
|
90
|
+
when Aws::BedrockRuntime::Errors::ValidationException
|
91
|
+
if error.message.include?("too long")
|
92
|
+
mapped_error_type = ValidationTooLongError
|
93
|
+
else
|
94
|
+
mapped_error_type = ValidationError
|
95
|
+
end
|
96
|
+
else
|
97
|
+
mapped_error_type = Error
|
98
|
+
end
|
99
|
+
else
|
100
|
+
raise UnknownArgumentError
|
101
|
+
end
|
102
|
+
|
103
|
+
mapped_error_type.new(error.message)
|
104
|
+
end
|
105
|
+
|
106
|
+
def params_factory(model, messages, **options)
|
107
|
+
params_table = {
|
108
|
+
claude: Parameter::ClaudeParameters,
|
109
|
+
cohere_embed: Parameter::CohereEmbedParameters,
|
110
|
+
mistral: Parameter::MistralParameters
|
111
|
+
}
|
112
|
+
params_table[MODELS[model][:format]].new(model: model, messages: messages, **options)
|
113
|
+
end
|
114
|
+
|
115
|
+
def embedding_response_factory(model, response)
|
116
|
+
Response::EmbeddingResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
|
117
|
+
end
|
118
|
+
|
119
|
+
def response_factory(model, response)
|
120
|
+
Response::ChatResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class ClaudeParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
anthropic_version: "bedrock-2023-05-31",
|
8
|
+
max_tokens: 2000,
|
9
|
+
stop_sequences: nil,
|
10
|
+
temperature: nil,
|
11
|
+
top_k: nil,
|
12
|
+
top_p: nil,
|
13
|
+
system: nil,
|
14
|
+
}.freeze
|
15
|
+
|
16
|
+
def initialize(model:, messages: [], **options)
|
17
|
+
@messages = messages
|
18
|
+
@model = model
|
19
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
20
|
+
end
|
21
|
+
|
22
|
+
def to_hash
|
23
|
+
{ model_id: @model, body: { messages: @messages }.merge(@options).to_json }
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class CohereEmbedParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
input_type: nil,
|
8
|
+
truncate: nil
|
9
|
+
}.freeze
|
10
|
+
|
11
|
+
def initialize(model:, texts: [], **options)
|
12
|
+
@model = model
|
13
|
+
@texts = texts
|
14
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_hash
|
18
|
+
{ model_id: @model, body: { texts: @texts }.merge(@options).to_json }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Parameter
|
5
|
+
class MistralParameters
|
6
|
+
DEFAULT_OPTIONS = {
|
7
|
+
max_tokens: nil,
|
8
|
+
stop: nil,
|
9
|
+
temperature: nil,
|
10
|
+
top_p: nil,
|
11
|
+
top_k: nil
|
12
|
+
}.freeze
|
13
|
+
|
14
|
+
def initialize(model:, messages:, **options)
|
15
|
+
@model = model
|
16
|
+
@messages = messages
|
17
|
+
@options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
|
18
|
+
end
|
19
|
+
|
20
|
+
def to_hash
|
21
|
+
{ model_id: @model, body: { prompt: prompt }.merge(@options).to_json }
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def prompt
|
27
|
+
system_message = nil
|
28
|
+
prompt = @messages.map do |message|
|
29
|
+
case message[:role].to_sym
|
30
|
+
when :assistant
|
31
|
+
"#{message[:content]}"
|
32
|
+
when :system
|
33
|
+
system_message = message[:content]
|
34
|
+
when :user
|
35
|
+
"[INST] #{message[:content]} [/INST]"
|
36
|
+
else
|
37
|
+
raise UnknownArgumentError
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
prompt = "<s>" + prompt.join("\n\n")
|
42
|
+
|
43
|
+
if system_message
|
44
|
+
prompt.sub("\[INST\]", "[INST] #{system_message}\n")
|
45
|
+
end
|
46
|
+
|
47
|
+
prompt
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "securerandom"
|
4
|
+
|
5
|
+
module InstLLM
|
6
|
+
module Response
|
7
|
+
class ChatResponse
|
8
|
+
attr_reader :created, :fingerprint, :stop_reason, :message, :model, :usage
|
9
|
+
|
10
|
+
def initialize(model:, message:, stop_reason:, usage:)
|
11
|
+
@created = Time.now.to_i
|
12
|
+
@fingerprint = SecureRandom.uuid
|
13
|
+
@message = message
|
14
|
+
@model = model
|
15
|
+
@stop_reason = stop_reason
|
16
|
+
@usage = usage
|
17
|
+
end
|
18
|
+
|
19
|
+
class << self
|
20
|
+
def from_claude(model:, response:)
|
21
|
+
new(
|
22
|
+
model: model,
|
23
|
+
message: { role: :assistant, content: response["content"][0]["text"] },
|
24
|
+
stop_reason: response["stop_reason"],
|
25
|
+
usage: {
|
26
|
+
input_tokens: response["usage"]["input_tokens"],
|
27
|
+
output_tokens: response["usage"]["output_tokens"]
|
28
|
+
}
|
29
|
+
)
|
30
|
+
end
|
31
|
+
|
32
|
+
def from_mistral(model:, response:)
|
33
|
+
new(
|
34
|
+
model: model,
|
35
|
+
message: { role: :assistant, content: response["outputs"][0]["text"] },
|
36
|
+
stop_reason: response["outputs"][0]["stop_reason"],
|
37
|
+
usage: {
|
38
|
+
input_tokens: -1,
|
39
|
+
output_tokens: -1
|
40
|
+
}
|
41
|
+
)
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module InstLLM
|
4
|
+
module Response
|
5
|
+
class EmbeddingResponse
|
6
|
+
attr_reader :model, :embeddings
|
7
|
+
|
8
|
+
def initialize(model, embeddings)
|
9
|
+
@model = model
|
10
|
+
@embeddings = embeddings
|
11
|
+
end
|
12
|
+
|
13
|
+
class << self
|
14
|
+
def from_cohere_embed(model:, response:)
|
15
|
+
embeddings = response["embeddings"].map.with_index do |embedding, i|
|
16
|
+
{ object: "embedding", embedding: embedding, index: i }
|
17
|
+
end
|
18
|
+
new(model, embeddings)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
data/lib/inst_llm.rb
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "inst_llm/version"
|
4
|
+
require_relative "inst_llm/client"
|
5
|
+
|
6
|
+
module InstLLM
|
7
|
+
class Error < StandardError; end
|
8
|
+
class UnknownArgumentError < StandardError; end
|
9
|
+
class ServiceQuotaExceededError < StandardError; end
|
10
|
+
class ThrottlingError < StandardError; end
|
11
|
+
class ValidationTooLongError < StandardError; end
|
12
|
+
class ValidationError < StandardError; end
|
13
|
+
end
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: inst_llm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Zach Pendleton
|
@@ -30,7 +30,17 @@ email:
|
|
30
30
|
executables: []
|
31
31
|
extensions: []
|
32
32
|
extra_rdoc_files: []
|
33
|
-
files:
|
33
|
+
files:
|
34
|
+
- lib/inst_llm.rb
|
35
|
+
- lib/inst_llm/client.rb
|
36
|
+
- lib/inst_llm/parameter/all.rb
|
37
|
+
- lib/inst_llm/parameter/claude_parameters.rb
|
38
|
+
- lib/inst_llm/parameter/cohere_embed_parameters.rb
|
39
|
+
- lib/inst_llm/parameter/mistral_parameters.rb
|
40
|
+
- lib/inst_llm/response/all.rb
|
41
|
+
- lib/inst_llm/response/chat_response.rb
|
42
|
+
- lib/inst_llm/response/embedding_response.rb
|
43
|
+
- lib/inst_llm/version.rb
|
34
44
|
homepage: https://instructure.com
|
35
45
|
licenses:
|
36
46
|
- MIT
|