mistral 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.env.example +3 -0
- data/.rubocop.yml +60 -0
- data/.tool-versions +1 -0
- data/CHANGELOG.md +12 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +21 -0
- data/PYTHON_CLIENT_COMPARISON.md +184 -0
- data/README.md +145 -0
- data/Rakefile +12 -0
- data/examples/chat_no_streaming.rb +18 -0
- data/examples/chat_with_streaming.rb +18 -0
- data/examples/chatbot_with_streaming.rb +289 -0
- data/examples/embeddings.rb +16 -0
- data/examples/function_calling.rb +104 -0
- data/examples/json_format.rb +21 -0
- data/examples/list_models.rb +13 -0
- data/lib/http/features/line_iterable_body.rb +35 -0
- data/lib/mistral/client.rb +229 -0
- data/lib/mistral/client_base.rb +126 -0
- data/lib/mistral/constants.rb +6 -0
- data/lib/mistral/exceptions.rb +38 -0
- data/lib/mistral/models/chat_completion.rb +95 -0
- data/lib/mistral/models/common.rb +11 -0
- data/lib/mistral/models/embeddings.rb +21 -0
- data/lib/mistral/models/models.rb +39 -0
- data/lib/mistral/version.rb +5 -0
- data/lib/mistral.rb +24 -0
- metadata +172 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'mistral/client_base'
|
4
|
+
|
5
|
+
module Mistral
|
6
|
+
# Synchronous wrapper around the async client
|
7
|
+
class Client < ClientBase
|
8
|
+
def initialize(
|
9
|
+
api_key: ENV['MISTRAL_API_KEY'],
|
10
|
+
endpoint: ENDPOINT,
|
11
|
+
max_retries: 5,
|
12
|
+
timeout: 120
|
13
|
+
)
|
14
|
+
super(endpoint: endpoint, api_key: api_key, max_retries: max_retries, timeout: timeout)
|
15
|
+
|
16
|
+
@client = HTTP.persistent(ENDPOINT)
|
17
|
+
.follow
|
18
|
+
.timeout(timeout)
|
19
|
+
.use(:line_iterable_body)
|
20
|
+
.headers('Accept' => 'application/json',
|
21
|
+
'User-Agent' => "mistral-client-ruby/#{VERSION}",
|
22
|
+
'Authorization' => "Bearer #{@api_key}",
|
23
|
+
'Content-Type' => 'application/json'
|
24
|
+
)
|
25
|
+
end
|
26
|
+
|
27
|
+
# A chat endpoint that returns a single response.
|
28
|
+
#
|
29
|
+
# @param messages [Array<ChatMessage>] An array of messages to chat with, e.g.
|
30
|
+
# [{role: 'user', content: 'What is the best French cheese?'}]
|
31
|
+
# @param model [String] The name of the model to chat with, e.g. mistral-tiny
|
32
|
+
# @param tools [Array<Hash>] A list of tools to use.
|
33
|
+
# @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
|
34
|
+
# @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
|
35
|
+
# @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
|
36
|
+
# @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
|
37
|
+
# @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
|
38
|
+
# @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
|
39
|
+
# @param tool_choice [String, ToolChoice] The tool choice.
|
40
|
+
# @param response_format [Hash<String, String>, ResponseFormat] The response format.
|
41
|
+
# @return [ChatCompletionResponse] A response object containing the generated text.
|
42
|
+
#
|
43
|
+
def chat(
|
44
|
+
messages:,
|
45
|
+
model: nil,
|
46
|
+
tools: nil,
|
47
|
+
temperature: nil,
|
48
|
+
max_tokens: nil,
|
49
|
+
top_p: nil,
|
50
|
+
random_seed: nil,
|
51
|
+
safe_mode: false,
|
52
|
+
safe_prompt: false,
|
53
|
+
tool_choice: nil,
|
54
|
+
response_format: nil
|
55
|
+
)
|
56
|
+
request = make_chat_request(
|
57
|
+
messages: messages,
|
58
|
+
model: model,
|
59
|
+
tools: tools,
|
60
|
+
temperature: temperature,
|
61
|
+
max_tokens: max_tokens,
|
62
|
+
top_p: top_p,
|
63
|
+
random_seed: random_seed,
|
64
|
+
stream: false,
|
65
|
+
safe_prompt: safe_mode || safe_prompt,
|
66
|
+
tool_choice: tool_choice,
|
67
|
+
response_format: response_format
|
68
|
+
)
|
69
|
+
|
70
|
+
single_response = request('post', 'v1/chat/completions', json: request)
|
71
|
+
|
72
|
+
single_response.each do |response|
|
73
|
+
return ChatCompletionResponse.new(response)
|
74
|
+
end
|
75
|
+
|
76
|
+
raise Mistral::Error.new(message: 'No response received')
|
77
|
+
end
|
78
|
+
|
79
|
+
# A chat endpoint that streams responses.
|
80
|
+
#
|
81
|
+
# @param messages [Array<Any>] An array of messages to chat with, e.g.
|
82
|
+
# [{role: 'user', content: 'What is the best French cheese?'}]
|
83
|
+
# @param model [String] The name of the model to chat with, e.g. mistral-tiny
|
84
|
+
# @param tools [Array<Hash>] A list of tools to use.
|
85
|
+
# @param temperature [Float] The temperature to use for sampling, e.g. 0.5.
|
86
|
+
# @param max_tokens [Integer] The maximum number of tokens to generate, e.g. 100.
|
87
|
+
# @param top_p [Float] The cumulative probability of tokens to generate, e.g. 0.9.
|
88
|
+
# @param random_seed [Integer] The random seed to use for sampling, e.g. 42.
|
89
|
+
# @param safe_mode [Boolean] Deprecated, use safe_prompt instead.
|
90
|
+
# @param safe_prompt [Boolean] Whether to use safe prompt, e.g. true.
|
91
|
+
# @param tool_choice [String, ToolChoice] The tool choice.
|
92
|
+
# @param response_format [Hash<String, String>, ResponseFormat] The response format.
|
93
|
+
# @return [Enumerator<ChatCompletionStreamResponse>] A generator that yields ChatCompletionStreamResponse objects.
|
94
|
+
#
|
95
|
+
def chat_stream(
|
96
|
+
messages:,
|
97
|
+
model: nil,
|
98
|
+
tools: nil,
|
99
|
+
temperature: nil,
|
100
|
+
max_tokens: nil,
|
101
|
+
top_p: nil,
|
102
|
+
random_seed: nil,
|
103
|
+
safe_mode: false,
|
104
|
+
safe_prompt: false,
|
105
|
+
tool_choice: nil,
|
106
|
+
response_format: nil
|
107
|
+
)
|
108
|
+
request = make_chat_request(
|
109
|
+
messages: messages,
|
110
|
+
model: model,
|
111
|
+
tools: tools,
|
112
|
+
temperature: temperature,
|
113
|
+
max_tokens: max_tokens,
|
114
|
+
top_p: top_p,
|
115
|
+
random_seed: random_seed,
|
116
|
+
stream: true,
|
117
|
+
safe_prompt: safe_mode || safe_prompt,
|
118
|
+
tool_choice: tool_choice,
|
119
|
+
response_format: response_format
|
120
|
+
)
|
121
|
+
|
122
|
+
Enumerator.new do |yielder|
|
123
|
+
request('post', 'v1/chat/completions', json: request, stream: true).each do |json_response|
|
124
|
+
yielder << ChatCompletionStreamResponse.new(**json_response)
|
125
|
+
end
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
# An embeddings endpoint that returns embeddings for a single, or batch of inputs
|
130
|
+
#
|
131
|
+
# @param model [String] The embedding model to use, e.g. mistral-embed
|
132
|
+
# @param input [String, Array<String>] The input to embed, e.g. ['What is the best French cheese?']
|
133
|
+
#
|
134
|
+
# @return [EmbeddingResponse] A response object containing the embeddings.
|
135
|
+
#
|
136
|
+
def embeddings(model:, input:)
|
137
|
+
request = { model: model, input: input }
|
138
|
+
singleton_response = request('post', 'v1/embeddings', json: request)
|
139
|
+
|
140
|
+
singleton_response.each do |response|
|
141
|
+
return EmbeddingResponse.new(response)
|
142
|
+
end
|
143
|
+
|
144
|
+
raise Mistral::Error.new(message: 'No response received')
|
145
|
+
end
|
146
|
+
|
147
|
+
# Returns a list of the available models
|
148
|
+
#
|
149
|
+
# @return [ModelList] A response object containing the list of models.
|
150
|
+
#
|
151
|
+
def list_models
|
152
|
+
singleton_response = request('get', 'v1/models')
|
153
|
+
|
154
|
+
singleton_response.each do |response|
|
155
|
+
return ModelList.new(response)
|
156
|
+
end
|
157
|
+
|
158
|
+
raise Mistral::Error.new(message: 'No response received')
|
159
|
+
end
|
160
|
+
|
161
|
+
private
|
162
|
+
|
163
|
+
def request(method, path, json: nil, stream: false, attempt: 1)
|
164
|
+
url = File.join(@endpoint, path)
|
165
|
+
headers = {}
|
166
|
+
headers['Accept'] = 'text/event-stream' if stream
|
167
|
+
|
168
|
+
@logger.debug("Sending request: #{method.upcase} #{url} #{json}")
|
169
|
+
|
170
|
+
Enumerator.new do |yielder|
|
171
|
+
response = @client.headers(headers).request(method.downcase.to_sym, url, json: json)
|
172
|
+
check_response_status_codes(response)
|
173
|
+
|
174
|
+
if stream
|
175
|
+
response.body.each_line do |line|
|
176
|
+
processed_line = process_line(line)
|
177
|
+
next if processed_line.nil?
|
178
|
+
|
179
|
+
yielder << processed_line
|
180
|
+
end
|
181
|
+
else
|
182
|
+
yielder << check_response(response)
|
183
|
+
end
|
184
|
+
rescue HTTP::ConnectionError => e
|
185
|
+
raise Mistral::ConnectionError, e.message
|
186
|
+
rescue HTTP::RequestError => e
|
187
|
+
raise Mistral::Error, "Unexpected exception (#{e.class}): #{e.message}"
|
188
|
+
rescue JSON::ParserError
|
189
|
+
raise Mistral::APIError.from_response(response, message: "Failed to decode json body: #{response.body}")
|
190
|
+
rescue Mistral::APIStatusError => e
|
191
|
+
attempt += 1
|
192
|
+
|
193
|
+
raise Mistral::APIStatusError.from_response(response, message: e.message) if attempt > @max_retries
|
194
|
+
|
195
|
+
backoff = 2.0**attempt # exponential backoff
|
196
|
+
sleep(backoff)
|
197
|
+
|
198
|
+
# Retry and yield the response
|
199
|
+
request(method, path, json: json, stream: stream, attempt: attempt).each do |r|
|
200
|
+
yielder << r
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
def check_response(response)
|
206
|
+
check_response_status_codes(response)
|
207
|
+
|
208
|
+
json_response = JSON.parse(response.body.to_s)
|
209
|
+
|
210
|
+
if !json_response.key?('object')
|
211
|
+
raise Mistral::Error, "Unexpected response: #{json_response}"
|
212
|
+
elsif json_response['object'] == 'error' # has errors
|
213
|
+
raise Mistral::APIError.from_response(response, message: json_response['message'])
|
214
|
+
end
|
215
|
+
|
216
|
+
json_response
|
217
|
+
end
|
218
|
+
|
219
|
+
def check_response_status_codes(response)
|
220
|
+
if RETRY_STATUS_CODES.include?(response.code)
|
221
|
+
raise APIStatusError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
|
222
|
+
elsif response.code >= 400 && response.code < 500
|
223
|
+
raise APIError.from_response(response, message: "Status: #{response.code}. Message: #{response.body}")
|
224
|
+
elsif response.code >= 500
|
225
|
+
raise Mistral::Error.new(message: "Status: #{response.code}. Message: #{response.body}")
|
226
|
+
end
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
@@ -0,0 +1,126 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class ClientBase
|
5
|
+
attr_reader :endpoint, :api_key, :max_retries, :timeout
|
6
|
+
|
7
|
+
def initialize(endpoint:, api_key: nil, max_retries: 5, timeout: 120)
|
8
|
+
@endpoint = endpoint
|
9
|
+
@api_key = api_key
|
10
|
+
@max_retries = max_retries
|
11
|
+
@timeout = timeout
|
12
|
+
|
13
|
+
@logger = config_logger
|
14
|
+
|
15
|
+
# For azure endpoints, we default to the mistral model
|
16
|
+
@default_model = 'mistral' if endpoint.include?('inference.azure.com')
|
17
|
+
end
|
18
|
+
|
19
|
+
protected
|
20
|
+
|
21
|
+
def parse_tools(tools)
|
22
|
+
parsed_tools = []
|
23
|
+
|
24
|
+
tools.each do |tool|
|
25
|
+
next unless tool['type'] == 'function'
|
26
|
+
|
27
|
+
parsed_function = {}
|
28
|
+
parsed_function['type'] = tool['type']
|
29
|
+
parsed_function['function'] = if tool['function'].is_a?(Function)
|
30
|
+
tool['function'].to_h
|
31
|
+
else
|
32
|
+
tool['function']
|
33
|
+
end
|
34
|
+
|
35
|
+
parsed_tools << parsed_function
|
36
|
+
end
|
37
|
+
|
38
|
+
parsed_tools
|
39
|
+
end
|
40
|
+
|
41
|
+
def parse_tool_choice(tool_choice)
|
42
|
+
tool_choice.is_a?(ToolChoice) ? tool_choice.to_s : tool_choice
|
43
|
+
end
|
44
|
+
|
45
|
+
def parse_response_format(response_format)
|
46
|
+
if response_format.is_a?(ResponseFormat)
|
47
|
+
response_format.to_h
|
48
|
+
else
|
49
|
+
response_format
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
def parse_messages(messages)
|
54
|
+
parsed_messages = []
|
55
|
+
|
56
|
+
messages.each do |message|
|
57
|
+
parsed_messages << if message.is_a?(ChatMessage)
|
58
|
+
message.to_h
|
59
|
+
else
|
60
|
+
message
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
parsed_messages
|
65
|
+
end
|
66
|
+
|
67
|
+
def make_chat_request(
|
68
|
+
messages:,
|
69
|
+
model: nil,
|
70
|
+
tools: nil,
|
71
|
+
temperature: nil,
|
72
|
+
max_tokens: nil,
|
73
|
+
top_p: nil,
|
74
|
+
random_seed: nil,
|
75
|
+
stream: nil,
|
76
|
+
safe_prompt: false,
|
77
|
+
tool_choice: nil,
|
78
|
+
response_format: nil
|
79
|
+
)
|
80
|
+
request_data = {
|
81
|
+
messages: parse_messages(messages),
|
82
|
+
safe_prompt: safe_prompt
|
83
|
+
}
|
84
|
+
|
85
|
+
request_data[:model] = if model.nil?
|
86
|
+
raise Mistral::Error.new(message: 'model must be provided') if @default_model.nil?
|
87
|
+
|
88
|
+
@default_model
|
89
|
+
else
|
90
|
+
model
|
91
|
+
end
|
92
|
+
|
93
|
+
request_data[:tools] = parse_tools(tools) unless tools.nil?
|
94
|
+
request_data[:temperature] = temperature unless temperature.nil?
|
95
|
+
request_data[:max_tokens] = max_tokens unless max_tokens.nil?
|
96
|
+
request_data[:top_p] = top_p unless top_p.nil?
|
97
|
+
request_data[:random_seed] = random_seed unless random_seed.nil?
|
98
|
+
request_data[:stream] = stream unless stream.nil?
|
99
|
+
request_data[:tool_choice] = parse_tool_choice(tool_choice) unless tool_choice.nil?
|
100
|
+
request_data[:response_format] = parse_response_format(response_format) unless response_format.nil?
|
101
|
+
|
102
|
+
@logger.debug("Chat request: #{request_data}")
|
103
|
+
|
104
|
+
request_data
|
105
|
+
end
|
106
|
+
|
107
|
+
def process_line(line)
|
108
|
+
return unless line.start_with?('data: ')
|
109
|
+
|
110
|
+
line = line[6..].to_s.strip
|
111
|
+
return if line == '[DONE]'
|
112
|
+
|
113
|
+
JSON.parse(line)
|
114
|
+
end
|
115
|
+
|
116
|
+
def config_logger
|
117
|
+
Logger.new($stdout).tap do |logger|
|
118
|
+
logger.level = ENV.fetch('LOG_LEVEL', 'ERROR')
|
119
|
+
|
120
|
+
logger.formatter = proc do |severity, datetime, progname, msg|
|
121
|
+
"#{datetime.strftime("%Y-%m-%d %H:%M:%S")} #{severity} #{progname}: #{msg}\n"
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
# Base error class, returned when nothing more specific applies
|
5
|
+
class Error < StandardError
|
6
|
+
end
|
7
|
+
|
8
|
+
class APIError < Error
|
9
|
+
attr_reader :http_status, :headers
|
10
|
+
|
11
|
+
def initialize(message: nil, http_status: nil, headers: nil)
|
12
|
+
super(message: message)
|
13
|
+
|
14
|
+
@http_status = http_status
|
15
|
+
@headers = headers
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.from_response(response, message: nil)
|
19
|
+
new(
|
20
|
+
message: message || response.to_s,
|
21
|
+
http_status: response.code,
|
22
|
+
headers: response.headers.to_h
|
23
|
+
)
|
24
|
+
end
|
25
|
+
|
26
|
+
def to_s
|
27
|
+
"#{self.class.name}(message=#{super}, http_status=#{http_status})"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
# Returned when we receive a non-200 response from the API that we should retry
|
32
|
+
class APIStatusError < APIError
|
33
|
+
end
|
34
|
+
|
35
|
+
# Returned when the SDK can not reach the API server for any reason
|
36
|
+
class ConnectionError < Error
|
37
|
+
end
|
38
|
+
end
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class Function < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :name, Types::Strict::String
|
8
|
+
attribute :description, Types::Strict::String
|
9
|
+
attribute :parameters, Types::Strict::Hash
|
10
|
+
end
|
11
|
+
|
12
|
+
ToolType = Types::Strict::String.default('function').enum('function')
|
13
|
+
|
14
|
+
class FunctionCall < Dry::Struct
|
15
|
+
transform_keys(&:to_sym)
|
16
|
+
|
17
|
+
attribute :name, Types::Strict::String
|
18
|
+
attribute :arguments, Types::Strict::String
|
19
|
+
end
|
20
|
+
|
21
|
+
class ToolCall < Dry::Struct
|
22
|
+
transform_keys(&:to_sym)
|
23
|
+
|
24
|
+
attribute :id, Types::Strict::String.default('null')
|
25
|
+
attribute :type, ToolType
|
26
|
+
attribute :function, FunctionCall
|
27
|
+
end
|
28
|
+
|
29
|
+
ResponseFormats = Types::Strict::String.default('text').enum('text', 'json_object')
|
30
|
+
|
31
|
+
ToolChoice = Types::Strict::String.enum('auto', 'any', 'none')
|
32
|
+
|
33
|
+
class ResponseFormat < Dry::Struct
|
34
|
+
transform_keys(&:to_sym)
|
35
|
+
|
36
|
+
attribute :type, ResponseFormats
|
37
|
+
end
|
38
|
+
|
39
|
+
class ChatMessage < Dry::Struct
|
40
|
+
transform_keys(&:to_sym)
|
41
|
+
|
42
|
+
attribute :role, Types::Strict::String
|
43
|
+
attribute :content, Types::Strict::Array.of(Types::Strict::String) | Types::Strict::String
|
44
|
+
attribute? :name, Types::String.optional
|
45
|
+
attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
|
46
|
+
end
|
47
|
+
|
48
|
+
class DeltaMessage < Dry::Struct
|
49
|
+
transform_keys(&:to_sym)
|
50
|
+
|
51
|
+
attribute? :role, Types::Strict::String.optional
|
52
|
+
attribute? :content, Types::Strict::String.optional
|
53
|
+
attribute? :tool_calls, Types::Strict::Array.of(ToolCall).optional
|
54
|
+
end
|
55
|
+
|
56
|
+
FinishReason = Types::Strict::String.enum('stop', 'length', 'error', 'tool_calls')
|
57
|
+
|
58
|
+
class ChatCompletionResponseStreamChoice < Dry::Struct
|
59
|
+
transform_keys(&:to_sym)
|
60
|
+
|
61
|
+
attribute :index, Types::Strict::Integer
|
62
|
+
attribute :delta, DeltaMessage
|
63
|
+
attribute? :finish_reason, FinishReason.optional
|
64
|
+
end
|
65
|
+
|
66
|
+
class ChatCompletionStreamResponse < Dry::Struct
|
67
|
+
transform_keys(&:to_sym)
|
68
|
+
|
69
|
+
attribute :id, Types::Strict::String
|
70
|
+
attribute :model, Types::Strict::String
|
71
|
+
attribute :choices, Types::Strict::Array.of(ChatCompletionResponseStreamChoice)
|
72
|
+
attribute? :created, Types::Strict::Integer.optional
|
73
|
+
attribute? :object, Types::Strict::String.optional
|
74
|
+
attribute? :usage, UsageInfo.optional
|
75
|
+
end
|
76
|
+
|
77
|
+
class ChatCompletionResponseChoice < Dry::Struct
|
78
|
+
transform_keys(&:to_sym)
|
79
|
+
|
80
|
+
attribute :index, Types::Strict::Integer
|
81
|
+
attribute :message, ChatMessage
|
82
|
+
attribute? :finish_reason, FinishReason.optional
|
83
|
+
end
|
84
|
+
|
85
|
+
class ChatCompletionResponse < Dry::Struct
|
86
|
+
transform_keys(&:to_sym)
|
87
|
+
|
88
|
+
attribute :id, Types::Strict::String
|
89
|
+
attribute :object, Types::Strict::String
|
90
|
+
attribute :created, Types::Strict::Integer
|
91
|
+
attribute :model, Types::Strict::String
|
92
|
+
attribute :choices, Types::Strict::Array.of(ChatCompletionResponseChoice)
|
93
|
+
attribute :usage, UsageInfo
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,11 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class UsageInfo < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :prompt_tokens, Types::Strict::Integer
|
8
|
+
attribute :total_tokens, Types::Strict::Integer
|
9
|
+
attribute? :completion_tokens, Types::Strict::Integer.optional
|
10
|
+
end
|
11
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class EmbeddingObject < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :object, Types::Strict::String
|
8
|
+
attribute :embedding, Types::Strict::Array.of(Types::Strict::Float)
|
9
|
+
attribute :index, Types::Strict::Integer
|
10
|
+
end
|
11
|
+
|
12
|
+
class EmbeddingResponse < Dry::Struct
|
13
|
+
transform_keys(&:to_sym)
|
14
|
+
|
15
|
+
attribute :id, Types::Strict::String
|
16
|
+
attribute :object, Types::Strict::String
|
17
|
+
attribute :data, Types::Strict::Array.of(EmbeddingObject)
|
18
|
+
attribute :model, Types::Strict::String
|
19
|
+
attribute :usage, UsageInfo
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Mistral
|
4
|
+
class ModelPermission < Dry::Struct
|
5
|
+
transform_keys(&:to_sym)
|
6
|
+
|
7
|
+
attribute :id, Types::Strict::String
|
8
|
+
attribute :object, Types::Strict::String
|
9
|
+
attribute :created, Types::Strict::Integer
|
10
|
+
attribute :allow_create_engine, Types::Strict::Bool.default(false)
|
11
|
+
attribute :allow_sampling, Types::Strict::Bool.default(true)
|
12
|
+
attribute :allow_logprobs, Types::Strict::Bool.default(true)
|
13
|
+
attribute :allow_search_indices, Types::Strict::Bool.default(false)
|
14
|
+
attribute :allow_view, Types::Strict::Bool.default(true)
|
15
|
+
attribute :allow_fine_tuning, Types::Strict::Bool.default(false)
|
16
|
+
attribute :organization, Types::Strict::String.default('*')
|
17
|
+
attribute? :group, Types::Strict::String.optional
|
18
|
+
attribute :is_blocking, Types::Strict::Bool.default(false)
|
19
|
+
end
|
20
|
+
|
21
|
+
class ModelCard < Dry::Struct
|
22
|
+
transform_keys(&:to_sym)
|
23
|
+
|
24
|
+
attribute :id, Types::Strict::String
|
25
|
+
attribute :object, Types::Strict::String
|
26
|
+
attribute :created, Types::Strict::Integer
|
27
|
+
attribute :owned_by, Types::Strict::String
|
28
|
+
attribute? :root, Types::Strict::String.optional
|
29
|
+
attribute? :parent, Types::Strict::String.optional
|
30
|
+
attribute :permission, Types::Strict::Array.of(ModelPermission).default([].freeze)
|
31
|
+
end
|
32
|
+
|
33
|
+
class ModelList < Dry::Struct
|
34
|
+
transform_keys(&:to_sym)
|
35
|
+
|
36
|
+
attribute :object, Types::Strict::String
|
37
|
+
attribute :data, Types::Strict::Array.of(ModelCard)
|
38
|
+
end
|
39
|
+
end
|
data/lib/mistral.rb
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'dry-struct'
|
4
|
+
require 'http'
|
5
|
+
require 'json'
|
6
|
+
require 'logger'
|
7
|
+
require 'time'
|
8
|
+
|
9
|
+
require 'http/features/line_iterable_body'
|
10
|
+
|
11
|
+
module Mistral
|
12
|
+
module Types
|
13
|
+
include Dry.Types()
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
require 'mistral/constants'
|
18
|
+
require 'mistral/exceptions'
|
19
|
+
require 'mistral/models/models'
|
20
|
+
require 'mistral/models/common'
|
21
|
+
require 'mistral/models/embeddings'
|
22
|
+
require 'mistral/models/chat_completion'
|
23
|
+
require 'mistral/version'
|
24
|
+
require 'mistral/client'
|