inst_llm 0.1.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33b83be8e20977ac085278558163be88d0acc3b35436e2360bb7e8d0102cb907
4
- data.tar.gz: a47ac4068509bef5454dc37904ff2752203194fd3d855a06d126c2d0d9f425e8
3
+ metadata.gz: 9dd355405ae5b7659a6cee01536e9ca422a0658208e125303a6206e4e6129863
4
+ data.tar.gz: 8e0955a5bf08923385e093a7a6d739ff224695a345d52684d2c435675e3aa4c9
5
5
  SHA512:
6
- metadata.gz: 313914f2b440882aab196834e08f9ca9eed991380fd17174c74a97656a9a00eb34809a4e2b03ad87201a1c7521698288dd57b71cfacc32ee6b5b265b0737033e
7
- data.tar.gz: 26467fb44454136ccb202219d88991c2c3f8a5a304cbda5707929ce98f059fdb231400b1f356cc2037df62f8d85a3eea7d1d79220575bc4679ede4a4141b3914
6
+ metadata.gz: 787201e52ae44238a3a22d53a93f83fd15718b251a49b46fded527aab1e30eda9fe805574c1f42c0369a5586679a9d8f27c0cae5bb689c47c2c8fb65355e81b3
7
+ data.tar.gz: 14ea0088107fc9be654483530f314b5f00ea4d72686326b0d6c1349470446de8b8731fe06313081a1b78c78d6b9f2d8bff125752e60183b2dc3978626ffaf83a
@@ -0,0 +1,126 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "aws-sdk-bedrockruntime"
4
+ require "json"
5
+
6
+ require_relative "parameter/all"
7
+ require_relative "response/all"
8
+
9
+ module InstLLM
10
+ class Client
11
+ MODELS = {
12
+ "anthropic.claude-3-sonnet-20240229-v1:0": { format: :claude, provider: :bedrock, type: :chat },
13
+ "anthropic.claude-3-haiku-20240307-v1:0": { format: :claude, provider: :bedrock, type: :chat },
14
+
15
+ "mistral.mistral-7b-instruct-v0:2": { format: :mistral, provider: :bedrock, type: :chat },
16
+ "mistral.mixtral-8x7b-instruct-v0:1": { format: :mistral, provider: :bedrock, type: :chat },
17
+ "mistral.mistral-large-2402-v1:0": { format: :mistral, provider: :bedrock, type: :chat },
18
+
19
+ "cohere.embed-english-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
20
+ "cohere.embed-multilingual-v3": { format: :cohere_embed, provider: :bedrock, type: :embedding },
21
+ }.freeze
22
+
23
+ def initialize(model, **options)
24
+ model = model.to_sym
25
+ raise UnknownArgumentError unless MODELS.key?(model)
26
+
27
+ @model = model
28
+ @options = options
29
+ end
30
+
31
+ def chat(messages, **options)
32
+ model = (options[:model] || options[:model_id] || @model).to_sym
33
+ raise ArgumentError, "Model #{model} is not a chat model" unless chat_model?(model)
34
+
35
+ response_factory(model, call(model, messages, **options))
36
+ end
37
+
38
+ def embedding(message, **options)
39
+ model = (options[:model] || options[:model_id] || @model).to_sym
40
+ raise ArgumentError, "Model #{model} is not an embedding model" unless embedding_model?(model)
41
+
42
+ embedding_response_factory(model, call(model, message, **options))
43
+ end
44
+
45
+ private
46
+
47
+ def call(model, messages, **options)
48
+ params = params_factory(model, messages, **options)
49
+
50
+ begin
51
+ res = client.invoke_model(
52
+ content_type: "application/json",
53
+ **params
54
+ )
55
+ rescue => error
56
+ raise map_error_type(error)
57
+ end
58
+
59
+ JSON.parse(res.body.read)
60
+ end
61
+
62
+ def chat_model?(model)
63
+ MODELS[model][:type] == :chat
64
+ end
65
+
66
+ def embedding_model?(model)
67
+ MODELS[model][:type] == :embedding
68
+ end
69
+
70
+ def client
71
+ return @client if @client
72
+
73
+ case MODELS[@model][:provider]
74
+ when :bedrock
75
+ @client = Aws::BedrockRuntime::Client.new(**@options)
76
+ else
77
+ raise UnknownArgumentError
78
+ end
79
+
80
+ @client
81
+ end
82
+
83
+ def map_error_type(error)
84
+ mapped_error_type = nil
85
+
86
+ case MODELS[@model][:provider]
87
+ when :bedrock
88
+ case error
89
+ when Aws::BedrockRuntime::Errors::ServiceQuotaExceededException
90
+ mapped_error_type = ServiceQuotaExceededError
91
+ when Aws::BedrockRuntime::Errors::ThrottlingException
92
+ mapped_error_type = ThrottlingError
93
+ when Aws::BedrockRuntime::Errors::ValidationException
94
+ if error.message.include?("too long")
95
+ mapped_error_type = ValidationTooLongError
96
+ else
97
+ mapped_error_type = ValidationError
98
+ end
99
+ else
100
+ mapped_error_type = Error
101
+ end
102
+ else
103
+ raise UnknownArgumentError
104
+ end
105
+
106
+ mapped_error_type.new(error.message)
107
+ end
108
+
109
+ def params_factory(model, messages, **options)
110
+ params_table = {
111
+ claude: Parameter::ClaudeParameters,
112
+ cohere_embed: Parameter::CohereEmbedParameters,
113
+ mistral: Parameter::MistralParameters
114
+ }
115
+ params_table[MODELS[model][:format]].new(model: model, messages: messages, **options)
116
+ end
117
+
118
+ def embedding_response_factory(model, response)
119
+ Response::EmbeddingResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
120
+ end
121
+
122
+ def response_factory(model, response)
123
+ Response::ChatResponse.send(:"from_#{MODELS[model][:format]}", model: model, response: response)
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'claude_parameters'
4
+ require_relative 'cohere_embed_parameters'
5
+ require_relative 'mistral_parameters'
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class ClaudeParameters
6
+ DEFAULT_OPTIONS = {
7
+ anthropic_version: "bedrock-2023-05-31",
8
+ max_tokens: 2000,
9
+ stop_sequences: nil,
10
+ temperature: nil,
11
+ top_k: nil,
12
+ top_p: nil,
13
+ system: nil,
14
+ }.freeze
15
+
16
+ def initialize(model:, messages: [], **options)
17
+ @messages = messages
18
+ @model = model
19
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
20
+ end
21
+
22
+ def to_hash
23
+ { model_id: @model, body: { messages: @messages }.merge(@options).to_json }
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class CohereEmbedParameters
6
+ DEFAULT_OPTIONS = {
7
+ input_type: nil,
8
+ truncate: nil
9
+ }.freeze
10
+
11
+ def initialize(model:, texts: [], **options)
12
+ @model = model
13
+ @texts = texts
14
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
15
+ end
16
+
17
+ def to_hash
18
+ { model_id: @model, body: { texts: @texts }.merge(@options).to_json }
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,51 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Parameter
5
+ class MistralParameters
6
+ DEFAULT_OPTIONS = {
7
+ max_tokens: nil,
8
+ stop: nil,
9
+ temperature: nil,
10
+ top_p: nil,
11
+ top_k: nil
12
+ }.freeze
13
+
14
+ def initialize(model:, messages:, **options)
15
+ @model = model
16
+ @messages = messages
17
+ @options = DEFAULT_OPTIONS.merge(options.slice(*DEFAULT_OPTIONS.keys)).compact
18
+ end
19
+
20
+ def to_hash
21
+ { model_id: @model, body: { prompt: prompt }.merge(@options).to_json }
22
+ end
23
+
24
+ private
25
+
26
+ def prompt
27
+ system_message = nil
28
+ prompt = @messages.map do |message|
29
+ case message[:role].to_sym
30
+ when :assistant
31
+ "#{message[:content]}"
32
+ when :system
33
+ system_message = message[:content]
34
+ when :user
35
+ "[INST] #{message[:content]} [/INST]"
36
+ else
37
+ raise UnknownArgumentError
38
+ end
39
+ end
40
+
41
+ prompt = "<s>" + prompt.join("\n\n")
42
+
43
+ if system_message
44
+ prompt.sub("\[INST\]", "[INST] #{system_message}\n")
45
+ end
46
+
47
+ prompt
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,4 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "chat_response"
4
+ require_relative "embedding_response"
@@ -0,0 +1,46 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "securerandom"
4
+
5
+ module InstLLM
6
+ module Response
7
+ class ChatResponse
8
+ attr_reader :created, :fingerprint, :stop_reason, :message, :model, :usage
9
+
10
+ def initialize(model:, message:, stop_reason:, usage:)
11
+ @created = Time.now.to_i
12
+ @fingerprint = SecureRandom.uuid
13
+ @message = message
14
+ @model = model
15
+ @stop_reason = stop_reason
16
+ @usage = usage
17
+ end
18
+
19
+ class << self
20
+ def from_claude(model:, response:)
21
+ new(
22
+ model: model,
23
+ message: { role: :assistant, content: response["content"][0]["text"] },
24
+ stop_reason: response["stop_reason"],
25
+ usage: {
26
+ input_tokens: response["usage"]["input_tokens"],
27
+ output_tokens: response["usage"]["output_tokens"]
28
+ }
29
+ )
30
+ end
31
+
32
+ def from_mistral(model:, response:)
33
+ new(
34
+ model: model,
35
+ message: { role: :assistant, content: response["outputs"][0]["text"] },
36
+ stop_reason: response["outputs"][0]["stop_reason"],
37
+ usage: {
38
+ input_tokens: -1,
39
+ output_tokens: -1
40
+ }
41
+ )
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ module Response
5
+ class EmbeddingResponse
6
+ attr_reader :model, :embeddings
7
+
8
+ def initialize(model, embeddings)
9
+ @model = model
10
+ @embeddings = embeddings
11
+ end
12
+
13
+ class << self
14
+ def from_cohere_embed(model:, response:)
15
+ embeddings = response["embeddings"].map.with_index do |embedding, i|
16
+ { object: "embedding", embedding: embedding, index: i }
17
+ end
18
+ new(model, embeddings)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module InstLLM
4
+ VERSION = "0.2.1"
5
+ end
data/lib/inst_llm.rb ADDED
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "inst_llm/version"
4
+ require_relative "inst_llm/client"
5
+
6
+ module InstLLM
7
+ class Error < StandardError; end
8
+ class UnknownArgumentError < StandardError; end
9
+ class ServiceQuotaExceededError < StandardError; end
10
+ class ThrottlingError < StandardError; end
11
+ class ValidationTooLongError < StandardError; end
12
+ class ValidationError < StandardError; end
13
+ end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: inst_llm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Zach Pendleton
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-04-18 00:00:00.000000000 Z
11
+ date: 2024-04-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: aws-sdk-bedrockruntime
@@ -30,7 +30,17 @@ email:
30
30
  executables: []
31
31
  extensions: []
32
32
  extra_rdoc_files: []
33
- files: []
33
+ files:
34
+ - lib/inst_llm.rb
35
+ - lib/inst_llm/client.rb
36
+ - lib/inst_llm/parameter/all.rb
37
+ - lib/inst_llm/parameter/claude_parameters.rb
38
+ - lib/inst_llm/parameter/cohere_embed_parameters.rb
39
+ - lib/inst_llm/parameter/mistral_parameters.rb
40
+ - lib/inst_llm/response/all.rb
41
+ - lib/inst_llm/response/chat_response.rb
42
+ - lib/inst_llm/response/embedding_response.rb
43
+ - lib/inst_llm/version.rb
34
44
  homepage: https://instructure.com
35
45
  licenses:
36
46
  - MIT