inst_llm 0.1.0 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33b83be8e20977ac085278558163be88d0acc3b35436e2360bb7e8d0102cb907
4
- data.tar.gz: a47ac4068509bef5454dc37904ff2752203194fd3d855a06d126c2d0d9f425e8
3
+ metadata.gz: 105c04ef5d12358f4ff663bc8c1278e731c0664a40ac016e29ec2defc999b6a3
4
+ data.tar.gz: 4873e34b6e2677822a6f8cabb3cac06367a6062d04301810856b33b751b260fb
5
5
  SHA512:
6
- metadata.gz: 313914f2b440882aab196834e08f9ca9eed991380fd17174c74a97656a9a00eb34809a4e2b03ad87201a1c7521698288dd57b71cfacc32ee6b5b265b0737033e
7
- data.tar.gz: 26467fb44454136ccb202219d88991c2c3f8a5a304cbda5707929ce98f059fdb231400b1f356cc2037df62f8d85a3eea7d1d79220575bc4679ede4a4141b3914
6
+ metadata.gz: 6d7c85ea828972e507b8c680e340fb5ad76b5784277a441ecad7fb1ee215a2bc839c426daf76207cacf63b1ca9fe54ff4d7bf1bf3b9d1f78ef413a28e64ddc33
7
+ data.tar.gz: 7ffed97282a142a6adb4227334ddc333825b2de7225ec61b2f8d1796681a68b693be4799f73aee09cc4f5b482c2c8722a7633f9973d47bb42a8546880d75b997
@@ -0,0 +1,123 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "aws-sdk-bedrockruntime"
4
+ require "json"
5
+
6
+ require_relative "parameter/all"
7
+ require_relative "response/all"
8
+
9
+ module InstLLM
10
+ class Client
11
+ MODELS = {
12
+ "anthropic.claude-3-sonnet-20240229-v1:0": { format: :claude, provider: :bedrock, type: :chat },
13
+ "anthropic.claude-3-haiku-20240307-v1:0": { format: :claude, provider: :bedrock, type: :chat },
14
+
15
+ "mistral.mistral-7b-instruct-v0:2": { format: :mistral, provider: :bedrock, type: :chat },
16
+ "mistral.mixtral-8x7b-instruct-v0:1": { format: :mistral, provider: :bedrock, type: :chat },
17
+ "mistral.mistral-large-2402-v1:0": { format: :mistral, provider: :bedrock, type: :chat },
18
+
19
+ "cohere.embed-english-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
20
+ "cohere.embed-multilingual-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
21
+ }.freeze
22
+
23
+ def initialize(model, **options)
24
+ model = model.to_sym
25
+ raise UnknownArgumentError unless MODELS.key?(model)
26
+
27
+ @model = model
28
+ @options = options
29
+ end
30
+
31
+ def chat(messages, **options)
32
+ model = (options[:model] || options[:model_id] || @model).to_sym
33
+ raise ArgumentError, "Model #{model} is not a chat model" unless chat_model?(model)
34
+
35
+ response_factory(model, call(model, messages, **options))
36
+ end
37
+
38
+ def embedding(message, **options)
39
+ model = (options[:model] || options[:model_id] || @model).to_sym
40
+ raise ArgumentError, "Model #{model} is not an embedding model" unless embedding_model?(model)
41
+
42
+ embedding_response_factory(model, call(model, message, **options))
43
+ end
44
+
45
+ private
46
+
47
+ def call(model, messages, **options)
48
+ params = params_factory(model, messages, **options)
49
+
50
+ begin
51
+ res = client.invoke_model(**params)
52
+ rescue => error
53
+ raise map_error_type(error)
54
+ end
55
+
56
+ JSON.parse(res.body.read)
57
+ end
58
+
59
+ def chat_model?(model)
60
+ MODELS[model][:type] == :chat
61
+ end
62
+
63
+ def embedding_model?(model)
64
+ MODELS[model][:type] == :embedding
65
+ end
66
+
67
+ def client
68
+ return @client if @client
69
+
70
+ case MODELS[@model][:provider]
71
+ when :bedrock
72
+ @client = Aws::BedrockRuntime::Client.new(**@options)
73
+ else
74
+ raise UnknownArgumentError
75
+ end
76
+
77
+ @client
78
+ end
79
+
80
+ def map_error_type(error)
81
+ mapped_error_type = nil
82
+
83
+ case MODELS[@model][:provider]
84
+ when :bedrock
85
+ case error
86
+ when Aws::BedrockRuntime::Errors::ServiceQuotaExceededException
87
+ mapped_error_type = ServiceQuotaExceededError
88
+ when Aws::BedrockRuntime::Errors::ThrottlingException
89
+ mapped_error_type = ThrottlingError
90
+ when Aws::BedrockRuntime::Errors::ValidationException
91
+ if error.message.include?("too long")
92
+ mapped_error_type = ValidationTooLongError
93
+ else
94
+ mapped_error_type = ValidationError
95
+ end
96
+ else
97
+ mapped_error_type = Error
98
+ end
99
+ else
100
+ raise UnknownArgumentError
101
+ end
102
+
103
+ mapped_error_type.new(error.message)
104
+ end
105
+
106
+ def params_factory(model, messages, **options)
107
+ params_table = {
108
+ claude: Parameter::ClaudeParameters,
109
+ cohere_embed: Parameter::CohereEmbedParameters,
110
+ mistral: Parameter::MistralParameters
111
+ }
112
+ params_table[MODELS[model][:format]].new(model: model, messages: messages, **options)
113
+ end
114
+
115
+ def embedding_response_factory(model, response)
116
+ Response::EmbeddingResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
117
+ end
118
+
119
+ def response_factory(model, response)
120
+ Response::ChatResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
121
+ end
122
+ end
123
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'claude_parameters'
4
+ require_relative 'cohere_embed_parameters'
5
+ require_relative 'mistral_parameters'
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class ClaudeParameters
6
+ DEFAULT_OPTIONS = {
7
+ anthropic_version: "bedrock-2023-05-31",
8
+ max_tokens: 2000,
9
+ stop_sequences: nil,
10
+ temperature: nil,
11
+ top_k: nil,
12
+ top_p: nil,
13
+ system: nil,
14
+ }.freeze
15
+
16
+ def initialize(model:, messages: [], **options)
17
+ @messages = messages
18
+ @model = model
19
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
20
+ end
21
+
22
+ def to_hash
23
+ { model_id: @model, body: { messages: @messages }.merge(@options).to_json }
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class CohereEmbedParameters
6
+ DEFAULT_OPTIONS = {
7
+ input_type: nil,
8
+ truncate: nil
9
+ }.freeze
10
+
11
+ def initialize(model:, texts: [], **options)
12
+ @model = model
13
+ @texts = texts
14
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
15
+ end
16
+
17
+ def to_hash
18
+ { model_id: @model, body: { texts: @texts }.merge(@options).to_json }
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,51 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class MistralParameters
6
+ DEFAULT_OPTIONS = {
7
+ max_tokens: nil,
8
+ stop: nil,
9
+ temperature: nil,
10
+ top_p: nil,
11
+ top_k: nil
12
+ }.freeze
13
+
14
+ def initialize(model:, messages:, **options)
15
+ @model = model
16
+ @messages = messages
17
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
18
+ end
19
+
20
+ def to_hash
21
+ { model_id: @model, body: { prompt: prompt }.merge(@options).to_json }
22
+ end
23
+
24
+ private
25
+
26
+ def prompt
27
+ system_message = nil
28
+ prompt = @messages.map do |message|
29
+ case message[:role].to_sym
30
+ when :assistant
31
+ "#{message[:content]}"
32
+ when :system
33
+ system_message = message[:content]
34
+ when :user
35
+ "[INST] #{message[:content]} [/INST]"
36
+ else
37
+ raise UnknownArgumentError
38
+ end
39
+ end
40
+
41
+ prompt = "<s>" + prompt.join("\n\n")
42
+
43
+ if system_message
44
+ prompt.sub("\[INST\]", "[INST] #{system_message}\n")
45
+ end
46
+
47
+ prompt
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,4 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "chat_response"
4
+ require_relative "embedding_response"
@@ -0,0 +1,46 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "securerandom"
4
+
5
+ module InstLLM
6
+ module Response
7
+ class ChatResponse
8
+ attr_reader :created, :fingerprint, :stop_reason, :message, :model, :usage
9
+
10
+ def initialize(model:, message:, stop_reason:, usage:)
11
+ @created = Time.now.to_i
12
+ @fingerprint = SecureRandom.uuid
13
+ @message = message
14
+ @model = model
15
+ @stop_reason = stop_reason
16
+ @usage = usage
17
+ end
18
+
19
+ class << self
20
+ def from_claude(model:, response:)
21
+ new(
22
+ model: model,
23
+ message: { role: :assistant, content: response["content"][0]["text"] },
24
+ stop_reason: response["stop_reason"],
25
+ usage: {
26
+ input_tokens: response["usage"]["input_tokens"],
27
+ output_tokens: response["usage"]["output_tokens"]
28
+ }
29
+ )
30
+ end
31
+
32
+ def from_mistral(model:, response:)
33
+ new(
34
+ model: model,
35
+ message: { role: :assistant, content: response["outputs"][0]["text"] },
36
+ stop_reason: response["outputs"][0]["stop_reason"],
37
+ usage: {
38
+ input_tokens: -1,
39
+ output_tokens: -1
40
+ }
41
+ )
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Response
5
+ class EmbeddingResponse
6
+ attr_reader :model, :embeddings
7
+
8
+ def initialize(model, embeddings)
9
+ @model = model
10
+ @embeddings = embeddings
11
+ end
12
+
13
+ class << self
14
+ def from_cohere_embed(model:, response:)
15
+ embeddings = response["embeddings"].map.with_index do |embedding, i|
16
+ { object: "embedding", embedding: embedding, index: i }
17
+ end
18
+ new(model, embeddings)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ VERSION = "0.2.0"
5
+ end
data/lib/inst_llm.rb ADDED
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "inst_llm/version"
4
+ require_relative "inst_llm/client"
5
+
6
+ module InstLLM
7
+ class Error < StandardError; end
8
+ class UnknownArgumentError < StandardError; end
9
+ class ServiceQuotaExceededError < StandardError; end
10
+ class ThrottlingError < StandardError; end
11
+ class ValidationTooLongError < StandardError; end
12
+ class ValidationError < StandardError; end
13
+ end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: inst_llm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Zach Pendleton
@@ -30,7 +30,17 @@ email:
30
30
  executables: []
31
31
  extensions: []
32
32
  extra_rdoc_files: []
33
- files: []
33
+ files:
34
+ - lib/inst_llm.rb
35
+ - lib/inst_llm/client.rb
36
+ - lib/inst_llm/parameter/all.rb
37
+ - lib/inst_llm/parameter/claude_parameters.rb
38
+ - lib/inst_llm/parameter/cohere_embed_parameters.rb
39
+ - lib/inst_llm/parameter/mistral_parameters.rb
40
+ - lib/inst_llm/response/all.rb
41
+ - lib/inst_llm/response/chat_response.rb
42
+ - lib/inst_llm/response/embedding_response.rb
43
+ - lib/inst_llm/version.rb
34
44
  homepage: https://instructure.com
35
45
  licenses:
36
46
  - MIT