inst_llm 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33b83be8e20977ac085278558163be88d0acc3b35436e2360bb7e8d0102cb907
4
- data.tar.gz: a47ac4068509bef5454dc37904ff2752203194fd3d855a06d126c2d0d9f425e8
3
+ metadata.gz: 105c04ef5d12358f4ff663bc8c1278e731c0664a40ac016e29ec2defc999b6a3
4
+ data.tar.gz: 4873e34b6e2677822a6f8cabb3cac06367a6062d04301810856b33b751b260fb
5
5
  SHA512:
6
- metadata.gz: 313914f2b440882aab196834e08f9ca9eed991380fd17174c74a97656a9a00eb34809a4e2b03ad87201a1c7521698288dd57b71cfacc32ee6b5b265b0737033e
7
- data.tar.gz: 26467fb44454136ccb202219d88991c2c3f8a5a304cbda5707929ce98f059fdb231400b1f356cc2037df62f8d85a3eea7d1d79220575bc4679ede4a4141b3914
6
+ metadata.gz: 6d7c85ea828972e507b8c680e340fb5ad76b5784277a441ecad7fb1ee215a2bc839c426daf76207cacf63b1ca9fe54ff4d7bf1bf3b9d1f78ef413a28e64ddc33
7
+ data.tar.gz: 7ffed97282a142a6adb4227334ddc333825b2de7225ec61b2f8d1796681a68b693be4799f73aee09cc4f5b482c2c8722a7633f9973d47bb42a8546880d75b997
@@ -0,0 +1,123 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "aws-sdk-bedrockruntime"
4
+ require "json"
5
+
6
+ require_relative "parameter/all"
7
+ require_relative "response/all"
8
+
9
+ module InstLLM
10
+ class Client
11
+ MODELS = {
12
+ "anthropic.claude-3-sonnet-20240229-v1:0": { format: :claude, provider: :bedrock, type: :chat },
13
+ "anthropic.claude-3-haiku-20240307-v1:0": { format: :claude, provider: :bedrock, type: :chat },
14
+
15
+ "mistral.mistral-7b-instruct-v0:2": { format: :mistral, provider: :bedrock, type: :chat },
16
+ "mistral.mixtral-8x7b-instruct-v0:1": { format: :mistral, provider: :bedrock, type: :chat },
17
+ "mistral.mistral-large-2402-v1:0": { format: :mistral, provider: :bedrock, type: :chat },
18
+
19
+ "cohere.embed-english-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
20
+ "cohere.embed-multilingual-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
21
+ }.freeze
22
+
23
+ def initialize(model, **options)
24
+ model = model.to_sym
25
+ raise UnknownArgumentError unless MODELS.key?(model)
26
+
27
+ @model = model
28
+ @options = options
29
+ end
30
+
31
+ def chat(messages, **options)
32
+ model = (options[:model] || options[:model_id] || @model).to_sym
33
+ raise ArgumentError, "Model #{model} is not a chat model" unless chat_model?(model)
34
+
35
+ response_factory(model, call(model, messages, **options))
36
+ end
37
+
38
+ def embedding(message, **options)
39
+ model = (options[:model] || options[:model_id] || @model).to_sym
40
+ raise ArgumentError, "Model #{model} is not an embedding model" unless embedding_model?(model)
41
+
42
+ embedding_response_factory(model, call(model, message, **options))
43
+ end
44
+
45
+ private
46
+
47
+ def call(model, messages, **options)
48
+ params = params_factory(model, messages, **options)
49
+
50
+ begin
51
+ res = client.invoke_model(**params)
52
+ rescue => error
53
+ raise map_error_type(error)
54
+ end
55
+
56
+ JSON.parse(res.body.read)
57
+ end
58
+
59
+ def chat_model?(model)
60
+ MODELS[model][:type] == :chat
61
+ end
62
+
63
+ def embedding_model?(model)
64
+ MODELS[model][:type] == :embedding
65
+ end
66
+
67
+ def client
68
+ return @client if @client
69
+
70
+ case MODELS[@model][:provider]
71
+ when :bedrock
72
+ @client = Aws::BedrockRuntime::Client.new(**@options)
73
+ else
74
+ raise UnknownArgumentError
75
+ end
76
+
77
+ @client
78
+ end
79
+
80
+ def map_error_type(error)
81
+ mapped_error_type = nil
82
+
83
+ case MODELS[@model][:provider]
84
+ when :bedrock
85
+ case error
86
+ when Aws::BedrockRuntime::Errors::ServiceQuotaExceededException
87
+ mapped_error_type = ServiceQuotaExceededError
88
+ when Aws::BedrockRuntime::Errors::ThrottlingException
89
+ mapped_error_type = ThrottlingError
90
+ when Aws::BedrockRuntime::Errors::ValidationException
91
+ if error.message.include?("too long")
92
+ mapped_error_type = ValidationTooLongError
93
+ else
94
+ mapped_error_type = ValidationError
95
+ end
96
+ else
97
+ mapped_error_type = Error
98
+ end
99
+ else
100
+ raise UnknownArgumentError
101
+ end
102
+
103
+ mapped_error_type.new(error.message)
104
+ end
105
+
106
+ def params_factory(model, messages, **options)
107
+ params_table = {
108
+ claude: Parameter::ClaudeParameters,
109
+ cohere_embed: Parameter::CohereEmbedParameters,
110
+ mistral: Parameter::MistralParameters
111
+ }
112
+ params_table[MODELS[model][:format]].new(model: model, messages: messages, **options)
113
+ end
114
+
115
+ def embedding_response_factory(model, response)
116
+ Response::EmbeddingResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
117
+ end
118
+
119
+ def response_factory(model, response)
120
+ Response::ChatResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
121
+ end
122
+ end
123
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'claude_parameters'
4
+ require_relative 'cohere_embed_parameters'
5
+ require_relative 'mistral_parameters'
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class ClaudeParameters
6
+ DEFAULT_OPTIONS = {
7
+ anthropic_version: "bedrock-2023-05-31",
8
+ max_tokens: 2000,
9
+ stop_sequences: nil,
10
+ temperature: nil,
11
+ top_k: nil,
12
+ top_p: nil,
13
+ system: nil,
14
+ }.freeze
15
+
16
+ def initialize(model:, messages: [], **options)
17
+ @messages = messages
18
+ @model = model
19
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
20
+ end
21
+
22
+ def to_hash
23
+ { model_id: @model, body: { messages: @messages }.merge(@options).to_json }
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class CohereEmbedParameters
6
+ DEFAULT_OPTIONS = {
7
+ input_type: nil,
8
+ truncate: nil
9
+ }.freeze
10
+
11
+ def initialize(model:, texts: [], **options)
12
+ @model = model
13
+ @texts = texts
14
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
15
+ end
16
+
17
+ def to_hash
18
+ { model_id: @model, body: { texts: @texts }.merge(@options).to_json }
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,51 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class MistralParameters
6
+ DEFAULT_OPTIONS = {
7
+ max_tokens: nil,
8
+ stop: nil,
9
+ temperature: nil,
10
+ top_p: nil,
11
+ top_k: nil
12
+ }.freeze
13
+
14
+ def initialize(model:, messages:, **options)
15
+ @model = model
16
+ @messages = messages
17
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
18
+ end
19
+
20
+ def to_hash
21
+ { model_id: @model, body: { prompt: prompt }.merge(@options).to_json }
22
+ end
23
+
24
+ private
25
+
26
+ def prompt
27
+ system_message = nil
28
+ prompt = @messages.map do |message|
29
+ case message[:role].to_sym
30
+ when :assistant
31
+ "#{message[:content]}"
32
+ when :system
33
+ system_message = message[:content]
34
+ when :user
35
+ "[INST] #{message[:content]} [/INST]"
36
+ else
37
+ raise UnknownArgumentError
38
+ end
39
+ end
40
+
41
+ prompt = "<s>" + prompt.join("\n\n")
42
+
43
+ if system_message
44
+ prompt.sub("\[INST\]", "[INST] #{system_message}\n")
45
+ end
46
+
47
+ prompt
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,4 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "chat_response"
4
+ require_relative "embedding_response"
@@ -0,0 +1,46 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "securerandom"
4
+
5
+ module InstLLM
6
+ module Response
7
+ class ChatResponse
8
+ attr_reader :created, :fingerprint, :stop_reason, :message, :model, :usage
9
+
10
+ def initialize(model:, message:, stop_reason:, usage:)
11
+ @created = Time.now.to_i
12
+ @fingerprint = SecureRandom.uuid
13
+ @message = message
14
+ @model = model
15
+ @stop_reason = stop_reason
16
+ @usage = usage
17
+ end
18
+
19
+ class << self
20
+ def from_claude(model:, response:)
21
+ new(
22
+ model: model,
23
+ message: { role: :assistant, content: response["content"][0]["text"] },
24
+ stop_reason: response["stop_reason"],
25
+ usage: {
26
+ input_tokens: response["usage"]["input_tokens"],
27
+ output_tokens: response["usage"]["output_tokens"]
28
+ }
29
+ )
30
+ end
31
+
32
+ def from_mistral(model:, response:)
33
+ new(
34
+ model: model,
35
+ message: { role: :assistant, content: response["outputs"][0]["text"] },
36
+ stop_reason: response["outputs"][0]["stop_reason"],
37
+ usage: {
38
+ input_tokens: -1,
39
+ output_tokens: -1
40
+ }
41
+ )
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Response
5
+ class EmbeddingResponse
6
+ attr_reader :model, :embeddings
7
+
8
+ def initialize(model, embeddings)
9
+ @model = model
10
+ @embeddings = embeddings
11
+ end
12
+
13
+ class << self
14
+ def from_cohere_embed(model:, response:)
15
+ embeddings = response["embeddings"].map.with_index do |embedding, i|
16
+ { object: "embedding", embedding: embedding, index: i }
17
+ end
18
+ new(model, embeddings)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ VERSION = "0.2.0"
5
+ end
data/lib/inst_llm.rb ADDED
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "inst_llm/version"
4
+ require_relative "inst_llm/client"
5
+
6
+ module InstLLM
7
+ class Error < StandardError; end
8
+ class UnknownArgumentError < StandardError; end
9
+ class ServiceQuotaExceededError < StandardError; end
10
+ class ThrottlingError < StandardError; end
11
+ class ValidationTooLongError < StandardError; end
12
+ class ValidationError < StandardError; end
13
+ end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: inst_llm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Zach Pendleton
@@ -30,7 +30,17 @@ email:
30
30
  executables: []
31
31
  extensions: []
32
32
  extra_rdoc_files: []
33
- files: []
33
+ files:
34
+ - lib/inst_llm.rb
35
+ - lib/inst_llm/client.rb
36
+ - lib/inst_llm/parameter/all.rb
37
+ - lib/inst_llm/parameter/claude_parameters.rb
38
+ - lib/inst_llm/parameter/cohere_embed_parameters.rb
39
+ - lib/inst_llm/parameter/mistral_parameters.rb
40
+ - lib/inst_llm/response/all.rb
41
+ - lib/inst_llm/response/chat_response.rb
42
+ - lib/inst_llm/response/embedding_response.rb
43
+ - lib/inst_llm/version.rb
34
44
  homepage: https://instructure.com
35
45
  licenses:
36
46
  - MIT