mistral 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,289 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ # Simple chatbot example -- run with -h argument to see options.
5
+
6
+ require 'bundler/setup'
7
+ require 'dotenv/load'
8
+ require 'readline'
9
+ require 'optparse'
10
+ require 'mistral'
11
+
12
+ MODEL_LIST = %w[
13
+ mistral-tiny
14
+ mistral-small
15
+ mistral-medium
16
+ ].freeze
17
+ DEFAULT_MODEL = 'mistral-small'
18
+ DEFAULT_TEMPERATURE = 0.7
19
+ LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
20
+ # A hash of all commands and their arguments, used for tab completion.
21
+ COMMAND_LIST = {
22
+ '/new' => {},
23
+ '/help' => {},
24
+ '/model' => MODEL_LIST.map { |model| [model, {}] }.to_h, # Nested completions for models
25
+ '/system' => {},
26
+ '/temperature' => {},
27
+ '/config' => {},
28
+ '/quit' => {},
29
+ '/exit' => {}
30
+ }.freeze
31
+
32
+ $logger = Logger.new($stdout)
33
+ $logger.level = Logger::INFO
34
+ $logger.formatter = proc do |severity, datetime, _, msg|
35
+ "#{datetime.strftime("%Y-%m-%d %H:%M:%S")} - #{severity} - #{msg}\n"
36
+ end
37
+
38
+ def find_completions(command_dict, parts)
39
+ return command_dict.keys if parts.empty?
40
+
41
+ if command_dict.key?(parts[0])
42
+ find_completions(command_dict[parts[0]], parts[1..])
43
+ else
44
+ command_dict.keys.select { |cmd| cmd.start_with?(parts[0]) }
45
+ end
46
+ end
47
+
48
+ # Enable tab completion
49
+ Readline.completion_proc = proc do |_input|
50
+ line_parts = Readline.line_buffer.lstrip.split(' ')
51
+ options = find_completions(COMMAND_LIST, line_parts[0..-2])
52
+ options.select { |option| option.start_with?(line_parts[-1]) }
53
+ end
54
+
55
+ class ChatBot
56
+ def initialize(api_key, model, system_message = nil, temperature = DEFAULT_TEMPERATURE)
57
+ raise ArgumentError, 'An API key must be provided to use the Mistral API.' if api_key.nil?
58
+
59
+ @client = Mistral::Client.new(api_key: api_key)
60
+ @model = model
61
+ @temperature = temperature
62
+ @system_message = system_message
63
+ end
64
+
65
+ def opening_instructions
66
+ puts '
67
+ To chat: type your message and hit enter
68
+ To start a new chat: /new
69
+ To switch model: /model <model name>
70
+ To switch system message: /system <message>
71
+ To switch temperature: /temperature <temperature>
72
+ To see current config: /config
73
+ To exit: /exit, /quit, or hit CTRL+C
74
+ To see this help: /help
75
+ '
76
+ end
77
+
78
+ def new_chat
79
+ puts ''
80
+ puts "Starting new chat with model: #{@model}, temperature: #{@temperature}"
81
+ puts ''
82
+ @messages = []
83
+ @messages << Mistral::ChatMessage.new(role: 'system', content: @system_message) if @system_message
84
+ end
85
+
86
+ def switch_model(input)
87
+ model = get_arguments(input)
88
+
89
+ if MODEL_LIST.include?(model)
90
+ @model = model
91
+ $logger.info("Switching model: #{model}")
92
+ else
93
+ $logger.error("Invalid model name: #{model}")
94
+ end
95
+ end
96
+
97
+ def switch_system_message(input)
98
+ system_message = get_arguments(input)
99
+
100
+ if system_message
101
+ @system_message = system_message
102
+ $logger.info("Switching system message: #{system_message}")
103
+ new_chat
104
+ else
105
+ $logger.error("Invalid system message: #{system_message}")
106
+ end
107
+ end
108
+
109
+ def switch_temperature(input)
110
+ temperature = get_arguments(input)
111
+
112
+ begin
113
+ temperature = Float(temperature)
114
+
115
+ raise ArgumentError if temperature.negative? || temperature > 1
116
+
117
+ @temperature = temperature
118
+ $logger.info("Switching temperature: #{temperature}")
119
+ rescue ArgumentError
120
+ $logger.error("Invalid temperature: #{temperature}")
121
+ end
122
+ end
123
+
124
+ def show_config
125
+ puts ''
126
+ puts "Current model: #{@model}"
127
+ puts "Current temperature: #{@temperature}"
128
+ puts "Current system message: #{@system_message}"
129
+ puts ''
130
+ end
131
+
132
+ def collect_user_input
133
+ puts ''
134
+ print 'YOU: '
135
+ gets.chomp
136
+ end
137
+
138
+ def run_inference(content)
139
+ puts ''
140
+ puts 'MISTRAL:'
141
+ puts ''
142
+
143
+ @messages << Mistral::ChatMessage.new(role: 'user', content: content)
144
+
145
+ assistant_response = ''
146
+
147
+ $logger.debug("Running inference with model: #{@model}, temperature: #{@temperature}")
148
+ $logger.debug("Sending messages: #{@messages}")
149
+
150
+ @client.chat_stream(model: @model, temperature: @temperature, messages: @messages).each do |chunk|
151
+ response = chunk.choices[0].delta.content
152
+
153
+ if response
154
+ print response
155
+ assistant_response += response
156
+ end
157
+ end
158
+
159
+ puts ''
160
+
161
+ @messages << Mistral::ChatMessage.new(role: 'assistant', content: assistant_response) if assistant_response
162
+
163
+ $logger.debug("Current messages: #{@messages}")
164
+ end
165
+
166
+ def get_command(input)
167
+ input.split[0].strip
168
+ end
169
+
170
+ def get_arguments(input)
171
+ input.split[1..].join(' ')
172
+ rescue IndexError
173
+ ''
174
+ end
175
+
176
+ def is_command?(input)
177
+ COMMAND_LIST.key?(get_command(input))
178
+ end
179
+
180
+ def execute_command(input)
181
+ command = get_command(input)
182
+ case command
183
+ when '/exit', '/quit'
184
+ exit
185
+ when '/help'
186
+ opening_instructions
187
+ when '/new'
188
+ new_chat
189
+ when '/model'
190
+ switch_model(input)
191
+ when '/system'
192
+ switch_system_message(input)
193
+ when '/temperature'
194
+ switch_temperature(input)
195
+ when '/config'
196
+ show_config
197
+ end
198
+ end
199
+
200
+ def start
201
+ opening_instructions
202
+ new_chat
203
+
204
+ loop do
205
+ input = collect_user_input
206
+
207
+ if is_command?(input)
208
+ execute_command(input)
209
+ else
210
+ run_inference(input)
211
+ end
212
+ rescue Interrupt
213
+ exit
214
+ end
215
+ end
216
+
217
+ def exit
218
+ $logger.debug('Exiting chatbot')
219
+ puts 'Goodbye!'
220
+ Kernel.exit(0)
221
+ end
222
+ end
223
+
224
+ options = {}
225
+ OptionParser.new do |opts|
226
+ opts.banner = 'Usage: chatbot.rb [options]'
227
+
228
+ opts.on(
229
+ '--api-key KEY',
230
+ 'Mistral API key. Defaults to environment variable MISTRAL_API_KEY'
231
+ ) do |key|
232
+ options[:api_key] = key
233
+ end
234
+
235
+ opts.on(
236
+ '-m',
237
+ '--model MODEL',
238
+ MODEL_LIST,
239
+ "Model for chat inference. Choices are #{MODEL_LIST.join(", ")}. Defaults to #{DEFAULT_MODEL}"
240
+ ) do |model|
241
+ options[:model] = model
242
+ end
243
+
244
+ opts.on(
245
+ '-s',
246
+ '--system-message MESSAGE',
247
+ 'Optional system message to prepend'
248
+ ) do |message|
249
+ options[:system_message] = message
250
+ end
251
+
252
+ opts.on(
253
+ '-t',
254
+ '--temperature FLOAT',
255
+ Float,
256
+ "Optional temperature for chat inference. Defaults to #{DEFAULT_TEMPERATURE}"
257
+ ) do |temp|
258
+ options[:temperature] = temp
259
+ end
260
+
261
+ opts.on(
262
+ '-d',
263
+ '--debug',
264
+ 'Enable debug logging'
265
+ ) do
266
+ options[:debug] = true
267
+ end
268
+ end.parse!
269
+
270
+ api_key = options[:api_key] || ENV.fetch('MISTRAL_API_KEY')
271
+ model = options[:model] || DEFAULT_MODEL
272
+ system_message = options[:system_message]
273
+ temperature = options[:temperature] || DEFAULT_TEMPERATURE
274
+
275
+ $logger.level = options[:debug] ? Logger::DEBUG : Logger::INFO
276
+
277
+ $logger.debug(
278
+ "Starting chatbot with model: #{model}, " \
279
+ "temperature: #{temperature}, " \
280
+ "system message: #{system_message}"
281
+ )
282
+
283
+ begin
284
+ bot = ChatBot.new(api_key, model, system_message, temperature)
285
+ bot.start
286
+ rescue StandardError => e
287
+ $logger.error(e)
288
+ exit(1)
289
+ end
@@ -0,0 +1,16 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require 'bundler/setup'
5
+ require 'dotenv/load'
6
+ require 'mistral'
7
+
8
+ api_key = ENV.fetch('MISTRAL_API_KEY')
9
+ client = Mistral::Client.new(api_key: api_key)
10
+
11
+ embeddings_response = client.embeddings(
12
+ model: 'mistral-embed',
13
+ input: ['What is the best French cheese?'] * 10
14
+ )
15
+
16
+ puts embeddings_response.to_h
@@ -0,0 +1,104 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require 'bundler/setup'
5
+ require 'dotenv/load'
6
+ require 'mistral'
7
+
8
+ # Assuming we have the following data
9
+ data = {
10
+ 'transaction_id' => %w[T1001 T1002 T1003 T1004 T1005],
11
+ 'customer_id' => %w[C001 C002 C003 C002 C001],
12
+ 'payment_amount' => [125.50, 89.99, 120.00, 54.30, 210.20],
13
+ 'payment_date' => %w[2021-10-05 2021-10-06 2021-10-07 2021-10-05 2021-10-08],
14
+ 'payment_status' => %w[Paid Unpaid Paid Paid Pending]
15
+ }
16
+
17
+ def retrieve_payment_status(data, transaction_id)
18
+ data['transaction_id'].each_with_index do |r, i|
19
+ return { status: data['payment_status'][i] }.to_json if r == transaction_id
20
+ end
21
+
22
+ { status: 'Error - transaction id not found' }.to_json
23
+ end
24
+
25
+ def retrieve_payment_date(data, transaction_id)
26
+ data['transaction_id'].each_with_index do |r, i|
27
+ return { date: data['payment_date'][i] }.to_json if r == transaction_id
28
+ end
29
+
30
+ { status: 'Error - transaction id not found' }.to_json
31
+ end
32
+
33
+ names_to_functions = {
34
+ 'retrieve_payment_status' => ->(transaction_id) { retrieve_payment_status(data, transaction_id) },
35
+ 'retrieve_payment_date' => ->(transaction_id) { retrieve_payment_date(data, transaction_id) }
36
+ }
37
+
38
+ tools = [
39
+ {
40
+ 'type' => 'function',
41
+ 'function' => Mistral::Function.new(
42
+ name: 'retrieve_payment_status',
43
+ description: 'Get payment status of a transaction id',
44
+ parameters: {
45
+ 'type' => 'object',
46
+ 'required' => ['transaction_id'],
47
+ 'properties' => {
48
+ 'transaction_id' => {
49
+ 'type' => 'string',
50
+ 'description' => 'The transaction id.'
51
+ }
52
+ }
53
+ }
54
+ )
55
+ },
56
+ {
57
+ 'type' => 'function',
58
+ 'function' => Mistral::Function.new(
59
+ name: 'retrieve_payment_date',
60
+ description: 'Get payment date of a transaction id',
61
+ parameters: {
62
+ 'type' => 'object',
63
+ 'required' => ['transaction_id'],
64
+ 'properties' => {
65
+ 'transaction_id' => {
66
+ 'type' => 'string',
67
+ 'description' => 'The transaction id.'
68
+ }
69
+ }
70
+ }
71
+ )
72
+ }
73
+ ]
74
+
75
+ api_key = ENV.fetch('MISTRAL_API_KEY')
76
+ model = 'mistral-large-latest'
77
+
78
+ client = Mistral::Client.new(api_key: api_key)
79
+
80
+ messages = [Mistral::ChatMessage.new(role: 'user', content: "What's the status of my transaction?")]
81
+
82
+ response = client.chat(model: model, messages: messages, tools: tools)
83
+
84
+ puts response.choices[0].message.content
85
+
86
+ messages << Mistral::ChatMessage.new(role: 'assistant', content: response.choices[0].message.content)
87
+ messages << Mistral::ChatMessage.new(role: 'user', content: 'My transaction ID is T1001.')
88
+
89
+ response = client.chat(model: model, messages: messages, tools: tools)
90
+
91
+ tool_call = response.choices[0].message.tool_calls[0]
92
+ function_name = tool_call.function.name
93
+ function_params = JSON.parse(tool_call.function.arguments)
94
+
95
+ puts "calling function_name: #{function_name}, with function_params: #{function_params}"
96
+
97
+ function_result = names_to_functions[function_name].call(function_params['transaction_id'])
98
+
99
+ messages << response.choices[0].message
100
+ messages << Mistral::ChatMessage.new(role: 'tool', name: function_name, content: function_result)
101
+
102
+ response = client.chat(model: model, messages: messages, tools: tools)
103
+
104
+ puts response.choices[0].message.content
@@ -0,0 +1,21 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require 'bundler/setup'
5
+ require 'dotenv/load'
6
+ require 'mistral'
7
+
8
+ api_key = ENV.fetch('MISTRAL_API_KEY')
9
+ client = Mistral::Client.new(api_key: api_key)
10
+
11
+ model = 'mistral-large-latest'
12
+
13
+ chat_response = client.chat(
14
+ model: model,
15
+ response_format: { type: 'json_object' },
16
+ messages: [Mistral::ChatMessage.new(
17
+ role: 'user', content: 'What is the best French cheese? Answer shortly in JSON.'
18
+ )]
19
+ )
20
+
21
+ puts chat_response.choices[0].message.content
@@ -0,0 +1,13 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require 'bundler/setup'
5
+ require 'dotenv/load'
6
+ require 'mistral'
7
+
8
+ api_key = ENV.fetch('MISTRAL_API_KEY')
9
+ client = Mistral::Client.new(api_key: api_key)
10
+
11
+ list_models_response = client.list_models
12
+
13
+ puts list_models_response.to_h
@@ -0,0 +1,35 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HTTP
4
+ module Features
5
+ class LineIterableBody < Feature
6
+ def wrap_response(response)
7
+ options = {
8
+ status: response.status,
9
+ version: response.version,
10
+ headers: response.headers,
11
+ proxy_headers: response.proxy_headers,
12
+ connection: response.connection,
13
+ body: IterableBodyWrapper.new(response.body, response.body.instance_variable_get(:@encoding)),
14
+ request: response.request
15
+ }
16
+
17
+ HTTP::Response.new(options)
18
+ end
19
+
20
+ class IterableBodyWrapper < HTTP::Response::Body
21
+ def initialize(body, encoding)
22
+ super(body, encoding: encoding)
23
+ end
24
+
25
+ def each_line(&block)
26
+ each do |chunk|
27
+ chunk.each_line(&block)
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+
34
+ HTTP::Options.register_feature(:line_iterable_body, Features::LineIterableBody)
35
+ end