rasti-ai 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.github/workflows/ci.yml +44 -0
- data/.gitignore +10 -0
- data/.ruby-gemset +1 -0
- data/.ruby-version +1 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +21 -0
- data/README.md +107 -0
- data/Rakefile +25 -0
- data/lib/rasti/ai/errors.rb +39 -0
- data/lib/rasti/ai/open_ai/assistant.rb +96 -0
- data/lib/rasti/ai/open_ai/assistant_state.rb +27 -0
- data/lib/rasti/ai/open_ai/client.rb +57 -0
- data/lib/rasti/ai/open_ai/roles.rb +14 -0
- data/lib/rasti/ai/open_ai/tool_serializer.rb +111 -0
- data/lib/rasti/ai/version.rb +5 -0
- data/lib/rasti/ai.rb +24 -0
- data/lib/rasti-ai.rb +1 -0
- data/rasti-ai.gemspec +33 -0
- data/spec/coverage_helper.rb +2 -0
- data/spec/minitest_helper.rb +26 -0
- data/spec/open_ai/assistant_spec.rb +281 -0
- data/spec/open_ai/client_spec.rb +156 -0
- data/spec/open_ai/tool_serializer_spec.rb +297 -0
- data/spec/resources/open_ai/basic_request.json +1 -0
- data/spec/resources/open_ai/basic_response.json +36 -0
- data/spec/resources/open_ai/tool_request.json +1 -0
- data/spec/resources/open_ai/tool_response.json +46 -0
- data/spec/support/helpers/erb.rb +17 -0
- data/spec/support/helpers/resources.rb +23 -0
- metadata +243 -0
@@ -0,0 +1,281 @@
|
|
1
|
+
require 'minitest_helper'
|
2
|
+
|
3
|
+
describe Rasti::AI::OpenAI::Assistant do
|
4
|
+
|
5
|
+
let(:api_url) { 'https://api.openai.com/v1/chat/completions' }
|
6
|
+
|
7
|
+
let(:question) { 'How many goals has Messi scored for Barca?' }
|
8
|
+
|
9
|
+
let(:answer) { 'Lionel Messi scored 672 goals in 778 official matches for FC Barcelona.' }
|
10
|
+
|
11
|
+
def stub_open_ai_chat_completions(model:nil, question:, answer:)
|
12
|
+
model ||= Rasti::AI.openai_default_model
|
13
|
+
|
14
|
+
stub_request(:post, api_url)
|
15
|
+
.with(body: read_resource('open_ai/basic_request.json', model: model, prompt: question))
|
16
|
+
.to_return(body: read_resource('open_ai/basic_response.json', content: answer))
|
17
|
+
end
|
18
|
+
|
19
|
+
|
20
|
+
it 'Default' do
|
21
|
+
stub_open_ai_chat_completions question: question, answer: answer
|
22
|
+
|
23
|
+
assistant = Rasti::AI::OpenAI::Assistant.new
|
24
|
+
|
25
|
+
response = assistant.call question
|
26
|
+
|
27
|
+
assert_equal answer, response
|
28
|
+
end
|
29
|
+
|
30
|
+
describe 'Customized' do
|
31
|
+
|
32
|
+
it 'Client' do
|
33
|
+
client_arguments = [
|
34
|
+
{
|
35
|
+
model: nil,
|
36
|
+
tools: [],
|
37
|
+
messages: [
|
38
|
+
{
|
39
|
+
role: Rasti::AI::OpenAI::Roles::USER,
|
40
|
+
content: question
|
41
|
+
}
|
42
|
+
]
|
43
|
+
}
|
44
|
+
]
|
45
|
+
|
46
|
+
client_response = read_json_resource 'open_ai/basic_response.json', content: answer
|
47
|
+
|
48
|
+
client = Minitest::Mock.new
|
49
|
+
client.expect :chat_completions, client_response, client_arguments
|
50
|
+
|
51
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client
|
52
|
+
|
53
|
+
response = assistant.call question
|
54
|
+
|
55
|
+
assert_equal answer, response
|
56
|
+
|
57
|
+
client.verify
|
58
|
+
end
|
59
|
+
|
60
|
+
it 'State' do
|
61
|
+
context = 'Act as sports journalist'
|
62
|
+
state = Rasti::AI::OpenAI::AssistantState.new context: context
|
63
|
+
|
64
|
+
request_body = {
|
65
|
+
model: Rasti::AI.openai_default_model,
|
66
|
+
messages: [
|
67
|
+
{
|
68
|
+
role: Rasti::AI::OpenAI::Roles::SYSTEM,
|
69
|
+
content: context
|
70
|
+
},
|
71
|
+
{
|
72
|
+
role: Rasti::AI::OpenAI::Roles::USER,
|
73
|
+
content: question
|
74
|
+
}
|
75
|
+
],
|
76
|
+
tools: [],
|
77
|
+
tool_choice: 'none'
|
78
|
+
}
|
79
|
+
|
80
|
+
stub_request(:post, api_url)
|
81
|
+
.with(body: JSON.dump(request_body))
|
82
|
+
.to_return(body: read_resource('open_ai/basic_response.json', content: answer))
|
83
|
+
|
84
|
+
assistant = Rasti::AI::OpenAI::Assistant.new state: state
|
85
|
+
|
86
|
+
response = assistant.call question
|
87
|
+
|
88
|
+
expected_assistant_message = {
|
89
|
+
role: Rasti::AI::OpenAI::Roles::ASSISTANT,
|
90
|
+
content: answer
|
91
|
+
}
|
92
|
+
|
93
|
+
assert_equal answer, response
|
94
|
+
assert_equal 3, state.messages.count
|
95
|
+
assert_equal expected_assistant_message, state.messages.last
|
96
|
+
end
|
97
|
+
|
98
|
+
it 'Model' do
|
99
|
+
model = SecureRandom.uuid
|
100
|
+
|
101
|
+
stub_open_ai_chat_completions question: question, answer: answer, model: model
|
102
|
+
|
103
|
+
assistant = Rasti::AI::OpenAI::Assistant.new model: model
|
104
|
+
|
105
|
+
response = assistant.call question
|
106
|
+
|
107
|
+
assert_equal answer, response
|
108
|
+
end
|
109
|
+
|
110
|
+
end
|
111
|
+
|
112
|
+
describe 'Tools' do
|
113
|
+
|
114
|
+
class GoalsByPlayer
|
115
|
+
def self.form
|
116
|
+
Rasti::Form[player: Rasti::Types::String, team: Rasti::Types::String]
|
117
|
+
end
|
118
|
+
|
119
|
+
def call(params={})
|
120
|
+
'672'
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
let(:client) { Minitest::Mock.new }
|
125
|
+
|
126
|
+
let(:tool_response) do
|
127
|
+
read_json_resource(
|
128
|
+
'open_ai/tool_response.json',
|
129
|
+
name: 'goals_by_player',
|
130
|
+
arguments: {
|
131
|
+
player: 'Lionel Messi',
|
132
|
+
team: 'Barcelona'
|
133
|
+
}
|
134
|
+
)
|
135
|
+
end
|
136
|
+
|
137
|
+
let(:tool_result) { '672' }
|
138
|
+
|
139
|
+
let(:error_message) { 'There was an error using a tool' }
|
140
|
+
|
141
|
+
def basic_response(content)
|
142
|
+
read_json_resource(
|
143
|
+
'open_ai/basic_response.json',
|
144
|
+
content: content
|
145
|
+
)
|
146
|
+
end
|
147
|
+
|
148
|
+
def stub_client_request(role:, content:, response:, tools:[])
|
149
|
+
client.expect :chat_completions, response do |params|
|
150
|
+
last_message = params[:messages].last
|
151
|
+
last_message[:role] == role &&
|
152
|
+
last_message[:content] == content &&
|
153
|
+
params[:tools] == tools.map { |t| Rasti::AI::OpenAI::ToolSerializer.serialize t.class }
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
it 'Call funcion' do
|
158
|
+
tool = GoalsByPlayer.new
|
159
|
+
|
160
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::USER,
|
161
|
+
content: question,
|
162
|
+
tools: [tool],
|
163
|
+
response: tool_response
|
164
|
+
|
165
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::TOOL,
|
166
|
+
content: tool_result,
|
167
|
+
tools: [tool],
|
168
|
+
response: basic_response(answer)
|
169
|
+
|
170
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client, tools: [tool]
|
171
|
+
|
172
|
+
response = assistant.call question
|
173
|
+
|
174
|
+
assert_equal answer, response
|
175
|
+
|
176
|
+
client.verify
|
177
|
+
end
|
178
|
+
|
179
|
+
it 'Tool failure' do
|
180
|
+
tool = GoalsByPlayer.new
|
181
|
+
tool.define_singleton_method :call do |*args|
|
182
|
+
raise 'Broken tool'
|
183
|
+
end
|
184
|
+
|
185
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::USER,
|
186
|
+
content: question,
|
187
|
+
tools: [tool],
|
188
|
+
response: tool_response
|
189
|
+
|
190
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::TOOL,
|
191
|
+
content: 'Error: Broken tool',
|
192
|
+
tools: [tool],
|
193
|
+
response: basic_response(error_message)
|
194
|
+
|
195
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client, tools: [tool]
|
196
|
+
|
197
|
+
response = assistant.call question
|
198
|
+
|
199
|
+
assert_equal error_message, response
|
200
|
+
|
201
|
+
client.verify
|
202
|
+
end
|
203
|
+
|
204
|
+
it 'Undefined tool' do
|
205
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::USER,
|
206
|
+
content: question,
|
207
|
+
response: tool_response
|
208
|
+
|
209
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::TOOL,
|
210
|
+
content: 'Error: Undefined tool goals_by_player',
|
211
|
+
response: basic_response(error_message)
|
212
|
+
|
213
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client, tools: []
|
214
|
+
|
215
|
+
response = assistant.call question
|
216
|
+
|
217
|
+
assert_equal error_message, response
|
218
|
+
|
219
|
+
client.verify
|
220
|
+
end
|
221
|
+
|
222
|
+
it 'Cached result' do
|
223
|
+
mock = Minitest::Mock.new
|
224
|
+
mock.expect :call, tool_result, [{'player' => 'Lionel Messi', 'team' => 'Barcelona'}]
|
225
|
+
|
226
|
+
tool = GoalsByPlayer.new
|
227
|
+
tool.define_singleton_method :call do |*args|
|
228
|
+
mock.call(*args)
|
229
|
+
end
|
230
|
+
|
231
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client, tools: [tool]
|
232
|
+
|
233
|
+
5.times do
|
234
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::USER,
|
235
|
+
content: question,
|
236
|
+
tools: [tool],
|
237
|
+
response: tool_response
|
238
|
+
|
239
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::TOOL,
|
240
|
+
content: tool_result,
|
241
|
+
tools: [tool],
|
242
|
+
response: basic_response(answer)
|
243
|
+
|
244
|
+
response = assistant.call question
|
245
|
+
|
246
|
+
assert_equal answer, response
|
247
|
+
end
|
248
|
+
|
249
|
+
client.verify
|
250
|
+
end
|
251
|
+
|
252
|
+
it 'Custom logger' do
|
253
|
+
log_output = StringIO.new
|
254
|
+
logger = Logger.new log_output
|
255
|
+
|
256
|
+
tool = GoalsByPlayer.new
|
257
|
+
|
258
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::USER,
|
259
|
+
content: question,
|
260
|
+
tools: [tool],
|
261
|
+
response: tool_response
|
262
|
+
|
263
|
+
stub_client_request role: Rasti::AI::OpenAI::Roles::TOOL,
|
264
|
+
content: tool_result,
|
265
|
+
tools: [tool],
|
266
|
+
response: basic_response(answer)
|
267
|
+
|
268
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client, tools: [tool], logger: logger
|
269
|
+
|
270
|
+
response = assistant.call question
|
271
|
+
|
272
|
+
assert_equal answer, response
|
273
|
+
|
274
|
+
refute_empty log_output.string
|
275
|
+
|
276
|
+
client.verify
|
277
|
+
end
|
278
|
+
|
279
|
+
end
|
280
|
+
|
281
|
+
end
|
@@ -0,0 +1,156 @@
|
|
1
|
+
require 'minitest_helper'
|
2
|
+
|
3
|
+
describe Rasti::AI::OpenAI::Client do
|
4
|
+
|
5
|
+
let(:api_url) { 'https://api.openai.com/v1/chat/completions' }
|
6
|
+
|
7
|
+
def user_message(content)
|
8
|
+
{
|
9
|
+
role: Rasti::AI::OpenAI::Roles::USER,
|
10
|
+
content: content
|
11
|
+
}
|
12
|
+
end
|
13
|
+
|
14
|
+
describe 'Basic message' do
|
15
|
+
|
16
|
+
let(:question) { 'who is Messi?' }
|
17
|
+
|
18
|
+
let(:answer) { 'Lionel Messi is the best player ever' }
|
19
|
+
|
20
|
+
def stub_open_ai_chat_completions(api_key:nil, model:nil)
|
21
|
+
api_key ||= Rasti::AI.openai_api_key
|
22
|
+
model ||= Rasti::AI.openai_default_model
|
23
|
+
|
24
|
+
stub_request(:post, api_url)
|
25
|
+
.with(
|
26
|
+
headers: {'Authorization' => "Bearer #{api_key}"},
|
27
|
+
body: read_resource('open_ai/basic_request.json', model: model, prompt: question)
|
28
|
+
)
|
29
|
+
.to_return(body: read_resource('open_ai/basic_response.json', content: answer))
|
30
|
+
end
|
31
|
+
|
32
|
+
def assert_response_content(response, expected_content)
|
33
|
+
assert_equal expected_content, response.dig('choices', 0, 'message', 'content')
|
34
|
+
end
|
35
|
+
|
36
|
+
it 'Default API key, model and logger' do
|
37
|
+
stub_open_ai_chat_completions
|
38
|
+
|
39
|
+
client = Rasti::AI::OpenAI::Client.new
|
40
|
+
|
41
|
+
response = client.chat_completions messages: [user_message(question)]
|
42
|
+
|
43
|
+
assert_response_content response, answer
|
44
|
+
end
|
45
|
+
|
46
|
+
it 'Custom API key' do
|
47
|
+
custom_api_key = SecureRandom.uuid
|
48
|
+
|
49
|
+
stub_open_ai_chat_completions api_key: custom_api_key
|
50
|
+
|
51
|
+
client = Rasti::AI::OpenAI::Client.new api_key: custom_api_key
|
52
|
+
|
53
|
+
response = client.chat_completions messages: [user_message(question)]
|
54
|
+
|
55
|
+
assert_response_content response, answer
|
56
|
+
end
|
57
|
+
|
58
|
+
it 'Custom model' do
|
59
|
+
custom_model = SecureRandom.uuid
|
60
|
+
|
61
|
+
stub_open_ai_chat_completions model: custom_model
|
62
|
+
|
63
|
+
client = Rasti::AI::OpenAI::Client.new
|
64
|
+
|
65
|
+
response = client.chat_completions messages: [user_message(question)],
|
66
|
+
model: custom_model
|
67
|
+
|
68
|
+
assert_response_content response, answer
|
69
|
+
end
|
70
|
+
|
71
|
+
it 'Custom logger' do
|
72
|
+
log_output = StringIO.new
|
73
|
+
logger = Logger.new log_output
|
74
|
+
|
75
|
+
stub_open_ai_chat_completions
|
76
|
+
|
77
|
+
client = Rasti::AI::OpenAI::Client.new logger: logger
|
78
|
+
|
79
|
+
response = client.chat_completions messages: [user_message(question)]
|
80
|
+
|
81
|
+
assert_response_content response, answer
|
82
|
+
|
83
|
+
refute_empty log_output.string
|
84
|
+
end
|
85
|
+
|
86
|
+
end
|
87
|
+
|
88
|
+
it 'Request error' do
|
89
|
+
stub_request(:post, api_url)
|
90
|
+
.to_return(status: 400, body: '{"error": {"message": "Test error"}}')
|
91
|
+
|
92
|
+
client = Rasti::AI::OpenAI::Client.new
|
93
|
+
|
94
|
+
error = assert_raises(Rasti::AI::Errors::RequestFail) do
|
95
|
+
client.chat_completions messages: ['invalid message']
|
96
|
+
end
|
97
|
+
|
98
|
+
assert_includes error.message, 'Response: 400'
|
99
|
+
end
|
100
|
+
|
101
|
+
it 'Tool call' do
|
102
|
+
question = 'how many goals did messi for barca'
|
103
|
+
|
104
|
+
tool_name = 'player_goals'
|
105
|
+
|
106
|
+
tool = {
|
107
|
+
type: 'function',
|
108
|
+
function: {
|
109
|
+
name: tool_name,
|
110
|
+
description: 'Gets the number of goals scored by a player for a specific team',
|
111
|
+
parameters: {
|
112
|
+
type: 'object',
|
113
|
+
properties: {
|
114
|
+
name: {
|
115
|
+
type: 'string',
|
116
|
+
description: 'Full name of the player'
|
117
|
+
},
|
118
|
+
team: {
|
119
|
+
type: 'string',
|
120
|
+
description: 'Name of the team the player was part of'
|
121
|
+
}
|
122
|
+
},
|
123
|
+
required: ['name', 'team']
|
124
|
+
}
|
125
|
+
}
|
126
|
+
}
|
127
|
+
|
128
|
+
arguments = {
|
129
|
+
name: 'Lionel Messi',
|
130
|
+
team: 'FC Barcelona'
|
131
|
+
}
|
132
|
+
|
133
|
+
stub_request(:post, api_url)
|
134
|
+
.with(
|
135
|
+
headers: {'Authorization' => "Bearer #{Rasti::AI.openai_api_key}"},
|
136
|
+
body: read_resource(
|
137
|
+
'open_ai/tool_request.json',
|
138
|
+
model: Rasti::AI.openai_default_model,
|
139
|
+
prompt: question,
|
140
|
+
tools: [tool]
|
141
|
+
)
|
142
|
+
)
|
143
|
+
.to_return(body: read_resource('open_ai/tool_response.json', name: tool_name, arguments: arguments))
|
144
|
+
|
145
|
+
client = Rasti::AI::OpenAI::Client.new
|
146
|
+
|
147
|
+
response = client.chat_completions messages: [user_message(question)],
|
148
|
+
tools: [tool]
|
149
|
+
|
150
|
+
tool_call = response.dig('choices', 0, 'message', 'tool_calls', 0, 'function')
|
151
|
+
|
152
|
+
assert_equal tool_name, tool_call['name']
|
153
|
+
assert_equal JSON.dump(arguments), tool_call['arguments']
|
154
|
+
end
|
155
|
+
|
156
|
+
end
|