rasti-ai 1.2.1 → 2.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +123 -42
- data/lib/rasti/ai/assistant.rb +161 -0
- data/lib/rasti/ai/assistant_state.rb +24 -0
- data/lib/rasti/ai/client.rb +81 -0
- data/lib/rasti/ai/gemini/assistant.rb +112 -0
- data/lib/rasti/ai/gemini/client.rb +35 -0
- data/lib/rasti/ai/gemini/roles.rb +13 -0
- data/lib/rasti/ai/open_ai/assistant.rb +57 -93
- data/lib/rasti/ai/open_ai/client.rb +11 -56
- data/lib/rasti/ai/usage.rb +14 -0
- data/lib/rasti/ai/version.rb +1 -1
- data/lib/rasti/ai.rb +6 -0
- data/rasti-ai.gemspec +1 -0
- data/spec/gemini/assistant_spec.rb +384 -0
- data/spec/gemini/client_spec.rb +155 -0
- data/spec/minitest_helper.rb +13 -0
- data/spec/open_ai/assistant_spec.rb +68 -10
- data/spec/resources/gemini/basic_request.json +1 -0
- data/spec/resources/gemini/basic_response.json +22 -0
- data/spec/resources/gemini/tool_request.json +1 -0
- data/spec/resources/gemini/tool_response.json +25 -0
- metadata +35 -3
- data/lib/rasti/ai/open_ai/assistant_state.rb +0 -27
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: dfad832d41e30e53127315f47de79582f43316453accba99c261bc0cf158902a
|
|
4
|
+
data.tar.gz: 0dfb5550e2e5732af21b49e9317ad844b1d962c860fee711c72e4b2e1c83a55d
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 30031fc6f74b996c9d39b72dc12027dc7ffa765e8227ea9d252dc46d92a951a2a29b06bf8ddebe33d75a17c4b71827992a078d4dc3034501f5254583f1e080b8
|
|
7
|
+
data.tar.gz: ebc68bbfeca92dd5f929ea6bb93bcb5077faca18a592ddc1e15e8f500e154c666d8cfe0de2828108d4c8d84f7640203b125c8836cd61a78c3dbad4a9b721eb8a
|
data/README.md
CHANGED
|
@@ -27,20 +27,44 @@ Or install it yourself as:
|
|
|
27
27
|
```ruby
|
|
28
28
|
Rasti::AI.configure do |config|
|
|
29
29
|
config.logger = Logger.new 'log/development.log'
|
|
30
|
+
|
|
31
|
+
# HTTP settings
|
|
32
|
+
config.http_connect_timeout = 60 # Default 60 seconds
|
|
33
|
+
config.http_read_timeout = 60 # Default 60 seconds
|
|
34
|
+
config.http_max_retries = 3 # Default 3 retries
|
|
35
|
+
|
|
36
|
+
# OpenAI
|
|
30
37
|
config.openai_api_key = 'abcd12345' # Default ENV['OPENAI_API_KEY']
|
|
31
38
|
config.openai_default_model = 'gpt-4o-mini' # Default ENV['OPENAI_DEFAULT_MODEL']
|
|
39
|
+
|
|
40
|
+
# Gemini
|
|
41
|
+
config.gemini_api_key = 'AIza12345' # Default ENV['GEMINI_API_KEY']
|
|
42
|
+
config.gemini_default_model = 'gemini-2.0-flash' # Default ENV['GEMINI_DEFAULT_MODEL']
|
|
43
|
+
|
|
44
|
+
# Usage tracking
|
|
45
|
+
config.usage_tracker = ->(usage) { puts "#{usage.provider}: #{usage.input_tokens} in / #{usage.output_tokens} out" }
|
|
32
46
|
end
|
|
33
47
|
```
|
|
34
48
|
|
|
35
|
-
###
|
|
49
|
+
### Supported providers
|
|
50
|
+
|
|
51
|
+
- **OpenAI** - `Rasti::AI::OpenAI::Assistant`
|
|
52
|
+
- **Gemini** - `Rasti::AI::Gemini::Assistant`
|
|
53
|
+
|
|
54
|
+
All providers share the same interface. The examples below use OpenAI, but apply equally to Gemini by replacing `OpenAI` with `Gemini`.
|
|
55
|
+
|
|
56
|
+
### Assistant
|
|
36
57
|
|
|
37
|
-
#### Assistant
|
|
38
58
|
```ruby
|
|
39
59
|
assistant = Rasti::AI::OpenAI::Assistant.new
|
|
40
60
|
assistant.call 'who is the best player' # => 'The best player is Lionel Messi'
|
|
41
61
|
```
|
|
42
62
|
|
|
43
|
-
|
|
63
|
+
### Tools
|
|
64
|
+
|
|
65
|
+
Tools can be simple classes or inherit from `Rasti::AI::Tool`. Both approaches work with any provider.
|
|
66
|
+
|
|
67
|
+
#### Simple tools
|
|
44
68
|
```ruby
|
|
45
69
|
class GetCurrentTime
|
|
46
70
|
def call(params={})
|
|
@@ -54,11 +78,41 @@ class GetCurrentWeather
|
|
|
54
78
|
end
|
|
55
79
|
|
|
56
80
|
def call(params={})
|
|
57
|
-
|
|
58
|
-
response.body.to_s
|
|
81
|
+
"The wheather in #{params['location']} is sunny"
|
|
59
82
|
end
|
|
60
83
|
end
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
#### Tools inheriting from Rasti::AI::Tool
|
|
87
|
+
```ruby
|
|
88
|
+
class SumTool < Rasti::AI::Tool
|
|
89
|
+
class Form < Rasti::Form
|
|
90
|
+
attribute :number_a, Rasti::Types::Float, required: true, description: 'First number'
|
|
91
|
+
attribute :number_b, Rasti::Types::Float, required: true, description: 'Second number'
|
|
92
|
+
end
|
|
61
93
|
|
|
94
|
+
def self.description
|
|
95
|
+
'Sum two numbers'
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def execute(form)
|
|
99
|
+
{result: form.number_a + form.number_b}
|
|
100
|
+
end
|
|
101
|
+
end
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
Supported form attribute types:
|
|
105
|
+
- `Rasti::Types::String` → `string`
|
|
106
|
+
- `Rasti::Types::Integer` → `integer`
|
|
107
|
+
- `Rasti::Types::Float` → `number`
|
|
108
|
+
- `Rasti::Types::Boolean` → `boolean`
|
|
109
|
+
- `Rasti::Types::Time` → `string (date)`
|
|
110
|
+
- `Rasti::Types::Enum[:a, :b]` → `string (enum)`
|
|
111
|
+
- `Rasti::Types::Array[Type]` → `array`
|
|
112
|
+
- `Rasti::Types::Model[FormClass]` → nested `object`
|
|
113
|
+
|
|
114
|
+
#### Using tools with an assistant
|
|
115
|
+
```ruby
|
|
62
116
|
tools = [
|
|
63
117
|
GetCurrentTime.new,
|
|
64
118
|
GetCurrentWeather.new
|
|
@@ -71,29 +125,72 @@ assistant.call 'what time is it' # => 'The current time is 3:03 PM on April 28,
|
|
|
71
125
|
assistant.call 'what is the weather in Buenos Aires' # => 'In Buenos Aires it is 15 degrees'
|
|
72
126
|
```
|
|
73
127
|
|
|
74
|
-
|
|
128
|
+
### Context and state
|
|
75
129
|
```ruby
|
|
76
|
-
state = Rasti::AI::
|
|
130
|
+
state = Rasti::AI::AssistantState.new context: 'Act as sports journalist'
|
|
77
131
|
|
|
78
132
|
assistant = Rasti::AI::OpenAI::Assistant.new state: state
|
|
79
133
|
|
|
80
134
|
assistant.call 'who is the best player'
|
|
81
135
|
|
|
82
|
-
state.
|
|
83
|
-
#
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
#
|
|
136
|
+
state.context # => 'Act as sports journalist'
|
|
137
|
+
state.messages # Array of provider-specific message hashes
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
The state keeps the conversation history, enabling multi-turn interactions. It also caches tool call results to avoid duplicate executions.
|
|
141
|
+
|
|
142
|
+
### Structured responses (JSON Schema)
|
|
143
|
+
```ruby
|
|
144
|
+
assistant = Rasti::AI::OpenAI::Assistant.new json_schema: {
|
|
145
|
+
player: 'string',
|
|
146
|
+
sport: 'string'
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
response = assistant.call 'who is the best player'
|
|
150
|
+
JSON.parse response # => {"player" => "Lionel Messi", "sport" => "Football"}
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
### Custom model and client
|
|
154
|
+
```ruby
|
|
155
|
+
# Override model
|
|
156
|
+
assistant = Rasti::AI::OpenAI::Assistant.new model: 'gpt-4o'
|
|
157
|
+
|
|
158
|
+
# Custom client with per-client HTTP settings
|
|
159
|
+
client = Rasti::AI::OpenAI::Client.new(
|
|
160
|
+
http_connect_timeout: 120,
|
|
161
|
+
http_read_timeout: 120,
|
|
162
|
+
http_max_retries: 5
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
assistant = Rasti::AI::OpenAI::Assistant.new client: client
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
### Usage tracking
|
|
169
|
+
|
|
170
|
+
Track token consumption across API calls (including tool calls):
|
|
171
|
+
|
|
172
|
+
```ruby
|
|
173
|
+
tracked_usage = []
|
|
174
|
+
tracker = ->(usage) { tracked_usage << usage }
|
|
175
|
+
|
|
176
|
+
assistant = Rasti::AI::OpenAI::Assistant.new usage_tracker: tracker
|
|
177
|
+
assistant.call 'who is the best player'
|
|
178
|
+
|
|
179
|
+
usage = tracked_usage.first
|
|
180
|
+
usage.provider # => :open_ai
|
|
181
|
+
usage.model # => 'gpt-4o-mini'
|
|
182
|
+
usage.input_tokens # => 150
|
|
183
|
+
usage.output_tokens # => 42
|
|
184
|
+
usage.cached_tokens # => 0
|
|
185
|
+
usage.reasoning_tokens # => 0
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
The tracker can also be configured globally:
|
|
189
|
+
|
|
190
|
+
```ruby
|
|
191
|
+
Rasti::AI.configure do |config|
|
|
192
|
+
config.usage_tracker = ->(usage) { MyMetrics.track(usage) }
|
|
193
|
+
end
|
|
97
194
|
```
|
|
98
195
|
|
|
99
196
|
### MCP (Model Context Protocol)
|
|
@@ -129,17 +226,6 @@ class HelloWorldTool < Rasti::AI::Tool
|
|
|
129
226
|
end
|
|
130
227
|
end
|
|
131
228
|
|
|
132
|
-
class SumTool < Rasti::AI::Tool
|
|
133
|
-
class Form < Rasti::Form
|
|
134
|
-
attribute :number_a, Rasti::Types::Float
|
|
135
|
-
attribute :number_b, Rasti::Types::Float
|
|
136
|
-
end
|
|
137
|
-
|
|
138
|
-
def execute(form)
|
|
139
|
-
{result: form.number_a + form.number_b}
|
|
140
|
-
end
|
|
141
|
-
end
|
|
142
|
-
|
|
143
229
|
# Register tools
|
|
144
230
|
Rasti::AI::MCP::Server.register_tool HelloWorldTool.new
|
|
145
231
|
Rasti::AI::MCP::Server.register_tool SumTool.new
|
|
@@ -224,26 +310,21 @@ client = Rasti::AI::MCP::Client.new(
|
|
|
224
310
|
)
|
|
225
311
|
```
|
|
226
312
|
|
|
227
|
-
##### Integration with
|
|
313
|
+
##### Integration with Assistants
|
|
228
314
|
|
|
229
|
-
You can use MCP clients as tools for
|
|
315
|
+
You can use MCP clients as tools for any assistant:
|
|
230
316
|
|
|
231
317
|
```ruby
|
|
232
|
-
# Create an MCP client
|
|
233
318
|
mcp_client = Rasti::AI::MCP::Client.new(
|
|
234
319
|
url: 'https://mcp.server.ai/mcp'
|
|
235
320
|
)
|
|
236
321
|
|
|
237
|
-
# Use it with the assistant
|
|
238
322
|
assistant = Rasti::AI::OpenAI::Assistant.new(
|
|
239
|
-
mcp_servers: {
|
|
240
|
-
my_mcp: mcp_client
|
|
241
|
-
}
|
|
323
|
+
mcp_servers: {my_mcp: mcp_client}
|
|
242
324
|
)
|
|
243
325
|
|
|
244
326
|
# The assistant can now call tools from the MCP server
|
|
245
327
|
assistant.call 'What is 5 plus 3?'
|
|
246
|
-
# The assistant will use the sum_tool from the MCP server
|
|
247
328
|
```
|
|
248
329
|
|
|
249
330
|
## Contributing
|
|
@@ -252,4 +333,4 @@ Bug reports and pull requests are welcome on GitHub at https://github.com/gabyna
|
|
|
252
333
|
|
|
253
334
|
## License
|
|
254
335
|
|
|
255
|
-
The gem is available as open source under the terms of the [MIT License](http://opensource.org/licenses/MIT).
|
|
336
|
+
The gem is available as open source under the terms of the [MIT License](http://opensource.org/licenses/MIT).
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
module Rasti
|
|
2
|
+
module AI
|
|
3
|
+
class Assistant
|
|
4
|
+
|
|
5
|
+
attr_reader :state
|
|
6
|
+
|
|
7
|
+
def initialize(client:nil, json_schema:nil, state:nil, model:nil, tools:[], mcp_servers:{}, logger:nil, usage_tracker:nil)
|
|
8
|
+
@client = client || build_default_client
|
|
9
|
+
@json_schema = json_schema
|
|
10
|
+
@state = state || AssistantState.new
|
|
11
|
+
@model = model
|
|
12
|
+
@tools = {}
|
|
13
|
+
@serialized_tools = []
|
|
14
|
+
@logger = logger || Rasti::AI.logger
|
|
15
|
+
@usage_tracker = usage_tracker || Rasti::AI.usage_tracker
|
|
16
|
+
|
|
17
|
+
register_tools(tools)
|
|
18
|
+
register_mcp_servers(mcp_servers)
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def call(prompt)
|
|
22
|
+
messages << build_user_message(prompt)
|
|
23
|
+
|
|
24
|
+
loop do
|
|
25
|
+
response = request_completion
|
|
26
|
+
track_usage response
|
|
27
|
+
|
|
28
|
+
tool_calls = parse_tool_calls(response)
|
|
29
|
+
|
|
30
|
+
if tool_calls.any?
|
|
31
|
+
messages << build_assistant_tool_calls_message(response)
|
|
32
|
+
|
|
33
|
+
tool_calls.each do |tool_call|
|
|
34
|
+
name, args = extract_tool_call_info(tool_call)
|
|
35
|
+
result = call_tool(name, args)
|
|
36
|
+
messages << build_tool_result_message(tool_call, name, result)
|
|
37
|
+
end
|
|
38
|
+
else
|
|
39
|
+
content = parse_content(response)
|
|
40
|
+
|
|
41
|
+
messages << build_assistant_message(content)
|
|
42
|
+
|
|
43
|
+
return content if finished?(response)
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
private
|
|
49
|
+
|
|
50
|
+
attr_reader :client, :json_schema, :model, :tools, :serialized_tools, :logger, :usage_tracker
|
|
51
|
+
|
|
52
|
+
def messages
|
|
53
|
+
state.messages
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def track_usage(response)
|
|
57
|
+
return unless usage_tracker
|
|
58
|
+
usage = parse_usage response
|
|
59
|
+
usage_tracker.call usage if usage
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# --- Shared behavior ---
|
|
63
|
+
|
|
64
|
+
def register_tools(tools)
|
|
65
|
+
tools.each do |tool|
|
|
66
|
+
serialization = wrap_tool_serialization(ToolSerializer.serialize(tool.class))
|
|
67
|
+
name = extract_tool_name(serialization)
|
|
68
|
+
@tools[name] = tool
|
|
69
|
+
@serialized_tools << serialization
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def register_mcp_servers(mcp_servers)
|
|
74
|
+
mcp_servers.each do |server_name, mcp|
|
|
75
|
+
mcp.list_tools.each do |tool|
|
|
76
|
+
prefixed_name = "#{server_name}_#{tool['name']}"
|
|
77
|
+
raw = tool.merge('name' => prefixed_name)
|
|
78
|
+
serialization = wrap_tool_serialization(raw)
|
|
79
|
+
@tools[prefixed_name] = ->(args) { mcp.call_tool tool['name'], args }
|
|
80
|
+
@serialized_tools << serialization
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
def call_tool(name, args)
|
|
86
|
+
raise Errors::UndefinedTool.new(name) unless tools.key? name
|
|
87
|
+
|
|
88
|
+
key = "#{name} -> #{args}"
|
|
89
|
+
|
|
90
|
+
state.fetch(key) do
|
|
91
|
+
logger.info(self.class) { "Calling function #{name} with #{args}" }
|
|
92
|
+
|
|
93
|
+
result = tools[name].call args
|
|
94
|
+
|
|
95
|
+
logger.info(self.class) { "Function result: #{result}" }
|
|
96
|
+
|
|
97
|
+
result
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
rescue => ex
|
|
101
|
+
logger.warn(self.class) { "Function failed: #{ex.message}\n#{ex.backtrace.join("\n")}" }
|
|
102
|
+
"Error: #{ex.message}"
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
# --- Template methods ---
|
|
106
|
+
|
|
107
|
+
def build_default_client
|
|
108
|
+
raise NotImplementedError
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
def build_user_message(prompt)
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
def build_assistant_message(content)
|
|
116
|
+
raise NotImplementedError
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def build_assistant_tool_calls_message(response)
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
def build_tool_result_message(tool_call, name, result)
|
|
124
|
+
raise NotImplementedError
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def request_completion
|
|
128
|
+
raise NotImplementedError
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
def parse_tool_calls(response)
|
|
132
|
+
raise NotImplementedError
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
def parse_content(response)
|
|
136
|
+
raise NotImplementedError
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def finished?(response)
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
end
|
|
142
|
+
|
|
143
|
+
def parse_usage(response)
|
|
144
|
+
raise NotImplementedError
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
def extract_tool_call_info(tool_call)
|
|
148
|
+
raise NotImplementedError
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
def wrap_tool_serialization(raw)
|
|
152
|
+
raise NotImplementedError
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def extract_tool_name(wrapped)
|
|
156
|
+
raise NotImplementedError
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
end
|
|
160
|
+
end
|
|
161
|
+
end
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
module Rasti
|
|
2
|
+
module AI
|
|
3
|
+
class AssistantState
|
|
4
|
+
|
|
5
|
+
attr_reader :messages, :context
|
|
6
|
+
|
|
7
|
+
def initialize(context:nil)
|
|
8
|
+
@messages = []
|
|
9
|
+
@cache = {}
|
|
10
|
+
@context = context
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def fetch(key, &block)
|
|
14
|
+
cache[key] = block.call unless cache.key? key
|
|
15
|
+
cache[key]
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
private
|
|
19
|
+
|
|
20
|
+
attr_reader :cache
|
|
21
|
+
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
end
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
module Rasti
|
|
2
|
+
module AI
|
|
3
|
+
class Client
|
|
4
|
+
|
|
5
|
+
RETRYABLE_STATUS_CODES = [502, 503, 504].freeze
|
|
6
|
+
|
|
7
|
+
def initialize(api_key:nil, logger:nil, http_connect_timeout:nil, http_read_timeout:nil, http_max_retries:nil)
|
|
8
|
+
@api_key = api_key || default_api_key
|
|
9
|
+
@logger = logger || Rasti::AI.logger
|
|
10
|
+
@http_connect_timeout = http_connect_timeout || Rasti::AI.http_connect_timeout
|
|
11
|
+
@http_read_timeout = http_read_timeout || Rasti::AI.http_read_timeout
|
|
12
|
+
@http_max_retries = http_max_retries || Rasti::AI.http_max_retries
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
private
|
|
16
|
+
|
|
17
|
+
attr_reader :api_key, :logger, :http_connect_timeout, :http_read_timeout, :http_max_retries
|
|
18
|
+
|
|
19
|
+
def default_api_key
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def base_url
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def build_url(relative_url)
|
|
28
|
+
"#{base_url}#{relative_url}"
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def build_request(uri)
|
|
32
|
+
request = Net::HTTP::Post.new uri
|
|
33
|
+
request['Content-Type'] = 'application/json'
|
|
34
|
+
request
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def post(relative_url, body)
|
|
38
|
+
max_retries = http_max_retries
|
|
39
|
+
retry_count = 0
|
|
40
|
+
|
|
41
|
+
begin
|
|
42
|
+
url = build_url(relative_url)
|
|
43
|
+
uri = URI.parse url
|
|
44
|
+
|
|
45
|
+
logger.info(self.class) { "POST #{url}" }
|
|
46
|
+
logger.debug(self.class) { JSON.pretty_generate(body) }
|
|
47
|
+
|
|
48
|
+
request = build_request(uri)
|
|
49
|
+
request.body = JSON.dump body
|
|
50
|
+
|
|
51
|
+
http = Net::HTTP.new uri.host, uri.port
|
|
52
|
+
http.use_ssl = (uri.scheme == 'https')
|
|
53
|
+
|
|
54
|
+
http.open_timeout = http_connect_timeout
|
|
55
|
+
http.read_timeout = http_read_timeout
|
|
56
|
+
|
|
57
|
+
response = http.request request
|
|
58
|
+
|
|
59
|
+
logger.info(self.class) { "Response #{response.code}" }
|
|
60
|
+
logger.debug(self.class) { response.body }
|
|
61
|
+
|
|
62
|
+
if !response.is_a?(Net::HTTPSuccess) || RETRYABLE_STATUS_CODES.include?(response.code.to_i)
|
|
63
|
+
raise Errors::RequestFail.new(url, body, response)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
JSON.parse response.body
|
|
67
|
+
|
|
68
|
+
rescue SocketError, Net::OpenTimeout, Net::ReadTimeout, Errors::RequestFail => e
|
|
69
|
+
if retry_count < max_retries
|
|
70
|
+
retry_count += 1
|
|
71
|
+
logger.warn(self.class) { "#{e.class.name}: #{e.message} (#{retry_count}/#{max_retries})" }
|
|
72
|
+
sleep retry_count
|
|
73
|
+
retry
|
|
74
|
+
end
|
|
75
|
+
raise
|
|
76
|
+
end
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
end
|
|
80
|
+
end
|
|
81
|
+
end
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
module Rasti
|
|
2
|
+
module AI
|
|
3
|
+
module Gemini
|
|
4
|
+
class Assistant < Rasti::AI::Assistant
|
|
5
|
+
|
|
6
|
+
private
|
|
7
|
+
|
|
8
|
+
def build_default_client
|
|
9
|
+
Client.new
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def build_user_message(prompt)
|
|
13
|
+
{role: Roles::USER, parts: [{text: prompt}]}
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def build_assistant_message(content)
|
|
17
|
+
{role: Roles::MODEL, parts: [{text: content}]}
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def build_assistant_tool_calls_message(response)
|
|
21
|
+
response['candidates'][0]['content']
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def build_tool_result_message(tool_call, name, result)
|
|
25
|
+
{
|
|
26
|
+
role: Roles::FUNCTION,
|
|
27
|
+
parts: [{
|
|
28
|
+
functionResponse: {
|
|
29
|
+
name: name,
|
|
30
|
+
response: {content: result}
|
|
31
|
+
}
|
|
32
|
+
}]
|
|
33
|
+
}
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def request_completion
|
|
37
|
+
system_inst = if state.context
|
|
38
|
+
{parts: [{text: state.context}]}
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
client.generate_content contents: messages,
|
|
42
|
+
model: model,
|
|
43
|
+
tools: serialized_tools_payload,
|
|
44
|
+
system_instruction: system_inst,
|
|
45
|
+
generation_config: generation_config
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def parse_tool_calls(response)
|
|
49
|
+
parts = response.dig('candidates', 0, 'content', 'parts') || []
|
|
50
|
+
parts.select { |p| p.key?('functionCall') }
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def parse_content(response)
|
|
54
|
+
parts = response.dig('candidates', 0, 'content', 'parts') || []
|
|
55
|
+
text_part = parts.find { |p| p.key?('text') }
|
|
56
|
+
text_part['text']
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def finished?(response)
|
|
60
|
+
response.dig('candidates', 0, 'finishReason') == 'STOP'
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def parse_usage(response)
|
|
64
|
+
usage = response['usageMetadata']
|
|
65
|
+
return unless usage
|
|
66
|
+
Usage.new(
|
|
67
|
+
provider: :gemini,
|
|
68
|
+
model: response['modelVersion'],
|
|
69
|
+
input_tokens: usage['promptTokenCount'],
|
|
70
|
+
output_tokens: usage['candidatesTokenCount'],
|
|
71
|
+
cached_tokens: usage['cachedContentTokenCount'] || 0,
|
|
72
|
+
reasoning_tokens: usage['thoughtsTokenCount'] || 0
|
|
73
|
+
)
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def extract_tool_call_info(tool_call)
|
|
77
|
+
fc = tool_call['functionCall']
|
|
78
|
+
[fc['name'], fc['args'] || {}]
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
def wrap_tool_serialization(raw)
|
|
82
|
+
result = raw.dup
|
|
83
|
+
if result.key?(:inputSchema)
|
|
84
|
+
result[:parameters] = result.delete(:inputSchema)
|
|
85
|
+
elsif result.key?('inputSchema')
|
|
86
|
+
result['parameters'] = result.delete('inputSchema')
|
|
87
|
+
end
|
|
88
|
+
result
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
def extract_tool_name(wrapped)
|
|
92
|
+
wrapped[:name] || wrapped['name']
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def serialized_tools_payload
|
|
96
|
+
return [] if serialized_tools.empty?
|
|
97
|
+
[{function_declarations: serialized_tools}]
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def generation_config
|
|
101
|
+
return nil if json_schema.nil?
|
|
102
|
+
|
|
103
|
+
{
|
|
104
|
+
response_mime_type: 'application/json',
|
|
105
|
+
response_schema: json_schema
|
|
106
|
+
}
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
end
|
|
110
|
+
end
|
|
111
|
+
end
|
|
112
|
+
end
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
module Rasti
|
|
2
|
+
module AI
|
|
3
|
+
module Gemini
|
|
4
|
+
class Client < Rasti::AI::Client
|
|
5
|
+
|
|
6
|
+
def generate_content(contents:, model:nil, tools:[], system_instruction:nil, generation_config:nil)
|
|
7
|
+
model_name = model || Rasti::AI.gemini_default_model
|
|
8
|
+
|
|
9
|
+
body = {contents: contents}
|
|
10
|
+
|
|
11
|
+
body[:tools] = tools unless tools.empty?
|
|
12
|
+
body[:system_instruction] = system_instruction unless system_instruction.nil?
|
|
13
|
+
body[:generation_config] = generation_config unless generation_config.nil?
|
|
14
|
+
|
|
15
|
+
post "/models/#{model_name}:generateContent", body
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
private
|
|
19
|
+
|
|
20
|
+
def default_api_key
|
|
21
|
+
Rasti::AI.gemini_api_key
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def base_url
|
|
25
|
+
'https://generativelanguage.googleapis.com/v1beta'
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def build_url(relative_url)
|
|
29
|
+
"#{base_url}#{relative_url}?key=#{api_key}"
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|