mistral 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mistral/client_base'
4
+
5
+ module Mistral
6
+ # Synchronous wrapper around the async client
7
+ class Client < ClientBase
8
+ def initialize(
9
+ api_key: ENV['MISTRAL_API_KEY'],
10
+ endpoint: ENDPOINT,
11
+ max_retries: 5,
12
+ timeout: 120
13
+ )
14
+ super(endpoint: endpoint, api_key: api_key, max_retries: max_retries, timeout: timeout)
15
+
16
+ @client = HTTP.persistent(ENDPOINT)
17
+ .follow
18
+ .timeout(timeout)
19
+ .use(:line_iterable_body)
20
+ .headers('Accept' => 'application/json',
21
+ 'User-Agent' => "mistral-client-ruby/#{VERSION}",
22
+ 'Authorization' => "Bearer #{@api_key}",
23
+ 'Content-Type' => 'application/json'
24
+ )
25
+ end
26
+
27
+ # A chat endpoint that returns a single response.
28
+ #
29
+ # @param messages [Array<ChatMessage>] An array of messages to chat with, e.g.
30
+ # [{role: 'user', content: 'What is the best French cheese?'}]
31
+ # @param model [String] The name of the model to chat with, e.g. mistral-tiny
32
+ # @param tools [Array<Hash>] A list of tools to use.
33
+ # @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
34
+ # @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
35
+ # @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
36
+ # @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
37
+ # @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
38
+ # @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
39
+ # @param tool_choice [String, ToolChoice] The tool choice.
40
+ # @param response_format [Hash<String, String>, ResponseFormat] The response format.
41
+ # @return [ChatCompletionResponse] A response object containing the generated text.
42
+ #
43
+ def chat(
44
+ messages:,
45
+ model: nil,
46
+ tools: nil,
47
+ temperature: nil,
48
+ max_tokens: nil,
49
+ top_p: nil,
50
+ random_seed: nil,
51
+ safe_mode: false,
52
+ safe_prompt: false,
53
+ tool_choice: nil,
54
+ response_format: nil
55
+ )
56
+ request = make_chat_request(
57
+ messages: messages,
58
+ model: model,
59
+ tools: tools,
60
+ temperature: temperature,
61
+ max_tokens: max_tokens,
62
+ top_p: top_p,
63
+ random_seed: random_seed,
64
+ stream: false,
65
+ safe_prompt: safe_mode || safe_prompt,
66
+ tool_choice: tool_choice,
67
+ response_format: response_format
68
+ )
69
+
70
+ single_response = request('post', 'v1/chat/completions', json: request)
71
+
72
+ single_response.each do |response|
73
+ return ChatCompletionResponse.new(response)
74
+ end
75
+
76
+ raise Mistral::Error.new(message: 'No response received')
77
+ end
78
+
79
+ # A chat endpoint that streams responses.
80
+ #
81
+ # @param messages [Array<Any>] An array of messages to chat with, e.g.
82
+ # [{role: 'user', content: 'What is the best French cheese?'}]
83
+ # @param model [String] The name of the model to chat with, e.g. mistral-tiny
84
+ # @param tools [Array<Hash>] A list of tools to use.
85
+ # @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
86
+ # @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
87
+ # @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
88
+ # @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
89
+ # @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
90
+ # @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
91
+ # @param tool_choice [String, ToolChoice] The tool choice.
92
+ # @param response_format [Hash<String, String>, ResponseFormat] The response format.
93
+ # @return [Enumerator<ChatCompletionStreamResponse>] A generator that yields ChatCompletionStreamResponse objects.
94
+ #
95
+ def chat_stream(
96
+ messages:,
97
+ model: nil,
98
+ tools: nil,
99
+ temperature: nil,
100
+ max_tokens: nil,
101
+ top_p: nil,
102
+ random_seed: nil,
103
+ safe_mode: false,
104
+ safe_prompt: false,
105
+ tool_choice: nil,
106
+ response_format: nil
107
+ )
108
+ request = make_chat_request(
109
+ messages: messages,
110
+ model: model,
111
+ tools: tools,
112
+ temperature: temperature,
113
+ max_tokens: max_tokens,
114
+ top_p: top_p,
115
+ random_seed: random_seed,
116
+ stream: true,
117
+ safe_prompt: safe_mode || safe_prompt,
118
+ tool_choice: tool_choice,
119
+ response_format: response_format
120
+ )
121
+
122
+ Enumerator.new do |yielder|
123
+ request('post', 'v1/chat/completions', json: request, stream: true).each do |json_response|
124
+ yielder << ChatCompletionStreamResponse.new(**json_response)
125
+ end
126
+ end
127
+ end
128
+
129
+ # An embeddings endpoint that returns embeddings for a single, or batch of inputs
130
+ #
131
+ # @param model [String] The embedding model to use, e.g. mistral-embed
132
+ # @param input [String, Array<String>] The input to embed, e.g. ['What is the best French cheese?']
133
+ #
134
+ # @return [EmbeddingResponse] A response object containing the embeddings.
135
+ #
136
+ def embeddings(model:, input:)
137
+ request = { model: model, input: input }
138
+ singleton_response = request('post', 'v1/embeddings', json: request)
139
+
140
+ singleton_response.each do |response|
141
+ return EmbeddingResponse.new(response)
142
+ end
143
+
144
+ raise Mistral::Error.new(message: 'No response received')
145
+ end
146
+
147
+ # Returns a list of the available models
148
+ #
149
+ # @return [ModelList] A response object containing the list of models.
150
+ #
151
+ def list_models
152
+ singleton_response = request('get', 'v1/models')
153
+
154
+ singleton_response.each do |response|
155
+ return ModelList.new(response)
156
+ end
157
+
158
+ raise Mistral::Error.new(message: 'No response received')
159
+ end
160
+
161
+ private
162
+
163
+ def request(method, path, json: nil, stream: false, attempt: 1)
164
+ url = File.join(@endpoint, path)
165
+ headers = {}
166
+ headers['Accept'] = 'text/event-stream' if stream
167
+
168
+ @logger.debug("Sending request: #{method.upcase} #{url} #{json}")
169
+
170
+ Enumerator.new do |yielder|
171
+ response = @client.headers(headers).request(method.downcase.to_sym, url, json: json)
172
+ check_response_status_codes(response)
173
+
174
+ if stream
175
+ response.body.each_line do |line|
176
+ processed_line = process_line(line)
177
+ next if processed_line.nil?
178
+
179
+ yielder << processed_line
180
+ end
181
+ else
182
+ yielder << check_response(response)
183
+ end
184
+ rescue HTTP::ConnectionError => e
185
+ raise Mistral::ConnectionError, e.message
186
+ rescue HTTP::RequestError => e
187
+ raise Mistral::Error, "Unexpected exception (#{e.class}): #{e.message}"
188
+ rescue JSON::ParserError
189
+ raise Mistral::APIError.from_response(response, message: "Failed to decode json body: #{response.body}")
190
+ rescue Mistral::APIStatusError => e
191
+ attempt += 1
192
+
193
+ raise Mistral::APIStatusError.from_response(response, message: e.message) if attempt > @max_retries
194
+
195
+ backoff = 2.0**attempt # exponential backoff
196
+ sleep(backoff)
197
+
198
+ # Retry and yield the response
199
+ request(method, path, json: json, stream: stream, attempt: attempt).each do |r|
200
+ yielder << r
201
+ end
202
+ end
203
+ end
204
+
205
+ def check_response(response)
206
+ check_response_status_codes(response)
207
+
208
+ json_response = JSON.parse(response.body.to_s)
209
+
210
+ if !json_response.key?('object')
211
+ raise Mistral::Error, "Unexpected response: #{json_response}"
212
+ elsif json_response['object'] == 'error' # has errors
213
+ raise Mistral::APIError.from_response(response, message: json_response['message'])
214
+ end
215
+
216
+ json_response
217
+ end
218
+
219
+ def check_response_status_codes(response)
220
+ if RETRY_STATUS_CODES.include?(response.code)
221
+ raise APIStatusError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
222
+ elsif response.code >= 400 && response.code < 500
223
+ raise APIError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
224
+ elsif response.code >= 500
225
+ raise Mistral::Error.new(message: "Status: #{response.code}. Message: #{response.body}")
226
+ end
227
+ end
228
+ end
229
+ end
@@ -0,0 +1,126 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class ClientBase
5
+ attr_reader :endpoint, :api_key, :max_retries, :timeout
6
+
7
+ def initialize(endpoint:, api_key: nil, max_retries: 5, timeout: 120)
8
+ @endpoint = endpoint
9
+ @api_key = api_key
10
+ @max_retries = max_retries
11
+ @timeout = timeout
12
+
13
+ @logger = config_logger
14
+
15
+ # For azure endpoints, we default to the mistral model
16
+ @default_model = 'mistral' if endpoint.include?('inference.azure.com')
17
+ end
18
+
19
+ protected
20
+
21
+ def parse_tools(tools)
22
+ parsed_tools = []
23
+
24
+ tools.each do |tool|
25
+ next unless tool['type'] == 'function'
26
+
27
+ parsed_function = {}
28
+ parsed_function['type'] = tool['type']
29
+ parsed_function['function'] = if tool['function'].is_a?(Function)
30
+ tool['function'].to_h
31
+ else
32
+ tool['function']
33
+ end
34
+
35
+ parsed_tools << parsed_function
36
+ end
37
+
38
+ parsed_tools
39
+ end
40
+
41
+ def parse_tool_choice(tool_choice)
42
+ tool_choice.is_a?(ToolChoice) ? tool_choice.to_s : tool_choice
43
+ end
44
+
45
+ def parse_response_format(response_format)
46
+ if response_format.is_a?(ResponseFormat)
47
+ response_format.to_h
48
+ else
49
+ response_format
50
+ end
51
+ end
52
+
53
+ def parse_messages(messages)
54
+ parsed_messages = []
55
+
56
+ messages.each do |message|
57
+ parsed_messages << if message.is_a?(ChatMessage)
58
+ message.to_h
59
+ else
60
+ message
61
+ end
62
+ end
63
+
64
+ parsed_messages
65
+ end
66
+
67
+ def make_chat_request(
68
+ messages:,
69
+ model: nil,
70
+ tools: nil,
71
+ temperature: nil,
72
+ max_tokens: nil,
73
+ top_p: nil,
74
+ random_seed: nil,
75
+ stream: nil,
76
+ safe_prompt: false,
77
+ tool_choice: nil,
78
+ response_format: nil
79
+ )
80
+ request_data = {
81
+ messages: parse_messages(messages),
82
+ safe_prompt: safe_prompt
83
+ }
84
+
85
+ request_data[:model] = if model.nil?
86
+ raise Mistral::Error.new(message: 'model must be provided') if @default_model.nil?
87
+
88
+ @default_model
89
+ else
90
+ model
91
+ end
92
+
93
+ request_data[:tools] = parse_tools(tools) unless tools.nil?
94
+ request_data[:temperature] = temperature unless temperature.nil?
95
+ request_data[:max_tokens] = max_tokens unless max_tokens.nil?
96
+ request_data[:top_p] = top_p unless top_p.nil?
97
+ request_data[:random_seed] = random_seed unless random_seed.nil?
98
+ request_data[:stream] = stream unless stream.nil?
99
+ request_data[:tool_choice] = parse_tool_choice(tool_choice) unless tool_choice.nil?
100
+ request_data[:response_format] = parse_response_format(response_format) unless response_format.nil?
101
+
102
+ @logger.debug("Chat request: #{request_data}")
103
+
104
+ request_data
105
+ end
106
+
107
+ def process_line(line)
108
+ return unless line.start_with?('data: ')
109
+
110
+ line = line[6..].to_s.strip
111
+ return if line == '[DONE]'
112
+
113
+ JSON.parse(line)
114
+ end
115
+
116
+ def config_logger
117
+ Logger.new($stdout).tap do |logger|
118
+ logger.level = ENV.fetch('LOG_LEVEL', 'ERROR')
119
+
120
+ logger.formatter = proc do |severity, datetime, progname, msg|
121
+ "#{datetime.strftime("%Y-%m-%d %H:%M:%S")} #{severity} #{progname}: #{msg}\n"
122
+ end
123
+ end
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,6 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ RETRY_STATUS_CODES = [429, 500, 502, 503, 504].freeze
5
+ ENDPOINT = 'https://api.mistral.ai'
6
+ end
@@ -0,0 +1,38 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ # Base error class, returned when nothing more specific applies
5
+ class Error < StandardError
6
+ end
7
+
8
+ class APIError < Error
9
+ attr_reader :http_status, :headers
10
+
11
+ def initialize(message: nil, http_status: nil, headers: nil)
12
+ super(message: message)
13
+
14
+ @http_status = http_status
15
+ @headers = headers
16
+ end
17
+
18
+ def self.from_response(response, message: nil)
19
+ new(
20
+ message: message || response.to_s,
21
+ http_status: response.code,
22
+ headers: response.headers.to_h
23
+ )
24
+ end
25
+
26
+ def to_s
27
+ "#{self.class.name}(message=#{super}, http_status=#{http_status})"
28
+ end
29
+ end
30
+
31
+ # Returned when we receive a non-200 response from the API that we should retry
32
+ class APIStatusError < APIError
33
+ end
34
+
35
+ # Returned when the SDK can not reach the API server for any reason
36
+ class ConnectionError < Error
37
+ end
38
+ end
@@ -0,0 +1,95 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class Function < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :name, Types::Strict::String
8
+ attribute :description, Types::Strict::String
9
+ attribute :parameters, Types::Strict::Hash
10
+ end
11
+
12
+ ToolType = Types::Strict::String.default('function').enum('function')
13
+
14
+ class FunctionCall < Dry::Struct
15
+ transform_keys(&:to_sym)
16
+
17
+ attribute :name, Types::Strict::String
18
+ attribute :arguments, Types::Strict::String
19
+ end
20
+
21
+ class ToolCall < Dry::Struct
22
+ transform_keys(&:to_sym)
23
+
24
+ attribute :id, Types::Strict::String.default('null')
25
+ attribute :type, ToolType
26
+ attribute :function, FunctionCall
27
+ end
28
+
29
+ ResponseFormats = Types::Strict::String.default('text').enum('text', 'json_object')
30
+
31
+ ToolChoice = Types::Strict::String.enum('auto', 'any', 'none')
32
+
33
+ class ResponseFormat < Dry::Struct
34
+ transform_keys(&:to_sym)
35
+
36
+ attribute :type, ResponseFormats
37
+ end
38
+
39
+ class ChatMessage < Dry::Struct
40
+ transform_keys(&:to_sym)
41
+
42
+ attribute :role, Types::Strict::String
43
+ attribute :content, Types::Strict::Array.of(Types::Strict::String) | Types::Strict::String
44
+ attribute? :name, Types::String.optional
45
+ attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
46
+ end
47
+
48
+ class DeltaMessage < Dry::Struct
49
+ transform_keys(&:to_sym)
50
+
51
+ attribute? :role, Types::Strict::String.optional
52
+ attribute? :content, Types::Strict::String.optional
53
+ attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
54
+ end
55
+
56
+ FinishReason = Types::Strict::String.enum('stop', 'length', 'error', 'tool_calls')
57
+
58
+ class ChatCompletionResponseStreamChoice < Dry::Struct
59
+ transform_keys(&:to_sym)
60
+
61
+ attribute :index, Types::Strict::Integer
62
+ attribute :delta, DeltaMessage
63
+ attribute? :finish_reason, FinishReason.optional
64
+ end
65
+
66
+ class ChatCompletionStreamResponse < Dry::Struct
67
+ transform_keys(&:to_sym)
68
+
69
+ attribute :id, Types::Strict::String
70
+ attribute :model, Types::Strict::String
71
+ attribute :choices, Types::Strict::Array.of(ChatCompletionResponseStreamChoice)
72
+ attribute? :created, Types::Strict::Integer.optional
73
+ attribute? :object, Types::Strict::String.optional
74
+ attribute? :usage, UsageInfo.optional
75
+ end
76
+
77
+ class ChatCompletionResponseChoice < Dry::Struct
78
+ transform_keys(&:to_sym)
79
+
80
+ attribute :index, Types::Strict::Integer
81
+ attribute :message, ChatMessage
82
+ attribute? :finish_reason, FinishReason.optional
83
+ end
84
+
85
+ class ChatCompletionResponse < Dry::Struct
86
+ transform_keys(&:to_sym)
87
+
88
+ attribute :id, Types::Strict::String
89
+ attribute :object, Types::Strict::String
90
+ attribute :created, Types::Strict::Integer
91
+ attribute :model, Types::Strict::String
92
+ attribute :choices, Types::Strict::Array.of(ChatCompletionResponseChoice)
93
+ attribute :usage, UsageInfo
94
+ end
95
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class UsageInfo < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :prompt_tokens, Types::Strict::Integer
8
+ attribute :total_tokens, Types::Strict::Integer
9
+ attribute? :completion_tokens, Types::Strict::Integer.optional
10
+ end
11
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class EmbeddingObject < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :object, Types::Strict::String
8
+ attribute :embedding, Types::Strict::Array.of(Types::Strict::Float)
9
+ attribute :index, Types::Strict::Integer
10
+ end
11
+
12
+ class EmbeddingResponse < Dry::Struct
13
+ transform_keys(&:to_sym)
14
+
15
+ attribute :id, Types::Strict::String
16
+ attribute :object, Types::Strict::String
17
+ attribute :data, Types::Strict::Array.of(EmbeddingObject)
18
+ attribute :model, Types::Strict::String
19
+ attribute :usage, UsageInfo
20
+ end
21
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ class ModelPermission < Dry::Struct
5
+ transform_keys(&:to_sym)
6
+
7
+ attribute :id, Types::Strict::String
8
+ attribute :object, Types::Strict::String
9
+ attribute :created, Types::Strict::Integer
10
+ attribute :allow_create_engine, Types::Strict::Bool.default(false)
11
+ attribute :allow_sampling, Types::Strict::Bool.default(true)
12
+ attribute :allow_logprobs, Types::Strict::Bool.default(true)
13
+ attribute :allow_search_indices, Types::Strict::Bool.default(false)
14
+ attribute :allow_view, Types::Strict::Bool.default(true)
15
+ attribute :allow_fine_tuning, Types::Strict::Bool.default(false)
16
+ attribute :organization, Types::Strict::String.default('*')
17
+ attribute? :group, Types::Strict::String.optional
18
+ attribute :is_blocking, Types::Strict::Bool.default(false)
19
+ end
20
+
21
+ class ModelCard < Dry::Struct
22
+ transform_keys(&:to_sym)
23
+
24
+ attribute :id, Types::Strict::String
25
+ attribute :object, Types::Strict::String
26
+ attribute :created, Types::Strict::Integer
27
+ attribute :owned_by, Types::Strict::String
28
+ attribute? :root, Types::Strict::String.optional
29
+ attribute? :parent, Types::Strict::String.optional
30
+ attribute :permission, Types::Strict::Array.of(ModelPermission).default([].freeze)
31
+ end
32
+
33
+ class ModelList < Dry::Struct
34
+ transform_keys(&:to_sym)
35
+
36
+ attribute :object, Types::Strict::String
37
+ attribute :data, Types::Strict::Array.of(ModelCard)
38
+ end
39
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Mistral
4
+ VERSION = '0.1.0'
5
+ end
data/lib/mistral.rb ADDED
@@ -0,0 +1,24 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'dry-struct'
4
+ require 'http'
5
+ require 'json'
6
+ require 'logger'
7
+ require 'time'
8
+
9
+ require 'http/features/line_iterable_body'
10
+
11
+ module Mistral
12
+ module Types
13
+ include Dry.Types()
14
+ end
15
+ end
16
+
17
+ require 'mistral/constants'
18
+ require 'mistral/exceptions'
19
+ require 'mistral/models/models'
20
+ require 'mistral/models/common'
21
+ require 'mistral/models/embeddings'
22
+ require 'mistral/models/chat_completion'
23
+ require 'mistral/version'
24
+ require 'mistral/client'