mistral 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.env.example +3 -0
- data/.rubocop.yml +60 -0
- data/.tool-versions +1 -0
- data/CHANGELOG.md +12 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +21 -0
- data/PYTHON_CLIENT_COMPARISON.md +184 -0
- data/README.md +145 -0
- data/Rakefile +12 -0
- data/examples/chat_no_streaming.rb +18 -0
- data/examples/chat_with_streaming.rb +18 -0
- data/examples/chatbot_with_streaming.rb +289 -0
- data/examples/embeddings.rb +16 -0
- data/examples/function_calling.rb +104 -0
- data/examples/json_format.rb +21 -0
- data/examples/list_models.rb +13 -0
- data/lib/http/features/line_iterable_body.rb +35 -0
- data/lib/mistral/client.rb +229 -0
- data/lib/mistral/client_base.rb +126 -0
- data/lib/mistral/constants.rb +6 -0
- data/lib/mistral/exceptions.rb +38 -0
- data/lib/mistral/models/chat_completion.rb +95 -0
- data/lib/mistral/models/common.rb +11 -0
- data/lib/mistral/models/embeddings.rb +21 -0
- data/lib/mistral/models/models.rb +39 -0
- data/lib/mistral/version.rb +5 -0
- data/lib/mistral.rb +24 -0
- metadata +172 -0
@@ -0,0 +1,289 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
# Simple chatbot example -- run with -h argument to see options.
|
5
|
+
|
6
|
+
require 'bundler/setup'
|
7
|
+
require 'dotenv/load'
|
8
|
+
require 'readline'
|
9
|
+
require 'optparse'
|
10
|
+
require 'mistral'
|
11
|
+
|
12
|
+
MODEL_LIST = %w[
|
13
|
+
mistral-tiny
|
14
|
+
mistral-small
|
15
|
+
mistral-medium
|
16
|
+
].freeze
|
17
|
+
DEFAULT_MODEL = 'mistral-small'
|
18
|
+
DEFAULT_TEMPERATURE = 0.7
|
19
|
+
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
|
20
|
+
# A hash of all commands and their arguments, used for tab completion.
|
21
|
+
COMMAND_LIST = {
|
22
|
+
'/new' => {},
|
23
|
+
'/help' => {},
|
24
|
+
'/model' => MODEL_LIST.map { |model| [model, {}] }.to_h, # Nested completions for models
|
25
|
+
'/system' => {},
|
26
|
+
'/temperature' => {},
|
27
|
+
'/config' => {},
|
28
|
+
'/quit' => {},
|
29
|
+
'/exit' => {}
|
30
|
+
}.freeze
|
31
|
+
|
32
|
+
$logger = Logger.new($stdout)
|
33
|
+
$logger.level = Logger::INFO
|
34
|
+
$logger.formatter = proc do |severity, datetime, _, msg|
|
35
|
+
"#{datetime.strftime("%Y-%m-%d %H:%M:%S")} - #{severity} - #{msg}\n"
|
36
|
+
end
|
37
|
+
|
38
|
+
def find_completions(command_dict, parts)
|
39
|
+
return command_dict.keys if parts.empty?
|
40
|
+
|
41
|
+
if command_dict.key?(parts[0])
|
42
|
+
find_completions(command_dict[parts[0]], parts[1..])
|
43
|
+
else
|
44
|
+
command_dict.keys.select { |cmd| cmd.start_with?(parts[0]) }
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
# Enable tab completion
|
49
|
+
Readline.completion_proc = proc do |_input|
|
50
|
+
line_parts = Readline.line_buffer.lstrip.split(' ')
|
51
|
+
options = find_completions(COMMAND_LIST, line_parts[0..-2])
|
52
|
+
options.select { |option| option.start_with?(line_parts[-1]) }
|
53
|
+
end
|
54
|
+
|
55
|
+
class ChatBot
|
56
|
+
def initialize(api_key, model, system_message = nil, temperature = DEFAULT_TEMPERATURE)
|
57
|
+
raise ArgumentError, 'An API key must be provided to use the Mistral API.' if api_key.nil?
|
58
|
+
|
59
|
+
@client = Mistral::Client.new(api_key: api_key)
|
60
|
+
@model = model
|
61
|
+
@temperature = temperature
|
62
|
+
@system_message = system_message
|
63
|
+
end
|
64
|
+
|
65
|
+
def opening_instructions
|
66
|
+
puts '
|
67
|
+
To chat: type your message and hit enter
|
68
|
+
To start a new chat: /new
|
69
|
+
To switch model: /model <model name>
|
70
|
+
To switch system message: /system <message>
|
71
|
+
To switch temperature: /temperature <temperature>
|
72
|
+
To see current config: /config
|
73
|
+
To exit: /exit, /quit, or hit CTRL+C
|
74
|
+
To see this help: /help
|
75
|
+
'
|
76
|
+
end
|
77
|
+
|
78
|
+
def new_chat
|
79
|
+
puts ''
|
80
|
+
puts "Starting new chat with model: #{@model}, temperature: #{@temperature}"
|
81
|
+
puts ''
|
82
|
+
@messages = []
|
83
|
+
@messages << Mistral::ChatMessage.new(role: 'system', content: @system_message) if @system_message
|
84
|
+
end
|
85
|
+
|
86
|
+
def switch_model(input)
|
87
|
+
model = get_arguments(input)
|
88
|
+
|
89
|
+
if MODEL_LIST.include?(model)
|
90
|
+
@model = model
|
91
|
+
$logger.info("Switching model: #{model}")
|
92
|
+
else
|
93
|
+
$logger.error("Invalid model name: #{model}")
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def switch_system_message(input)
|
98
|
+
system_message = get_arguments(input)
|
99
|
+
|
100
|
+
if system_message
|
101
|
+
@system_message = system_message
|
102
|
+
$logger.info("Switching system message: #{system_message}")
|
103
|
+
new_chat
|
104
|
+
else
|
105
|
+
$logger.error("Invalid system message: #{system_message}")
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
def switch_temperature(input)
|
110
|
+
temperature = get_arguments(input)
|
111
|
+
|
112
|
+
begin
|
113
|
+
temperature = Float(temperature)
|
114
|
+
|
115
|
+
raise ArgumentError if temperature.negative? || temperature > 1
|
116
|
+
|
117
|
+
@temperature = temperature
|
118
|
+
$logger.info("Switching temperature: #{temperature}")
|
119
|
+
rescue ArgumentError
|
120
|
+
$logger.error("Invalid temperature: #{temperature}")
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
def show_config
|
125
|
+
puts ''
|
126
|
+
puts "Current model: #{@model}"
|
127
|
+
puts "Current temperature: #{@temperature}"
|
128
|
+
puts "Current system message: #{@system_message}"
|
129
|
+
puts ''
|
130
|
+
end
|
131
|
+
|
132
|
+
def collect_user_input
|
133
|
+
puts ''
|
134
|
+
print 'YOU: '
|
135
|
+
gets.chomp
|
136
|
+
end
|
137
|
+
|
138
|
+
def run_inference(content)
|
139
|
+
puts ''
|
140
|
+
puts 'MISTRAL:'
|
141
|
+
puts ''
|
142
|
+
|
143
|
+
@messages << Mistral::ChatMessage.new(role: 'user', content: content)
|
144
|
+
|
145
|
+
assistant_response = ''
|
146
|
+
|
147
|
+
$logger.debug("Running inference with model: #{@model}, temperature: #{@temperature}")
|
148
|
+
$logger.debug("Sending messages: #{@messages}")
|
149
|
+
|
150
|
+
@client.chat_stream(model: @model, temperature: @temperature, messages: @messages).each do |chunk|
|
151
|
+
response = chunk.choices[0].delta.content
|
152
|
+
|
153
|
+
if response
|
154
|
+
print response
|
155
|
+
assistant_response += response
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
puts ''
|
160
|
+
|
161
|
+
@messages << Mistral::ChatMessage.new(role: 'assistant', content: assistant_response) if assistant_response
|
162
|
+
|
163
|
+
$logger.debug("Current messages: #{@messages}")
|
164
|
+
end
|
165
|
+
|
166
|
+
def get_command(input)
|
167
|
+
input.split[0].strip
|
168
|
+
end
|
169
|
+
|
170
|
+
def get_arguments(input)
|
171
|
+
input.split[1..].join(' ')
|
172
|
+
rescue IndexError
|
173
|
+
''
|
174
|
+
end
|
175
|
+
|
176
|
+
def is_command?(input)
|
177
|
+
COMMAND_LIST.key?(get_command(input))
|
178
|
+
end
|
179
|
+
|
180
|
+
def execute_command(input)
|
181
|
+
command = get_command(input)
|
182
|
+
case command
|
183
|
+
when '/exit', '/quit'
|
184
|
+
exit
|
185
|
+
when '/help'
|
186
|
+
opening_instructions
|
187
|
+
when '/new'
|
188
|
+
new_chat
|
189
|
+
when '/model'
|
190
|
+
switch_model(input)
|
191
|
+
when '/system'
|
192
|
+
switch_system_message(input)
|
193
|
+
when '/temperature'
|
194
|
+
switch_temperature(input)
|
195
|
+
when '/config'
|
196
|
+
show_config
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
def start
|
201
|
+
opening_instructions
|
202
|
+
new_chat
|
203
|
+
|
204
|
+
loop do
|
205
|
+
input = collect_user_input
|
206
|
+
|
207
|
+
if is_command?(input)
|
208
|
+
execute_command(input)
|
209
|
+
else
|
210
|
+
run_inference(input)
|
211
|
+
end
|
212
|
+
rescue Interrupt
|
213
|
+
exit
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def exit
|
218
|
+
$logger.debug('Exiting chatbot')
|
219
|
+
puts 'Goodbye!'
|
220
|
+
Kernel.exit(0)
|
221
|
+
end
|
222
|
+
end
|
223
|
+
|
224
|
+
options = {}
|
225
|
+
OptionParser.new do |opts|
|
226
|
+
opts.banner = 'Usage: chatbot.rb [options]'
|
227
|
+
|
228
|
+
opts.on(
|
229
|
+
'--api-key KEY',
|
230
|
+
'Mistral API key. Defaults to environment variable MISTRAL_API_KEY'
|
231
|
+
) do |key|
|
232
|
+
options[:api_key] = key
|
233
|
+
end
|
234
|
+
|
235
|
+
opts.on(
|
236
|
+
'-m',
|
237
|
+
'--model MODEL',
|
238
|
+
MODEL_LIST,
|
239
|
+
"Model for chat inference. Choices are #{MODEL_LIST.join(", ")}. Defaults to #{DEFAULT_MODEL}"
|
240
|
+
) do |model|
|
241
|
+
options[:model] = model
|
242
|
+
end
|
243
|
+
|
244
|
+
opts.on(
|
245
|
+
'-s',
|
246
|
+
'--system-message MESSAGE',
|
247
|
+
'Optional system message to prepend'
|
248
|
+
) do |message|
|
249
|
+
options[:system_message] = message
|
250
|
+
end
|
251
|
+
|
252
|
+
opts.on(
|
253
|
+
'-t',
|
254
|
+
'--temperature FLOAT',
|
255
|
+
Float,
|
256
|
+
"Optional temperature for chat inference. Defaults to #{DEFAULT_TEMPERATURE}"
|
257
|
+
) do |temp|
|
258
|
+
options[:temperature] = temp
|
259
|
+
end
|
260
|
+
|
261
|
+
opts.on(
|
262
|
+
'-d',
|
263
|
+
'--debug',
|
264
|
+
'Enable debug logging'
|
265
|
+
) do
|
266
|
+
options[:debug] = true
|
267
|
+
end
|
268
|
+
end.parse!
|
269
|
+
|
270
|
+
api_key = options[:api_key] || ENV.fetch('MISTRAL_API_KEY')
|
271
|
+
model = options[:model] || DEFAULT_MODEL
|
272
|
+
system_message = options[:system_message]
|
273
|
+
temperature = options[:temperature] || DEFAULT_TEMPERATURE
|
274
|
+
|
275
|
+
$logger.level = options[:debug] ? Logger::DEBUG : Logger::INFO
|
276
|
+
|
277
|
+
$logger.debug(
|
278
|
+
"Starting chatbot with model: #{model}, " \
|
279
|
+
"temperature: #{temperature}, " \
|
280
|
+
"system message: #{system_message}"
|
281
|
+
)
|
282
|
+
|
283
|
+
begin
|
284
|
+
bot = ChatBot.new(api_key, model, system_message, temperature)
|
285
|
+
bot.start
|
286
|
+
rescue StandardError => e
|
287
|
+
$logger.error(e)
|
288
|
+
exit(1)
|
289
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
9
|
+
client = Mistral::Client.new(api_key: api_key)
|
10
|
+
|
11
|
+
embeddings_response = client.embeddings(
|
12
|
+
model: 'mistral-embed',
|
13
|
+
input: ['What is the best French cheese?'] * 10
|
14
|
+
)
|
15
|
+
|
16
|
+
puts embeddings_response.to_h
|
@@ -0,0 +1,104 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
# Assuming we have the following data
|
9
|
+
data = {
|
10
|
+
'transaction_id' => %w[T1001 T1002 T1003 T1004 T1005],
|
11
|
+
'customer_id' => %w[C001 C002 C003 C002 C001],
|
12
|
+
'payment_amount' => [125.50, 89.99, 120.00, 54.30, 210.20],
|
13
|
+
'payment_date' => %w[2021-10-05 2021-10-06 2021-10-07 2021-10-05 2021-10-08],
|
14
|
+
'payment_status' => %w[Paid Unpaid Paid Paid Pending]
|
15
|
+
}
|
16
|
+
|
17
|
+
def retrieve_payment_status(data, transaction_id)
|
18
|
+
data['transaction_id'].each_with_index do |r, i|
|
19
|
+
return { status: data['payment_status'][i] }.to_json if r == transaction_id
|
20
|
+
end
|
21
|
+
|
22
|
+
{ status: 'Error - transaction id not found' }.to_json
|
23
|
+
end
|
24
|
+
|
25
|
+
def retrieve_payment_date(data, transaction_id)
|
26
|
+
data['transaction_id'].each_with_index do |r, i|
|
27
|
+
return { date: data['payment_date'][i] }.to_json if r == transaction_id
|
28
|
+
end
|
29
|
+
|
30
|
+
{ status: 'Error - transaction id not found' }.to_json
|
31
|
+
end
|
32
|
+
|
33
|
+
names_to_functions = {
|
34
|
+
'retrieve_payment_status' => ->(transaction_id) { retrieve_payment_status(data, transaction_id) },
|
35
|
+
'retrieve_payment_date' => ->(transaction_id) { retrieve_payment_date(data, transaction_id) }
|
36
|
+
}
|
37
|
+
|
38
|
+
tools = [
|
39
|
+
{
|
40
|
+
'type' => 'function',
|
41
|
+
'function' => Mistral::Function.new(
|
42
|
+
name: 'retrieve_payment_status',
|
43
|
+
description: 'Get payment status of a transaction id',
|
44
|
+
parameters: {
|
45
|
+
'type' => 'object',
|
46
|
+
'required' => ['transaction_id'],
|
47
|
+
'properties' => {
|
48
|
+
'transaction_id' => {
|
49
|
+
'type' => 'string',
|
50
|
+
'description' => 'The transaction id.'
|
51
|
+
}
|
52
|
+
}
|
53
|
+
}
|
54
|
+
)
|
55
|
+
},
|
56
|
+
{
|
57
|
+
'type' => 'function',
|
58
|
+
'function' => Mistral::Function.new(
|
59
|
+
name: 'retrieve_payment_date',
|
60
|
+
description: 'Get payment date of a transaction id',
|
61
|
+
parameters: {
|
62
|
+
'type' => 'object',
|
63
|
+
'required' => ['transaction_id'],
|
64
|
+
'properties' => {
|
65
|
+
'transaction_id' => {
|
66
|
+
'type' => 'string',
|
67
|
+
'description' => 'The transaction id.'
|
68
|
+
}
|
69
|
+
}
|
70
|
+
}
|
71
|
+
)
|
72
|
+
}
|
73
|
+
]
|
74
|
+
|
75
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
76
|
+
model = 'mistral-large-latest'
|
77
|
+
|
78
|
+
client = Mistral::Client.new(api_key: api_key)
|
79
|
+
|
80
|
+
messages = [Mistral::ChatMessage.new(role: 'user', content: "What's the status of my transaction?")]
|
81
|
+
|
82
|
+
response = client.chat(model: model, messages: messages, tools: tools)
|
83
|
+
|
84
|
+
puts response.choices[0].message.content
|
85
|
+
|
86
|
+
messages << Mistral::ChatMessage.new(role: 'assistant', content: response.choices[0].message.content)
|
87
|
+
messages << Mistral::ChatMessage.new(role: 'user', content: 'My transaction ID is T1001.')
|
88
|
+
|
89
|
+
response = client.chat(model: model, messages: messages, tools: tools)
|
90
|
+
|
91
|
+
tool_call = response.choices[0].message.tool_calls[0]
|
92
|
+
function_name = tool_call.function.name
|
93
|
+
function_params = JSON.parse(tool_call.function.arguments)
|
94
|
+
|
95
|
+
puts "calling function_name: #{function_name}, with function_params: #{function_params}"
|
96
|
+
|
97
|
+
function_result = names_to_functions[function_name].call(function_params['transaction_id'])
|
98
|
+
|
99
|
+
messages << response.choices[0].message
|
100
|
+
messages << Mistral::ChatMessage.new(role: 'tool', name: function_name, content: function_result)
|
101
|
+
|
102
|
+
response = client.chat(model: model, messages: messages, tools: tools)
|
103
|
+
|
104
|
+
puts response.choices[0].message.content
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
9
|
+
client = Mistral::Client.new(api_key: api_key)
|
10
|
+
|
11
|
+
model = 'mistral-large-latest'
|
12
|
+
|
13
|
+
chat_response = client.chat(
|
14
|
+
model: model,
|
15
|
+
response_format: { type: 'json_object' },
|
16
|
+
messages: [Mistral::ChatMessage.new(
|
17
|
+
role: 'user', content: 'What is the best French cheese? Answer shortly in JSON.'
|
18
|
+
)]
|
19
|
+
)
|
20
|
+
|
21
|
+
puts chat_response.choices[0].message.content
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'bundler/setup'
|
5
|
+
require 'dotenv/load'
|
6
|
+
require 'mistral'
|
7
|
+
|
8
|
+
api_key = ENV.fetch('MISTRAL_API_KEY')
|
9
|
+
client = Mistral::Client.new(api_key: api_key)
|
10
|
+
|
11
|
+
list_models_response = client.list_models
|
12
|
+
|
13
|
+
puts list_models_response.to_h
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HTTP
|
4
|
+
module Features
|
5
|
+
class LineIterableBody < Feature
|
6
|
+
def wrap_response(response)
|
7
|
+
options = {
|
8
|
+
status: response.status,
|
9
|
+
version: response.version,
|
10
|
+
headers: response.headers,
|
11
|
+
proxy_headers: response.proxy_headers,
|
12
|
+
connection: response.connection,
|
13
|
+
body: IterableBodyWrapper.new(response.body, response.body.instance_variable_get(:@encoding)),
|
14
|
+
request: response.request
|
15
|
+
}
|
16
|
+
|
17
|
+
HTTP::Response.new(options)
|
18
|
+
end
|
19
|
+
|
20
|
+
class IterableBodyWrapper < HTTP::Response::Body
|
21
|
+
def initialize(body, encoding)
|
22
|
+
super(body, encoding: encoding)
|
23
|
+
end
|
24
|
+
|
25
|
+
def each_line(&block)
|
26
|
+
each do |chunk|
|
27
|
+
chunk.each_line(&block)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
HTTP::Options.register_feature(:line_iterable_body, Features::LineIterableBody)
|
35
|
+
end
|