mistral 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.env.example +3 -0
- data/.rubocop.yml +60 -0
- data/.tool-versions +1 -0
- data/CHANGELOG.md +12 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +21 -0
- data/PYTHON_CLIENT_COMPARISON.md +184 -0
- data/README.md +145 -0
- data/Rakefile +12 -0
- data/examples/chat_no_streaming.rb +18 -0
- data/examples/chat_with_streaming.rb +18 -0
- data/examples/chatbot_with_streaming.rb +289 -0
- data/examples/embeddings.rb +16 -0
- data/examples/function_calling.rb +104 -0
- data/examples/json_format.rb +21 -0
- data/examples/list_models.rb +13 -0
- data/lib/http/features/line_iterable_body.rb +35 -0
- data/lib/mistral/client.rb +229 -0
- data/lib/mistral/client_base.rb +126 -0
- data/lib/mistral/constants.rb +6 -0
- data/lib/mistral/exceptions.rb +38 -0
- data/lib/mistral/models/chat_completion.rb +95 -0
- data/lib/mistral/models/common.rb +11 -0
- data/lib/mistral/models/embeddings.rb +21 -0
- data/lib/mistral/models/models.rb +39 -0
- data/lib/mistral/version.rb +5 -0
- data/lib/mistral.rb +24 -0
- metadata +172 -0
@@ -0,0 +1,289 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
# Simple chatbot example -- run with -h argument to see options.
|
5
|
+
|
6
|
+
require 'bundler/setup'
|
7
|
+
require 'dotenv/load'
|
8
|
+
require 'readline'
|
9
|
+
require 'optparse'
|
10
|
+
require 'mistral'
|
11
|
+
|
12
|
+
MODEL_LIST = %w[
|
13
|
+
mistral-tiny
|
14
|
+
mistral-small
|
15
|
+
mistral-medium
|
16
|
+
].freeze
|
17
|
+
DEFAULT_MODEL = 'mistral-small'
|
18
|
+
DEFAULT_TEMPERATURE = 0.7
|
19
|
+
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
|
20
|
+
# A hash of all commands and their arguments, used for tab completion.
|
21
|
+
COMMAND_LIST = {
|
22
|
+
'/new' => {},
|
23
|
+
'/help' => {},
|
24
|
+
'/model' => MODEL_LIST.map { |model| [model, {}] }.to_h, # Nested completions for models
|
25
|
+
'/system' => {},
|
26
|
+
'/temperature' => {},
|
27
|
+
'/config' => {},
|
28
|
+
'/quit' => {},
|
29
|
+
'/exit' => {}
|
30
|
+
}.freeze
|
31
|
+
|
32
|
+
$logger = Logger.new($stdout)
|
33
|
+
$logger.level = Logger::INFO
|
34
|
+
$logger.formatter = proc do |severity, datetime, _, msg|
|
35
|
+
"#{datetime.strftime("%Y-%m-%d %H:%M:%S")} - #{severity} - #{msg}\n"
|
36
|
+
end
|
37
|
+
|
38
|
+
def find_completions(command_dict, parts)
|
39
|
+
return command_dict.keys if parts.empty?
|
40
|
+
|
41
|
+
if command_dict.key?(parts[0])
|
42
|
+
find_completions(command_dict[parts[0]], parts[1..])
|
43
|
+
else
|
44
|
+
command_dict.keys.select { |cmd| cmd.start_with?(parts[0]) }
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
# Enable tab completion
|
49
|
+
Readline.completion_proc = proc do |_input|
|
50
|
+
line_parts = Readline.line_buffer.lstrip.split(' ')
|
51
|
+
options = find_completions(COMMAND_LIST, line_parts[0..-2])
|
52
|
+
options.select { |option| option.start_with?(line_parts[-1]) }
|
53
|
+
end
|
54
|
+
|
55
|
+
class ChatBot
|
56
|
+
def initialize(api_key, model, system_message = nil, temperature = DEFAULT_TEMPERATURE)
|
57
|
+
raise ArgumentError, 'An API key must be provided to use the Mistral API.' if api_key.nil?
|
58
|
+
|
59
|
+
@client = Mistral::Client.new(api_key: api_key)
|
60
|
+
@model = model
|
61
|
+
@temperature = temperature
|
62
|
+
@system_message = system_message
|
63
|
+
end
|
64
|
+
|
65
|
+
def opening_instructions
|
66
|
+
puts '
|
67
|
+
To chat: type your message and hit enter
|
68
|
+
To start a new chat: /new
|
69
|
+
To switch model: /model <model name>
|
70
|
+
To switch system message: /system <message>
|
71
|
+
To switch temperature: /temperature <temperature>
|
72
|
+
To see current config: /config
|
73
|
+
To exit: /exit, /quit, or hit CTRL+C
|
74
|
+
To see this help: /help
|
75
|
+
'
|
76
|
+
end
|
77
|
+
|
78
|
+
def new_chat
|
79
|
+
puts ''
|
80
|
+
puts "Starting new chat with model: #{@model}, temperature: #{@temperature}"
|
81
|
+
puts ''
|
82
|
+
@messages = []
|
83
|
+
@messages << Mistral::ChatMessage.new(role: 'system', content: @system_message) if @system_message
|
84
|
+
end
|
85
|
+
|
86
|
+
def switch_model(input)
|
87
|
+
model = get_arguments(input)
|
88
|
+
|
89
|
+
if MODEL_LIST.include?(model)
|
90
|
+
@model = model
|
91
|
+
$logger.info("Switching model: #{model}")
|
92
|
+
else
|
93
|
+
$logger.error("Invalid model name: #{model}")
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def switch_system_message(input)
|
98
|
+
system_message = get_arguments(input)
|
99
|
+
|
100
|
+
if system_message
|
101
|
+
@system_message = system_message
|
102
|
+
$logger.info("Switching system message: #{system_message}")
|
103
|
+
new_chat
|
104
|
+
else
|
105
|
+
$logger.error("Invalid system message: #{system_message}")
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
def switch_temperature(input)
|
110
|
+
temperature = get_arguments(input)
|
111
|
+
|
112
|
+
begin
|
113
|
+
temperature = Float(temperature)
|
114
|
+
|
115
|
+
raise ArgumentError if temperature.negative? || temperature > 1
|
116
|
+
|
117
|
+
@temperature = temperature
|
118
|
+
$logger.info("Switching temperature: #{temperature}")
|
119
|
+
rescue ArgumentError
|
120
|
+
$logger.error("Invalid temperature: #{temperature}")
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
def show_config
|
125
|
+
puts ''
|
126
|
+
puts "Current model: #{@model}"
|
127
|
+
puts "Current temperature: #{@temperature}"
|
128
|
+
puts "Current system message: #{@system_message}"
|
129
|
+
puts ''
|
130
|
+
end
|
131
|
+
|
132
|
+
def collect_user_input
|
133
|
+
puts ''
|
134
|
+
print 'YOU: '
|
135
|
+
gets.chomp
|
136
|
+
end
|
137
|
+
|
138
|
+
def run_inference(content)
|
139
|
+
puts ''
|
140
|
+
puts 'MISTRAL:'
|
141
|
+
puts ''
|
142
|
+
|
143
|
+
@messages << Mistral::ChatMessage.new(role: 'user', content: content)
|
144
|
+
|
145
|
+
assistant_response = ''
|
146
|
+
|
147
|
+
$logger.debug("Running inference with model: #{@model}, temperature: #{@temperature}")
|
148
|
+
$logger.debug("Sending messages: #{@messages}")
|
149
|
+
|
150
|
+
@client.chat_stream(model: @model, temperature: @temperature, messages: @messages).each do |chunk|
|
151
|
+
response = chunk.choices[0].delta.content
|
152
|
+
|
153
|
+
if response
|
154
|
+
print response
|
155
|
+
assistant_response += response
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
puts ''
|
160
|
+
|
161
|
+
@messages << Mistral::ChatMessage.new(role: 'assistant', content: assistant_response) if assistant_response
|
162
|
+
|
163
|
+
$logger.debug("Current messages: #{@messages}")
|
164
|
+
end
|
165
|
+
|
166
|
+
def get_command(input)
|
167
|
+
input.split[0].strip
|
168
|
+
end
|
169
|
+
|
170
|
+
def get_arguments(input)
|
171
|
+
input.split[1..].join(' ')
|
172
|
+
rescue IndexError
|
173
|
+
''
|
174
|
+
end
|
175
|
+
|
176
|
+
def is_command?(input)
|
177
|
+
COMMAND_LIST.key?(get_command(input))
|
178
|
+
end
|
179
|
+
|
180
|
+
def execute_command(input)
|
181
|
+
command = get_command(input)
|
182
|
+
case command
|
183
|
+
when '/exit', '/quit'
|
184
|
+
exit
|
185
|
+
when '/help'
|
186
|
+
opening_instructions
|
187
|
+
when '/new'
|
188
|
+
new_chat
|
189
|
+
when '/model'
|
190
|
+
switch_model(input)
|
191
|
+
when '/system'
|
192
|
+
switch_system_message(input)
|
193
|
+
when '/temperature'
|
194
|
+
switch_temperature(input)
|
195
|
+
when '/config'
|
196
|
+
show_config
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
def start
|
201
|
+
opening_instructions
|
202
|
+
new_chat
|
203
|
+
|
204
|
+
loop do
|
205
|
+
input = collect_user_input
|
206
|
+
|
207
|
+
if is_command?(input)
|
208
|
+
execute_command(input)
|
209
|
+
else
|
210
|
+
run_inference(input)
|
211
|
+
end
|
212
|
+
rescue Interrupt
|
213
|
+
exit
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def exit
|
218
|
+
$logger.debug('Exiting chatbot')
|
219
|
+
puts 'Goodbye!'
|
220
|
+
Kernel.exit(0)
|
221
|
+
end
|
222
|
+
end
|
223
|
+
|
224
|
+
options = {}
|
225
|
+
OptionParser.new do |opts|
|
226
|
+
opts.banner = 'Usage: chatbot.rb [options]'
|
227
|
+
|
228
|
+
opts.on(
|
229
|
+
'--api-key KEY',
|
230
|
+
'Mistral API key. Defaults to environment variable MISTRAL_API_KEY'
|
231
|
+
) do |key|
|
232
|
+
options[:api_key] = key
|
233
|
+
end
|
234
|
+
|
235
|
+
opts.on(
|
236
|
+
'-m',
|
237
|
+
'--model MODEL',
|
238
|
+
MODEL_LIST,
|
239
|
+
"Model for chat inference. Choices are #{MODEL_LIST.join(", ")}. Defaults to #{DEFAULT_MODEL}"
|
240
|
+
) do |model|
|
241
|
+
options[:model] = model
|
242
|
+
end
|
243
|
+
|
244
|
+
opts.on(
|
245
|
+
'-s',
|
246
|
+
'--system-message MESSAGE',
|
247
|
+
'Optional system message to prepend'
|
248
|
+
) do |message|
|
249
|
+
options[:system_message] = message
|
250
|
+
end
|
251
|
+
|
252
|
+
opts.on(
|
253
|
+
'-t',
|
254
|
+
'--temperature FLOAT',
|
255
|
+
Float,
|
256
|
+
"Optional temperature for chat inference. Defaults to #{DEFAULT_TEMPERATURE}"
|
257
|
+
) do |temp|
|
258
|
+
options[:temperature] = temp
|
259
|
+
end
|
260
|
+
|
261
|
+
opts.on(
|
262
|
+
'-d',
|
263
|
+
'--debug',
|
264
|
+
'Enable debug logging'
|
265
|
+
) do
|
266
|
+
options[:debug] = true
|
267
|
+
end
|
268
|
+
end.parse!
|
269
|
+
|
270
|
+
api_key = options[:api_key] || ENV.fetch('MISTRAL_API_KEY')
|
271
|
+
model = options[:model] || DEFAULT_MODEL
|
272
|
+
system_message = options[:system_message]
|
273
|
+
temperature = options[:temperature] || DEFAULT_TEMPERATURE
|
274
|
+
|
275
|
+
$logger.level = options[:debug] ? Logger::DEBUG : Logger::INFO
|
276
|
+
|
277
|
+
$logger.debug(
|
278
|
+
"Starting chatbot with model: #{model}, " \
|
279
|
+
"temperature: #{temperature}, " \
|
280
|
+
"system message: #{system_message}"
|
281
|
+
)
|
282
|
+
|
283
|
+
begin
|
284
|
+
bot = ChatBot.new(api_key, model, system_message, temperature)
|
285
|
+
bot.start
|
286
|
+
rescue StandardError => e
|
287
|
+
$logger.error(e)
|
288
|
+
exit(1)
|
289
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
9
|
+
client = Mistral::Client.new(api_key: api_key)
|
10
|
+
|
11
|
+
embeddings_response = client.embeddings(
|
12
|
+
model: 'mistral-embed',
|
13
|
+
input: ['What is the best French cheese?'] * 10
|
14
|
+
)
|
15
|
+
|
16
|
+
puts embeddings_response.to_h
|
@@ -0,0 +1,104 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
# Assuming we have the following data
|
9
|
+
data = {
|
10
|
+
'transaction_id' => %w[T1001 T1002 T1003 T1004 T1005],
|
11
|
+
'customer_id' => %w[C001 C002 C003 C002 C001],
|
12
|
+
'payment_amount' => [125.50, 89.99, 120.00, 54.30, 210.20],
|
13
|
+
'payment_date' => %w[2021-10-05 2021-10-06 2021-10-07 2021-10-05 2021-10-08],
|
14
|
+
'payment_status' => %w[Paid Unpaid Paid Paid Pending]
|
15
|
+
}
|
16
|
+
|
17
|
+
def retrieve_payment_status(data, transaction_id)
|
18
|
+
data['transaction_id'].each_with_index do |r, i|
|
19
|
+
return { status: data['payment_status'][i] }.to_json if r == transaction_id
|
20
|
+
end
|
21
|
+
|
22
|
+
{ status: 'Error - transaction id not found' }.to_json
|
23
|
+
end
|
24
|
+
|
25
|
+
def retrieve_payment_date(data, transaction_id)
|
26
|
+
data['transaction_id'].each_with_index do |r, i|
|
27
|
+
return { date: data['payment_date'][i] }.to_json if r == transaction_id
|
28
|
+
end
|
29
|
+
|
30
|
+
{ status: 'Error - transaction id not found' }.to_json
|
31
|
+
end
|
32
|
+
|
33
|
+
names_to_functions = {
|
34
|
+
'retrieve_payment_status' => ->(transaction_id) { retrieve_payment_status(data, transaction_id) },
|
35
|
+
'retrieve_payment_date' => ->(transaction_id) { retrieve_payment_date(data, transaction_id) }
|
36
|
+
}
|
37
|
+
|
38
|
+
tools = [
|
39
|
+
{
|
40
|
+
'type' => 'function',
|
41
|
+
'function' => Mistral::Function.new(
|
42
|
+
name: 'retrieve_payment_status',
|
43
|
+
description: 'Get payment status of a transaction id',
|
44
|
+
parameters: {
|
45
|
+
'type' => 'object',
|
46
|
+
'required' => ['transaction_id'],
|
47
|
+
'properties' => {
|
48
|
+
'transaction_id' => {
|
49
|
+
'type' => 'string',
|
50
|
+
'description' => 'The transaction id.'
|
51
|
+
}
|
52
|
+
}
|
53
|
+
}
|
54
|
+
)
|
55
|
+
},
|
56
|
+
{
|
57
|
+
'type' => 'function',
|
58
|
+
'function' => Mistral::Function.new(
|
59
|
+
name: 'retrieve_payment_date',
|
60
|
+
description: 'Get payment date of a transaction id',
|
61
|
+
parameters: {
|
62
|
+
'type' => 'object',
|
63
|
+
'required' => ['transaction_id'],
|
64
|
+
'properties' => {
|
65
|
+
'transaction_id' => {
|
66
|
+
'type' => 'string',
|
67
|
+
'description' => 'The transaction id.'
|
68
|
+
}
|
69
|
+
}
|
70
|
+
}
|
71
|
+
)
|
72
|
+
}
|
73
|
+
]
|
74
|
+
|
75
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
76
|
+
model = 'mistral-large-latest'
|
77
|
+
|
78
|
+
client = Mistral::Client.new(api_key: api_key)
|
79
|
+
|
80
|
+
messages = [Mistral::ChatMessage.new(role: 'user', content: "What's the status of my transaction?")]
|
81
|
+
|
82
|
+
response = client.chat(model: model, messages: messages, tools: tools)
|
83
|
+
|
84
|
+
puts response.choices[0].message.content
|
85
|
+
|
86
|
+
messages << Mistral::ChatMessage.new(role: 'assistant', content: response.choices[0].message.content)
|
87
|
+
messages << Mistral::ChatMessage.new(role: 'user', content: 'My transaction ID is T1001.')
|
88
|
+
|
89
|
+
response = client.chat(model: model, messages: messages, tools: tools)
|
90
|
+
|
91
|
+
tool_call = response.choices[0].message.tool_calls[0]
|
92
|
+
function_name = tool_call.function.name
|
93
|
+
function_params = JSON.parse(tool_call.function.arguments)
|
94
|
+
|
95
|
+
puts "calling function_name: #{function_name}, with function_params: #{function_params}"
|
96
|
+
|
97
|
+
function_result = names_to_functions[function_name].call(function_params['transaction_id'])
|
98
|
+
|
99
|
+
messages << response.choices[0].message
|
100
|
+
messages << Mistral::ChatMessage.new(role: 'tool', name: function_name, content: function_result)
|
101
|
+
|
102
|
+
response = client.chat(model: model, messages: messages, tools: tools)
|
103
|
+
|
104
|
+
puts response.choices[0].message.content
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
9
|
+
client = Mistral::Client.new(api_key: api_key)
|
10
|
+
|
11
|
+
model = 'mistral-large-latest'
|
12
|
+
|
13
|
+
chat_response = client.chat(
|
14
|
+
model: model,
|
15
|
+
response_format: { type: 'json_object' },
|
16
|
+
messages: [Mistral::ChatMessage.new(
|
17
|
+
role: 'user', content: 'What is the best French cheese? Answer shortly in JSON.'
|
18
|
+
)]
|
19
|
+
)
|
20
|
+
|
21
|
+
puts chat_response.choices[0].message.content
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
9
|
+
client = Mistral::Client.new(api_key: api_key)
|
10
|
+
|
11
|
+
list_models_response = client.list_models
|
12
|
+
|
13
|
+
puts list_models_response.to_h
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HTTP
|
4
|
+
module Features
|
5
|
+
class LineIterableBody < Feature
|
6
|
+
def wrap_response(response)
|
7
|
+
options = {
|
8
|
+
status: response.status,
|
9
|
+
version: response.version,
|
10
|
+
headers: response.headers,
|
11
|
+
proxy_headers: response.proxy_headers,
|
12
|
+
connection: response.connection,
|
13
|
+
body: IterableBodyWrapper.new(response.body, response.body.instance_variable_get(:@encoding)),
|
14
|
+
request: response.request
|
15
|
+
}
|
16
|
+
|
17
|
+
HTTP::Response.new(options)
|
18
|
+
end
|
19
|
+
|
20
|
+
class IterableBodyWrapper < HTTP::Response::Body
|
21
|
+
def initialize(body, encoding)
|
22
|
+
super(body, encoding: encoding)
|
23
|
+
end
|
24
|
+
|
25
|
+
def each_line(&block)
|
26
|
+
each do |chunk|
|
27
|
+
chunk.each_line(&block)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
HTTP::Options.register_feature(:line_iterable_body, Features::LineIterableBody)
|
35
|
+
end
|