mistral 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.env.example +3 -0
- data/.rubocop.yml +60 -0
- data/.tool-versions +1 -0
- data/CHANGELOG.md +12 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +21 -0
- data/PYTHON_CLIENT_COMPARISON.md +184 -0
- data/README.md +145 -0
- data/Rakefile +12 -0
- data/examples/chat_no_streaming.rb +18 -0
- data/examples/chat_with_streaming.rb +18 -0
- data/examples/chatbot_with_streaming.rb +289 -0
- data/examples/embeddings.rb +16 -0
- data/examples/function_calling.rb +104 -0
- data/examples/json_format.rb +21 -0
- data/examples/list_models.rb +13 -0
- data/lib/http/features/line_iterable_body.rb +35 -0
- data/lib/mistral/client.rb +229 -0
- data/lib/mistral/client_base.rb +126 -0
- data/lib/mistral/constants.rb +6 -0
- data/lib/mistral/exceptions.rb +38 -0
- data/lib/mistral/models/chat_completion.rb +95 -0
- data/lib/mistral/models/common.rb +11 -0
- data/lib/mistral/models/embeddings.rb +21 -0
- data/lib/mistral/models/models.rb +39 -0
- data/lib/mistral/version.rb +5 -0
- data/lib/mistral.rb +24 -0
- metadata +172 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'mistral/client_base'
|
4
|
+
|
5
|
+
module Mistral
|
6
|
+
# Synchronous wrapper around the async client
|
7
|
+
class Client < ClientBase
|
8
|
+
def initialize(
|
9
|
+
api_key: ENV['MISTRAL_API_KEY'],
|
10
|
+
endpoint: ENDPOINT,
|
11
|
+
max_retries: 5,
|
12
|
+
timeout: 120
|
13
|
+
)
|
14
|
+
super(endpoint: endpoint, api_key: api_key, max_retries: max_retries, timeout: timeout)
|
15
|
+
|
16
|
+
@client = HTTP.persistent(ENDPOINT)
|
17
|
+
.follow
|
18
|
+
.timeout(timeout)
|
19
|
+
.use(:line_iterable_body)
|
20
|
+
.headers('Accept' => 'application/json',
|
21
|
+
'User-Agent' => "mistral-client-ruby/#{VERSION}",
|
22
|
+
'Authorization' => "Bearer #{@api_key}",
|
23
|
+
'Content-Type' => 'application/json'
|
24
|
+
)
|
25
|
+
end
|
26
|
+
|
27
|
+
# A chat endpoint that returns a single response.
|
28
|
+
#
|
29
|
+
# @param messages [Array<ChatMessage>] An array of messages to chat with, e.g.
|
30
|
+
# [{role: 'user', content: 'What is the best French cheese?'}]
|
31
|
+
# @param model [String] The name of the model to chat with, e.g. mistral-tiny
|
32
|
+
# @param tools [Array<Hash>] A list of tools to use.
|
33
|
+
# @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
|
34
|
+
# @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
|
35
|
+
# @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
|
36
|
+
# @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
|
37
|
+
# @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
|
38
|
+
# @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
|
39
|
+
# @param tool_choice [String, ToolChoice] The tool choice.
|
40
|
+
# @param response_format [Hash<String, String>, ResponseFormat] The response format.
|
41
|
+
# @return [ChatCompletionResponse] A response object containing the generated text.
|
42
|
+
#
|
43
|
+
def chat(
|
44
|
+
messages:,
|
45
|
+
model: nil,
|
46
|
+
tools: nil,
|
47
|
+
temperature: nil,
|
48
|
+
max_tokens: nil,
|
49
|
+
top_p: nil,
|
50
|
+
random_seed: nil,
|
51
|
+
safe_mode: false,
|
52
|
+
safe_prompt: false,
|
53
|
+
tool_choice: nil,
|
54
|
+
response_format: nil
|
55
|
+
)
|
56
|
+
request = make_chat_request(
|
57
|
+
messages: messages,
|
58
|
+
model: model,
|
59
|
+
tools: tools,
|
60
|
+
temperature: temperature,
|
61
|
+
max_tokens: max_tokens,
|
62
|
+
top_p: top_p,
|
63
|
+
random_seed: random_seed,
|
64
|
+
stream: false,
|
65
|
+
safe_prompt: safe_mode || safe_prompt,
|
66
|
+
tool_choice: tool_choice,
|
67
|
+
response_format: response_format
|
68
|
+
)
|
69
|
+
|
70
|
+
single_response = request('post', 'v1/chat/completions', json: request)
|
71
|
+
|
72
|
+
single_response.each do |response|
|
73
|
+
return ChatCompletionResponse.new(response)
|
74
|
+
end
|
75
|
+
|
76
|
+
raise Mistral::Error.new(message: 'No response received')
|
77
|
+
end
|
78
|
+
|
79
|
+
# A chat endpoint that streams responses.
|
80
|
+
#
|
81
|
+
# @param messages [Array<Any>] An array of messages to chat with, e.g.
|
82
|
+
# [{role: 'user', content: 'What is the best French cheese?'}]
|
83
|
+
# @param model [String] The name of the model to chat with, e.g. mistral-tiny
|
84
|
+
# @param tools [Array<Hash>] A list of tools to use.
|
85
|
+
# @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
|
86
|
+
# @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
|
87
|
+
# @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
|
88
|
+
# @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
|
89
|
+
# @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
|
90
|
+
# @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
|
91
|
+
# @param tool_choice [String, ToolChoice] The tool choice.
|
92
|
+
# @param response_format [Hash<String, String>, ResponseFormat] The response format.
|
93
|
+
# @return [Enumerator<ChatCompletionStreamResponse>] A generator that yields ChatCompletionStreamResponse objects.
|
94
|
+
#
|
95
|
+
def chat_stream(
|
96
|
+
messages:,
|
97
|
+
model: nil,
|
98
|
+
tools: nil,
|
99
|
+
temperature: nil,
|
100
|
+
max_tokens: nil,
|
101
|
+
top_p: nil,
|
102
|
+
random_seed: nil,
|
103
|
+
safe_mode: false,
|
104
|
+
safe_prompt: false,
|
105
|
+
tool_choice: nil,
|
106
|
+
response_format: nil
|
107
|
+
)
|
108
|
+
request = make_chat_request(
|
109
|
+
messages: messages,
|
110
|
+
model: model,
|
111
|
+
tools: tools,
|
112
|
+
temperature: temperature,
|
113
|
+
max_tokens: max_tokens,
|
114
|
+
top_p: top_p,
|
115
|
+
random_seed: random_seed,
|
116
|
+
stream: true,
|
117
|
+
safe_prompt: safe_mode || safe_prompt,
|
118
|
+
tool_choice: tool_choice,
|
119
|
+
response_format: response_format
|
120
|
+
)
|
121
|
+
|
122
|
+
Enumerator.new do |yielder|
|
123
|
+
request('post', 'v1/chat/completions', json: request, stream: true).each do |json_response|
|
124
|
+
yielder << ChatCompletionStreamResponse.new(**json_response)
|
125
|
+
end
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
# An embeddings endpoint that returns embeddings for a single, or batch of inputs
|
130
|
+
#
|
131
|
+
# @param model [String] The embedding model to use, e.g. mistral-embed
|
132
|
+
# @param input [String, Array<String>] The input to embed, e.g. ['What is the best French cheese?']
|
133
|
+
#
|
134
|
+
# @return [EmbeddingResponse] A response object containing the embeddings.
|
135
|
+
#
|
136
|
+
def embeddings(model:, input:)
|
137
|
+
request = { model: model, input: input }
|
138
|
+
singleton_response = request('post', 'v1/embeddings', json: request)
|
139
|
+
|
140
|
+
singleton_response.each do |response|
|
141
|
+
return EmbeddingResponse.new(response)
|
142
|
+
end
|
143
|
+
|
144
|
+
raise Mistral::Error.new(message: 'No response received')
|
145
|
+
end
|
146
|
+
|
147
|
+
# Returns a list of the available models
|
148
|
+
#
|
149
|
+
# @return [ModelList] A response object containing the list of models.
|
150
|
+
#
|
151
|
+
def list_models
|
152
|
+
singleton_response = request('get', 'v1/models')
|
153
|
+
|
154
|
+
singleton_response.each do |response|
|
155
|
+
return ModelList.new(response)
|
156
|
+
end
|
157
|
+
|
158
|
+
raise Mistral::Error.new(message: 'No response received')
|
159
|
+
end
|
160
|
+
|
161
|
+
private
|
162
|
+
|
163
|
+
def request(method, path, json: nil, stream: false, attempt: 1)
|
164
|
+
url = File.join(@endpoint, path)
|
165
|
+
headers = {}
|
166
|
+
headers['Accept'] = 'text/event-stream' if stream
|
167
|
+
|
168
|
+
@logger.debug("Sending request: #{method.upcase} #{url} #{json}")
|
169
|
+
|
170
|
+
Enumerator.new do |yielder|
|
171
|
+
response = @client.headers(headers).request(method.downcase.to_sym, url, json: json)
|
172
|
+
check_response_status_codes(response)
|
173
|
+
|
174
|
+
if stream
|
175
|
+
response.body.each_line do |line|
|
176
|
+
processed_line = process_line(line)
|
177
|
+
next if processed_line.nil?
|
178
|
+
|
179
|
+
yielder << processed_line
|
180
|
+
end
|
181
|
+
else
|
182
|
+
yielder << check_response(response)
|
183
|
+
end
|
184
|
+
rescue HTTP::ConnectionError => e
|
185
|
+
raise Mistral::ConnectionError, e.message
|
186
|
+
rescue HTTP::RequestError => e
|
187
|
+
raise Mistral::Error, "Unexpected exception (#{e.class}): #{e.message}"
|
188
|
+
rescue JSON::ParserError
|
189
|
+
raise Mistral::APIError.from_response(response, message: "Failed to decode json body: #{response.body}")
|
190
|
+
rescue Mistral::APIStatusError => e
|
191
|
+
attempt += 1
|
192
|
+
|
193
|
+
raise Mistral::APIStatusError.from_response(response, message: e.message) if attempt > @max_retries
|
194
|
+
|
195
|
+
backoff = 2.0**attempt # exponential backoff
|
196
|
+
sleep(backoff)
|
197
|
+
|
198
|
+
# Retry and yield the response
|
199
|
+
request(method, path, json: json, stream: stream, attempt: attempt).each do |r|
|
200
|
+
yielder << r
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
def check_response(response)
|
206
|
+
check_response_status_codes(response)
|
207
|
+
|
208
|
+
json_response = JSON.parse(response.body.to_s)
|
209
|
+
|
210
|
+
if !json_response.key?('object')
|
211
|
+
raise Mistral::Error, "Unexpected response: #{json_response}"
|
212
|
+
elsif json_response['object'] == 'error' # has errors
|
213
|
+
raise Mistral::APIError.from_response(response, message: json_response['message'])
|
214
|
+
end
|
215
|
+
|
216
|
+
json_response
|
217
|
+
end
|
218
|
+
|
219
|
+
def check_response_status_codes(response)
|
220
|
+
if RETRY_STATUS_CODES.include?(response.code)
|
221
|
+
raise APIStatusError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
|
222
|
+
elsif response.code >= 400 && response.code < 500
|
223
|
+
raise APIError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
|
224
|
+
elsif response.code >= 500
|
225
|
+
raise Mistral::Error.new(message: "Status: #{response.code}. Message: #{response.body}")
|
226
|
+
end
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
@@ -0,0 +1,126 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class ClientBase
|
5
|
+
attr_reader :endpoint, :api_key, :max_retries, :timeout
|
6
|
+
|
7
|
+
def initialize(endpoint:, api_key: nil, max_retries: 5, timeout: 120)
|
8
|
+
@endpoint = endpoint
|
9
|
+
@api_key = api_key
|
10
|
+
@max_retries = max_retries
|
11
|
+
@timeout = timeout
|
12
|
+
|
13
|
+
@logger = config_logger
|
14
|
+
|
15
|
+
# For azure endpoints, we default to the mistral model
|
16
|
+
@default_model = 'mistral' if endpoint.include?('inference.azure.com')
|
17
|
+
end
|
18
|
+
|
19
|
+
protected
|
20
|
+
|
21
|
+
def parse_tools(tools)
|
22
|
+
parsed_tools = []
|
23
|
+
|
24
|
+
tools.each do |tool|
|
25
|
+
next unless tool['type'] == 'function'
|
26
|
+
|
27
|
+
parsed_function = {}
|
28
|
+
parsed_function['type'] = tool['type']
|
29
|
+
parsed_function['function'] = if tool['function'].is_a?(Function)
|
30
|
+
tool['function'].to_h
|
31
|
+
else
|
32
|
+
tool['function']
|
33
|
+
end
|
34
|
+
|
35
|
+
parsed_tools << parsed_function
|
36
|
+
end
|
37
|
+
|
38
|
+
parsed_tools
|
39
|
+
end
|
40
|
+
|
41
|
+
def parse_tool_choice(tool_choice)
|
42
|
+
tool_choice.is_a?(ToolChoice) ? tool_choice.to_s : tool_choice
|
43
|
+
end
|
44
|
+
|
45
|
+
def parse_response_format(response_format)
|
46
|
+
if response_format.is_a?(ResponseFormat)
|
47
|
+
response_format.to_h
|
48
|
+
else
|
49
|
+
response_format
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
def parse_messages(messages)
|
54
|
+
parsed_messages = []
|
55
|
+
|
56
|
+
messages.each do |message|
|
57
|
+
parsed_messages << if message.is_a?(ChatMessage)
|
58
|
+
message.to_h
|
59
|
+
else
|
60
|
+
message
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
parsed_messages
|
65
|
+
end
|
66
|
+
|
67
|
+
def make_chat_request(
|
68
|
+
messages:,
|
69
|
+
model: nil,
|
70
|
+
tools: nil,
|
71
|
+
temperature: nil,
|
72
|
+
max_tokens: nil,
|
73
|
+
top_p: nil,
|
74
|
+
random_seed: nil,
|
75
|
+
stream: nil,
|
76
|
+
safe_prompt: false,
|
77
|
+
tool_choice: nil,
|
78
|
+
response_format: nil
|
79
|
+
)
|
80
|
+
request_data = {
|
81
|
+
messages: parse_messages(messages),
|
82
|
+
safe_prompt: safe_prompt
|
83
|
+
}
|
84
|
+
|
85
|
+
request_data[:model] = if model.nil?
|
86
|
+
raise Mistral::Error.new(message: 'model must be provided') if @default_model.nil?
|
87
|
+
|
88
|
+
@default_model
|
89
|
+
else
|
90
|
+
model
|
91
|
+
end
|
92
|
+
|
93
|
+
request_data[:tools] = parse_tools(tools) unless tools.nil?
|
94
|
+
request_data[:temperature] = temperature unless temperature.nil?
|
95
|
+
request_data[:max_tokens] = max_tokens unless max_tokens.nil?
|
96
|
+
request_data[:top_p] = top_p unless top_p.nil?
|
97
|
+
request_data[:random_seed] = random_seed unless random_seed.nil?
|
98
|
+
request_data[:stream] = stream unless stream.nil?
|
99
|
+
request_data[:tool_choice] = parse_tool_choice(tool_choice) unless tool_choice.nil?
|
100
|
+
request_data[:response_format] = parse_response_format(response_format) unless response_format.nil?
|
101
|
+
|
102
|
+
@logger.debug("Chat request: #{request_data}")
|
103
|
+
|
104
|
+
request_data
|
105
|
+
end
|
106
|
+
|
107
|
+
def process_line(line)
|
108
|
+
return unless line.start_with?('data: ')
|
109
|
+
|
110
|
+
line = line[6..].to_s.strip
|
111
|
+
return if line == '[DONE]'
|
112
|
+
|
113
|
+
JSON.parse(line)
|
114
|
+
end
|
115
|
+
|
116
|
+
def config_logger
|
117
|
+
Logger.new($stdout).tap do |logger|
|
118
|
+
logger.level = ENV.fetch('LOG_LEVEL', 'ERROR')
|
119
|
+
|
120
|
+
logger.formatter = proc do |severity, datetime, progname, msg|
|
121
|
+
"#{datetime.strftime("%Y-%m-%d %H:%M:%S")} #{severity} #{progname}: #{msg}\n"
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
# Base error class, returned when nothing more specific applies
|
5
|
+
class Error < StandardError
|
6
|
+
end
|
7
|
+
|
8
|
+
class APIError < Error
|
9
|
+
attr_reader :http_status, :headers
|
10
|
+
|
11
|
+
def initialize(message: nil, http_status: nil, headers: nil)
|
12
|
+
super(message: message)
|
13
|
+
|
14
|
+
@http_status = http_status
|
15
|
+
@headers = headers
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.from_response(response, message: nil)
|
19
|
+
new(
|
20
|
+
message: message || response.to_s,
|
21
|
+
http_status: response.code,
|
22
|
+
headers: response.headers.to_h
|
23
|
+
)
|
24
|
+
end
|
25
|
+
|
26
|
+
def to_s
|
27
|
+
"#{self.class.name}(message=#{super}, http_status=#{http_status})"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
# Returned when we receive a non-200 response from the API that we should retry
|
32
|
+
class APIStatusError < APIError
|
33
|
+
end
|
34
|
+
|
35
|
+
# Returned when the SDK can not reach the API server for any reason
|
36
|
+
class ConnectionError < Error
|
37
|
+
end
|
38
|
+
end
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class Function < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :name, Types::Strict::String
|
8
|
+
attribute :description, Types::Strict::String
|
9
|
+
attribute :parameters, Types::Strict::Hash
|
10
|
+
end
|
11
|
+
|
12
|
+
ToolType = Types::Strict::String.default('function').enum('function')
|
13
|
+
|
14
|
+
class FunctionCall < Dry::Struct
|
15
|
+
transform_keys(&:to_sym)
|
16
|
+
|
17
|
+
attribute :name, Types::Strict::String
|
18
|
+
attribute :arguments, Types::Strict::String
|
19
|
+
end
|
20
|
+
|
21
|
+
class ToolCall < Dry::Struct
|
22
|
+
transform_keys(&:to_sym)
|
23
|
+
|
24
|
+
attribute :id, Types::Strict::String.default('null')
|
25
|
+
attribute :type, ToolType
|
26
|
+
attribute :function, FunctionCall
|
27
|
+
end
|
28
|
+
|
29
|
+
ResponseFormats = Types::Strict::String.default('text').enum('text', 'json_object')
|
30
|
+
|
31
|
+
ToolChoice = Types::Strict::String.enum('auto', 'any', 'none')
|
32
|
+
|
33
|
+
class ResponseFormat < Dry::Struct
|
34
|
+
transform_keys(&:to_sym)
|
35
|
+
|
36
|
+
attribute :type, ResponseFormats
|
37
|
+
end
|
38
|
+
|
39
|
+
class ChatMessage < Dry::Struct
|
40
|
+
transform_keys(&:to_sym)
|
41
|
+
|
42
|
+
attribute :role, Types::Strict::String
|
43
|
+
attribute :content, Types::Strict::Array.of(Types::Strict::String) | Types::Strict::String
|
44
|
+
attribute? :name, Types::String.optional
|
45
|
+
attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
|
46
|
+
end
|
47
|
+
|
48
|
+
class DeltaMessage < Dry::Struct
|
49
|
+
transform_keys(&:to_sym)
|
50
|
+
|
51
|
+
attribute? :role, Types::Strict::String.optional
|
52
|
+
attribute? :content, Types::Strict::String.optional
|
53
|
+
attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
|
54
|
+
end
|
55
|
+
|
56
|
+
FinishReason = Types::Strict::String.enum('stop', 'length', 'error', 'tool_calls')
|
57
|
+
|
58
|
+
class ChatCompletionResponseStreamChoice < Dry::Struct
|
59
|
+
transform_keys(&:to_sym)
|
60
|
+
|
61
|
+
attribute :index, Types::Strict::Integer
|
62
|
+
attribute :delta, DeltaMessage
|
63
|
+
attribute? :finish_reason, FinishReason.optional
|
64
|
+
end
|
65
|
+
|
66
|
+
class ChatCompletionStreamResponse < Dry::Struct
|
67
|
+
transform_keys(&:to_sym)
|
68
|
+
|
69
|
+
attribute :id, Types::Strict::String
|
70
|
+
attribute :model, Types::Strict::String
|
71
|
+
attribute :choices, Types::Strict::Array.of(ChatCompletionResponseStreamChoice)
|
72
|
+
attribute? :created, Types::Strict::Integer.optional
|
73
|
+
attribute? :object, Types::Strict::String.optional
|
74
|
+
attribute? :usage, UsageInfo.optional
|
75
|
+
end
|
76
|
+
|
77
|
+
class ChatCompletionResponseChoice < Dry::Struct
|
78
|
+
transform_keys(&:to_sym)
|
79
|
+
|
80
|
+
attribute :index, Types::Strict::Integer
|
81
|
+
attribute :message, ChatMessage
|
82
|
+
attribute? :finish_reason, FinishReason.optional
|
83
|
+
end
|
84
|
+
|
85
|
+
class ChatCompletionResponse < Dry::Struct
|
86
|
+
transform_keys(&:to_sym)
|
87
|
+
|
88
|
+
attribute :id, Types::Strict::String
|
89
|
+
attribute :object, Types::Strict::String
|
90
|
+
attribute :created, Types::Strict::Integer
|
91
|
+
attribute :model, Types::Strict::String
|
92
|
+
attribute :choices, Types::Strict::Array.of(ChatCompletionResponseChoice)
|
93
|
+
attribute :usage, UsageInfo
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,11 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class UsageInfo < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :prompt_tokens, Types::Strict::Integer
|
8
|
+
attribute :total_tokens, Types::Strict::Integer
|
9
|
+
attribute? :completion_tokens, Types::Strict::Integer.optional
|
10
|
+
end
|
11
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class EmbeddingObject < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :object, Types::Strict::String
|
8
|
+
attribute :embedding, Types::Strict::Array.of(Types::Strict::Float)
|
9
|
+
attribute :index, Types::Strict::Integer
|
10
|
+
end
|
11
|
+
|
12
|
+
class EmbeddingResponse < Dry::Struct
|
13
|
+
transform_keys(&:to_sym)
|
14
|
+
|
15
|
+
attribute :id, Types::Strict::String
|
16
|
+
attribute :object, Types::Strict::String
|
17
|
+
attribute :data, Types::Strict::Array.of(EmbeddingObject)
|
18
|
+
attribute :model, Types::Strict::String
|
19
|
+
attribute :usage, UsageInfo
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class ModelPermission < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :id, Types::Strict::String
|
8
|
+
attribute :object, Types::Strict::String
|
9
|
+
attribute :created, Types::Strict::Integer
|
10
|
+
attribute :allow_create_engine, Types::Strict::Bool.default(false)
|
11
|
+
attribute :allow_sampling, Types::Strict::Bool.default(true)
|
12
|
+
attribute :allow_logprobs, Types::Strict::Bool.default(true)
|
13
|
+
attribute :allow_search_indices, Types::Strict::Bool.default(false)
|
14
|
+
attribute :allow_view, Types::Strict::Bool.default(true)
|
15
|
+
attribute :allow_fine_tuning, Types::Strict::Bool.default(false)
|
16
|
+
attribute :organization, Types::Strict::String.default('*')
|
17
|
+
attribute? :group, Types::Strict::String.optional
|
18
|
+
attribute :is_blocking, Types::Strict::Bool.default(false)
|
19
|
+
end
|
20
|
+
|
21
|
+
class ModelCard < Dry::Struct
|
22
|
+
transform_keys(&:to_sym)
|
23
|
+
|
24
|
+
attribute :id, Types::Strict::String
|
25
|
+
attribute :object, Types::Strict::String
|
26
|
+
attribute :created, Types::Strict::Integer
|
27
|
+
attribute :owned_by, Types::Strict::String
|
28
|
+
attribute? :root, Types::Strict::String.optional
|
29
|
+
attribute? :parent, Types::Strict::String.optional
|
30
|
+
attribute :permission, Types::Strict::Array.of(ModelPermission).default([].freeze)
|
31
|
+
end
|
32
|
+
|
33
|
+
class ModelList < Dry::Struct
|
34
|
+
transform_keys(&:to_sym)
|
35
|
+
|
36
|
+
attribute :object, Types::Strict::String
|
37
|
+
attribute :data, Types::Strict::Array.of(ModelCard)
|
38
|
+
end
|
39
|
+
end
|
data/lib/mistral.rb
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'dry-struct'
|
4
|
+
require 'http'
|
5
|
+
require 'json'
|
6
|
+
require 'logger'
|
7
|
+
require 'time'
|
8
|
+
|
9
|
+
require 'http/features/line_iterable_body'
|
10
|
+
|
11
|
+
module Mistral
|
12
|
+
module Types
|
13
|
+
include Dry.Types()
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
require 'mistral/constants'
|
18
|
+
require 'mistral/exceptions'
|
19
|
+
require 'mistral/models/models'
|
20
|
+
require 'mistral/models/common'
|
21
|
+
require 'mistral/models/embeddings'
|
22
|
+
require 'mistral/models/chat_completion'
|
23
|
+
require 'mistral/version'
|
24
|
+
require 'mistral/client'
|