mistral 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,229 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mistral/client_base'
4
+
5
+ module Mistral
6
+ # Synchronous wrapper around the async client
7
+ class Client < ClientBase
8
+ def initialize(
9
+ api_key: ENV['MISTRAL_API_KEY'],
10
+ endpoint: ENDPOINT,
11
+ max_retries: 5,
12
+ timeout: 120
13
+ )
14
+ super(endpoint: endpoint, api_key: api_key, max_retries: max_retries, timeout: timeout)
15
+
16
+ @client = HTTP.persistent(ENDPOINT)
17
+ .follow
18
+ .timeout(timeout)
19
+ .use(:line_iterable_body)
20
+ .headers('Accept' => 'application/json',
21
+ 'User-Agent' => "mistral-client-ruby/#{VERSION}",
22
+ 'Authorization' => "Bearer #{@api_key}",
23
+ 'Content-Type' => 'application/json'
24
+ )
25
+ end
26
+
27
+ # A chat endpoint that returns a single response.
28
+ #
29
+ # @param messages [Array<ChatMessage>] An array of messages to chat with, e.g.
30
+ # [{role: 'user', content: 'What is the best French cheese?'}]
31
+ # @param model [String] The name of the model to chat with, e.g. mistral-tiny
32
+ # @param tools [Array<Hash>] A list of tools to use.
33
+ # @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
34
+ # @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
35
+ # @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
36
+ # @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
37
+ # @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
38
+ # @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
39
+ # @param tool_choice [String, ToolChoice] The tool choice.
40
+ # @param response_format [Hash<String, String>, ResponseFormat] The response format.
41
+ # @return [ChatCompletionResponse] A response object containing the generated text.
42
+ #
43
+ def chat(
44
+ messages:,
45
+ model: nil,
46
+ tools: nil,
47
+ temperature: nil,
48
+ max_tokens: nil,
49
+ top_p: nil,
50
+ random_seed: nil,
51
+ safe_mode: false,
52
+ safe_prompt: false,
53
+ tool_choice: nil,
54
+ response_format: nil
55
+ )
56
+ request = make_chat_request(
57
+ messages: messages,
58
+ model: model,
59
+ tools: tools,
60
+ temperature: temperature,
61
+ max_tokens: max_tokens,
62
+ top_p: top_p,
63
+ random_seed: random_seed,
64
+ stream: false,
65
+ safe_prompt: safe_mode || safe_prompt,
66
+ tool_choice: tool_choice,
67
+ response_format: response_format
68
+ )
69
+
70
+ single_response = request('post', 'v1/chat/completions', json: request)
71
+
72
+ single_response.each do |response|
73
+ return ChatCompletionResponse.new(response)
74
+ end
75
+
76
+ raise Mistral::Error.new(message: 'No response received')
77
+ end
78
+
79
+ # A chat endpoint that streams responses.
80
+ #
81
+ # @param messages [Array<Any>] An array of messages to chat with, e.g.
82
+ # [{role: 'user', content: 'What is the best French cheese?'}]
83
+ # @param model [String] The name of the model to chat with, e.g. mistral-tiny
84
+ # @param tools [Array<Hash>] A list of tools to use.
85
+ # @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
86
+ # @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
87
+ # @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
88
+ # @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
89
+ # @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
90
+ # @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
91
+ # @param tool_choice [String, ToolChoice] The tool choice.
92
+ # @param response_format [Hash<String, String>, ResponseFormat] The response format.
93
+ # @return [Enumerator<ChatCompletionStreamResponse>] A generator that yields ChatCompletionStreamResponse objects.
94
+ #
95
+ def chat_stream(
96
+ messages:,
97
+ model: nil,
98
+ tools: nil,
99
+ temperature: nil,
100
+ max_tokens: nil,
101
+ top_p: nil,
102
+ random_seed: nil,
103
+ safe_mode: false,
104
+ safe_prompt: false,
105
+ tool_choice: nil,
106
+ response_format: nil
107
+ )
108
+ request = make_chat_request(
109
+ messages: messages,
110
+ model: model,
111
+ tools: tools,
112
+ temperature: temperature,
113
+ max_tokens: max_tokens,
114
+ top_p: top_p,
115
+ random_seed: random_seed,
116
+ stream: true,
117
+ safe_prompt: safe_mode || safe_prompt,
118
+ tool_choice: tool_choice,
119
+ response_format: response_format
120
+ )
121
+
122
+ Enumerator.new do |yielder|
123
+ request('post', 'v1/chat/completions', json: request, stream: true).each do |json_response|
124
+ yielder << ChatCompletionStreamResponse.new(**json_response)
125
+ end
126
+ end
127
+ end
128
+
129
+ # An embeddings endpoint that returns embeddings for a single, or batch of inputs
130
+ #
131
+ # @param model [String] The embedding model to use, e.g. mistral-embed
132
+ # @param input [String, Array<String>] The input to embed, e.g. ['What is the best French cheese?']
133
+ #
134
+ # @return [EmbeddingResponse] A response object containing the embeddings.
135
+ #
136
+ def embeddings(model:, input:)
137
+ request = { model: model, input: input }
138
+ singleton_response = request('post', 'v1/embeddings', json: request)
139
+
140
+ singleton_response.each do |response|
141
+ return EmbeddingResponse.new(response)
142
+ end
143
+
144
+ raise Mistral::Error.new(message: 'No response received')
145
+ end
146
+
147
+ # Returns a list of the available models
148
+ #
149
+ # @return [ModelList] A response object containing the list of models.
150
+ #
151
+ def list_models
152
+ singleton_response = request('get', 'v1/models')
153
+
154
+ singleton_response.each do |response|
155
+ return ModelList.new(response)
156
+ end
157
+
158
+ raise Mistral::Error.new(message: 'No response received')
159
+ end
160
+
161
+ private
162
+
163
+ def request(method, path, json: nil, stream: false, attempt: 1)
164
+ url = File.join(@endpoint, path)
165
+ headers = {}
166
+ headers['Accept'] = 'text/event-stream' if stream
167
+
168
+ @logger.debug("Sending request: #{method.upcase} #{url} #{json}")
169
+
170
+ Enumerator.new do |yielder|
171
+ response = @client.headers(headers).request(method.downcase.to_sym, url, json: json)
172
+ check_response_status_codes(response)
173
+
174
+ if stream
175
+ response.body.each_line do |line|
176
+ processed_line = process_line(line)
177
+ next if processed_line.nil?
178
+
179
+ yielder << processed_line
180
+ end
181
+ else
182
+ yielder << check_response(response)
183
+ end
184
+ rescue HTTP::ConnectionError => e
185
+ raise Mistral::ConnectionError, e.message
186
+ rescue HTTP::RequestError => e
187
+ raise Mistral::Error, "Unexpected exception (#{e.class}): #{e.message}"
188
+ rescue JSON::ParserError
189
+ raise Mistral::APIError.from_response(response, message: "Failed to decode json body: #{response.body}")
190
+ rescue Mistral::APIStatusError => e
191
+ attempt += 1
192
+
193
+ raise Mistral::APIStatusError.from_response(response, message: e.message) if attempt > @max_retries
194
+
195
+ backoff = 2.0**attempt # exponential backoff
196
+ sleep(backoff)
197
+
198
+ # Retry and yield the response
199
+ request(method, path, json: json, stream: stream, attempt: attempt).each do |r|
200
+ yielder << r
201
+ end
202
+ end
203
+ end
204
+
205
+ def check_response(response)
206
+ check_response_status_codes(response)
207
+
208
+ json_response = JSON.parse(response.body.to_s)
209
+
210
+ if !json_response.key?('object')
211
+ raise Mistral::Error, "Unexpected response: #{json_response}"
212
+ elsif json_response['object'] == 'error' # has errors
213
+ raise Mistral::APIError.from_response(response, message: json_response['message'])
214
+ end
215
+
216
+ json_response
217
+ end
218
+
219
+ def check_response_status_codes(response)
220
+ if RETRY_STATUS_CODES.include?(response.code)
221
+ raise APIStatusError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
222
+ elsif response.code >= 400 && response.code < 500
223
+ raise APIError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
224
+ elsif response.code >= 500
225
+ raise Mistral::Error.new(message: "Status: #{response.code}. Message: #{response.body}")
226
+ end
227
+ end
228
+ end
229
+ end
@@ -0,0 +1,126 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class ClientBase
5
+ attr_reader :endpoint, :api_key, :max_retries, :timeout
6
+
7
+ def initialize(endpoint:, api_key: nil, max_retries: 5, timeout: 120)
8
+ @endpoint = endpoint
9
+ @api_key = api_key
10
+ @max_retries = max_retries
11
+ @timeout = timeout
12
+
13
+ @logger = config_logger
14
+
15
+ # For azure endpoints, we default to the mistral model
16
+ @default_model = 'mistral' if endpoint.include?('inference.azure.com')
17
+ end
18
+
19
+ protected
20
+
21
+ def parse_tools(tools)
22
+ parsed_tools = []
23
+
24
+ tools.each do |tool|
25
+ next unless tool['type'] == 'function'
26
+
27
+ parsed_function = {}
28
+ parsed_function['type'] = tool['type']
29
+ parsed_function['function'] = if tool['function'].is_a?(Function)
30
+ tool['function'].to_h
31
+ else
32
+ tool['function']
33
+ end
34
+
35
+ parsed_tools << parsed_function
36
+ end
37
+
38
+ parsed_tools
39
+ end
40
+
41
+ def parse_tool_choice(tool_choice)
42
+ tool_choice.is_a?(ToolChoice) ? tool_choice.to_s : tool_choice
43
+ end
44
+
45
+ def parse_response_format(response_format)
46
+ if response_format.is_a?(ResponseFormat)
47
+ response_format.to_h
48
+ else
49
+ response_format
50
+ end
51
+ end
52
+
53
+ def parse_messages(messages)
54
+ parsed_messages = []
55
+
56
+ messages.each do |message|
57
+ parsed_messages << if message.is_a?(ChatMessage)
58
+ message.to_h
59
+ else
60
+ message
61
+ end
62
+ end
63
+
64
+ parsed_messages
65
+ end
66
+
67
+ def make_chat_request(
68
+ messages:,
69
+ model: nil,
70
+ tools: nil,
71
+ temperature: nil,
72
+ max_tokens: nil,
73
+ top_p: nil,
74
+ random_seed: nil,
75
+ stream: nil,
76
+ safe_prompt: false,
77
+ tool_choice: nil,
78
+ response_format: nil
79
+ )
80
+ request_data = {
81
+ messages: parse_messages(messages),
82
+ safe_prompt: safe_prompt
83
+ }
84
+
85
+ request_data[:model] = if model.nil?
86
+ raise Mistral::Error.new(message: 'model must be provided') if @default_model.nil?
87
+
88
+ @default_model
89
+ else
90
+ model
91
+ end
92
+
93
+ request_data[:tools] = parse_tools(tools) unless tools.nil?
94
+ request_data[:temperature] = temperature unless temperature.nil?
95
+ request_data[:max_tokens] = max_tokens unless max_tokens.nil?
96
+ request_data[:top_p] = top_p unless top_p.nil?
97
+ request_data[:random_seed] = random_seed unless random_seed.nil?
98
+ request_data[:stream] = stream unless stream.nil?
99
+ request_data[:tool_choice] = parse_tool_choice(tool_choice) unless tool_choice.nil?
100
+ request_data[:response_format] = parse_response_format(response_format) unless response_format.nil?
101
+
102
+ @logger.debug("Chat request: #{request_data}")
103
+
104
+ request_data
105
+ end
106
+
107
+ def process_line(line)
108
+ return unless line.start_with?('data: ')
109
+
110
+ line = line[6..].to_s.strip
111
+ return if line == '[DONE]'
112
+
113
+ JSON.parse(line)
114
+ end
115
+
116
+ def config_logger
117
+ Logger.new($stdout).tap do |logger|
118
+ logger.level = ENV.fetch('LOG_LEVEL', 'ERROR')
119
+
120
+ logger.formatter = proc do |severity, datetime, progname, msg|
121
+ "#{datetime.strftime("%Y-%m-%d %H:%M:%S")} #{severity} #{progname}: #{msg}\n"
122
+ end
123
+ end
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,6 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ RETRY_STATUS_CODES = [429, 500, 502, 503, 504].freeze
5
+ ENDPOINT = 'https://api.mistral.ai'
6
+ end
@@ -0,0 +1,38 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ # Base error class, returned when nothing more specific applies
5
+ class Error < StandardError
6
+ end
7
+
8
+ class APIError < Error
9
+ attr_reader :http_status, :headers
10
+
11
+ def initialize(message: nil, http_status: nil, headers: nil)
12
+ super(message: message)
13
+
14
+ @http_status = http_status
15
+ @headers = headers
16
+ end
17
+
18
+ def self.from_response(response, message: nil)
19
+ new(
20
+ message: message || response.to_s,
21
+ http_status: response.code,
22
+ headers: response.headers.to_h
23
+ )
24
+ end
25
+
26
+ def to_s
27
+ "#{self.class.name}(message=#{super}, http_status=#{http_status})"
28
+ end
29
+ end
30
+
31
+ # Returned when we receive a non-200 response from the API that we should retry
32
+ class APIStatusError < APIError
33
+ end
34
+
35
+ # Returned when the SDK can not reach the API server for any reason
36
+ class ConnectionError < Error
37
+ end
38
+ end
@@ -0,0 +1,95 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class Function < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :name, Types::Strict::String
8
+ attribute :description, Types::Strict::String
9
+ attribute :parameters, Types::Strict::Hash
10
+ end
11
+
12
+ ToolType = Types::Strict::String.default('function').enum('function')
13
+
14
+ class FunctionCall < Dry::Struct
15
+ transform_keys(&:to_sym)
16
+
17
+ attribute :name, Types::Strict::String
18
+ attribute :arguments, Types::Strict::String
19
+ end
20
+
21
+ class ToolCall < Dry::Struct
22
+ transform_keys(&:to_sym)
23
+
24
+ attribute :id, Types::Strict::String.default('null')
25
+ attribute :type, ToolType
26
+ attribute :function, FunctionCall
27
+ end
28
+
29
+ ResponseFormats = Types::Strict::String.default('text').enum('text', 'json_object')
30
+
31
+ ToolChoice = Types::Strict::String.enum('auto', 'any', 'none')
32
+
33
+ class ResponseFormat < Dry::Struct
34
+ transform_keys(&:to_sym)
35
+
36
+ attribute :type, ResponseFormats
37
+ end
38
+
39
+ class ChatMessage < Dry::Struct
40
+ transform_keys(&:to_sym)
41
+
42
+ attribute :role, Types::Strict::String
43
+ attribute :content, Types::Strict::Array.of(Types::Strict::String) | Types::Strict::String
44
+ attribute? :name, Types::String.optional
45
+ attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
46
+ end
47
+
48
+ class DeltaMessage < Dry::Struct
49
+ transform_keys(&:to_sym)
50
+
51
+ attribute? :role, Types::Strict::String.optional
52
+ attribute? :content, Types::Strict::String.optional
53
+ attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
54
+ end
55
+
56
+ FinishReason = Types::Strict::String.enum('stop', 'length', 'error', 'tool_calls')
57
+
58
+ class ChatCompletionResponseStreamChoice < Dry::Struct
59
+ transform_keys(&:to_sym)
60
+
61
+ attribute :index, Types::Strict::Integer
62
+ attribute :delta, DeltaMessage
63
+ attribute? :finish_reason, FinishReason.optional
64
+ end
65
+
66
+ class ChatCompletionStreamResponse < Dry::Struct
67
+ transform_keys(&:to_sym)
68
+
69
+ attribute :id, Types::Strict::String
70
+ attribute :model, Types::Strict::String
71
+ attribute :choices, Types::Strict::Array.of(ChatCompletionResponseStreamChoice)
72
+ attribute? :created, Types::Strict::Integer.optional
73
+ attribute? :object, Types::Strict::String.optional
74
+ attribute? :usage, UsageInfo.optional
75
+ end
76
+
77
+ class ChatCompletionResponseChoice < Dry::Struct
78
+ transform_keys(&:to_sym)
79
+
80
+ attribute :index, Types::Strict::Integer
81
+ attribute :message, ChatMessage
82
+ attribute? :finish_reason, FinishReason.optional
83
+ end
84
+
85
+ class ChatCompletionResponse < Dry::Struct
86
+ transform_keys(&:to_sym)
87
+
88
+ attribute :id, Types::Strict::String
89
+ attribute :object, Types::Strict::String
90
+ attribute :created, Types::Strict::Integer
91
+ attribute :model, Types::Strict::String
92
+ attribute :choices, Types::Strict::Array.of(ChatCompletionResponseChoice)
93
+ attribute :usage, UsageInfo
94
+ end
95
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class UsageInfo < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :prompt_tokens, Types::Strict::Integer
8
+ attribute :total_tokens, Types::Strict::Integer
9
+ attribute? :completion_tokens, Types::Strict::Integer.optional
10
+ end
11
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class EmbeddingObject < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :object, Types::Strict::String
8
+ attribute :embedding, Types::Strict::Array.of(Types::Strict::Float)
9
+ attribute :index, Types::Strict::Integer
10
+ end
11
+
12
+ class EmbeddingResponse < Dry::Struct
13
+ transform_keys(&:to_sym)
14
+
15
+ attribute :id, Types::Strict::String
16
+ attribute :object, Types::Strict::String
17
+ attribute :data, Types::Strict::Array.of(EmbeddingObject)
18
+ attribute :model, Types::Strict::String
19
+ attribute :usage, UsageInfo
20
+ end
21
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class ModelPermission < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :id, Types::Strict::String
8
+ attribute :object, Types::Strict::String
9
+ attribute :created, Types::Strict::Integer
10
+ attribute :allow_create_engine, Types::Strict::Bool.default(false)
11
+ attribute :allow_sampling, Types::Strict::Bool.default(true)
12
+ attribute :allow_logprobs, Types::Strict::Bool.default(true)
13
+ attribute :allow_search_indices, Types::Strict::Bool.default(false)
14
+ attribute :allow_view, Types::Strict::Bool.default(true)
15
+ attribute :allow_fine_tuning, Types::Strict::Bool.default(false)
16
+ attribute :organization, Types::Strict::String.default('*')
17
+ attribute? :group, Types::Strict::String.optional
18
+ attribute :is_blocking, Types::Strict::Bool.default(false)
19
+ end
20
+
21
+ class ModelCard < Dry::Struct
22
+ transform_keys(&:to_sym)
23
+
24
+ attribute :id, Types::Strict::String
25
+ attribute :object, Types::Strict::String
26
+ attribute :created, Types::Strict::Integer
27
+ attribute :owned_by, Types::Strict::String
28
+ attribute? :root, Types::Strict::String.optional
29
+ attribute? :parent, Types::Strict::String.optional
30
+ attribute :permission, Types::Strict::Array.of(ModelPermission).default([].freeze)
31
+ end
32
+
33
+ class ModelList < Dry::Struct
34
+ transform_keys(&:to_sym)
35
+
36
+ attribute :object, Types::Strict::String
37
+ attribute :data, Types::Strict::Array.of(ModelCard)
38
+ end
39
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ VERSION = '0.1.0'
5
+ end
data/lib/mistral.rb ADDED
@@ -0,0 +1,24 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'dry-struct'
4
+ require 'http'
5
+ require 'json'
6
+ require 'logger'
7
+ require 'time'
8
+
9
+ require 'http/features/line_iterable_body'
10
+
11
+ module Mistral
12
+ module Types
13
+ include Dry.Types()
14
+ end
15
+ end
16
+
17
+ require 'mistral/constants'
18
+ require 'mistral/exceptions'
19
+ require 'mistral/models/models'
20
+ require 'mistral/models/common'
21
+ require 'mistral/models/embeddings'
22
+ require 'mistral/models/chat_completion'
23
+ require 'mistral/version'
24
+ require 'mistral/client'