claude_agent 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.claude/commands/spec/complete.md +105 -0
- data/.claude/commands/spec/update.md +95 -0
- data/.claude/rules/conventions.md +622 -0
- data/.claude/rules/git.md +86 -0
- data/.claude/rules/pull-requests.md +31 -0
- data/.claude/rules/releases.md +177 -0
- data/.claude/rules/testing.md +267 -0
- data/.claude/settings.json +49 -0
- data/CHANGELOG.md +13 -0
- data/CLAUDE.md +94 -0
- data/LICENSE.txt +21 -0
- data/README.md +679 -0
- data/Rakefile +63 -0
- data/SPEC.md +558 -0
- data/lib/claude_agent/abort_controller.rb +113 -0
- data/lib/claude_agent/client.rb +298 -0
- data/lib/claude_agent/content_blocks.rb +163 -0
- data/lib/claude_agent/control_protocol.rb +717 -0
- data/lib/claude_agent/errors.rb +103 -0
- data/lib/claude_agent/hooks.rb +228 -0
- data/lib/claude_agent/mcp/server.rb +166 -0
- data/lib/claude_agent/mcp/tool.rb +137 -0
- data/lib/claude_agent/message_parser.rb +262 -0
- data/lib/claude_agent/messages.rb +421 -0
- data/lib/claude_agent/options.rb +264 -0
- data/lib/claude_agent/permissions.rb +164 -0
- data/lib/claude_agent/query.rb +90 -0
- data/lib/claude_agent/sandbox_settings.rb +139 -0
- data/lib/claude_agent/spawn.rb +235 -0
- data/lib/claude_agent/transport/base.rb +61 -0
- data/lib/claude_agent/transport/subprocess.rb +432 -0
- data/lib/claude_agent/types.rb +193 -0
- data/lib/claude_agent/version.rb +5 -0
- data/lib/claude_agent.rb +28 -0
- data/sig/claude_agent.rbs +912 -0
- data/sig/manifest.yaml +5 -0
- metadata +97 -0
|
@@ -0,0 +1,717 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "securerandom"
|
|
5
|
+
|
|
6
|
+
module ClaudeAgent
|
|
7
|
+
# Handles the control protocol for bidirectional communication with Claude Code CLI
|
|
8
|
+
#
|
|
9
|
+
# The control protocol enables:
|
|
10
|
+
# - Initialization handshake with hook registration
|
|
11
|
+
# - Tool permission callbacks (can_use_tool)
|
|
12
|
+
# - Hook callbacks (PreToolUse, PostToolUse, etc.)
|
|
13
|
+
# - MCP message routing for SDK servers
|
|
14
|
+
# - Dynamic permission mode and model changes
|
|
15
|
+
# - Interrupt and file rewind operations
|
|
16
|
+
#
|
|
17
|
+
# @example Basic usage
|
|
18
|
+
# protocol = ControlProtocol.new(transport: transport, options: options)
|
|
19
|
+
# protocol.start
|
|
20
|
+
# protocol.each_message { |msg| process(msg) }
|
|
21
|
+
#
|
|
22
|
+
class ControlProtocol
|
|
23
|
+
DEFAULT_TIMEOUT = 60
|
|
24
|
+
REQUEST_ID_PREFIX = "req"
|
|
25
|
+
|
|
26
|
+
attr_reader :transport, :options, :server_info
|
|
27
|
+
|
|
28
|
+
# @param transport [Transport::Base] Transport for communication
|
|
29
|
+
# @param options [Options] Configuration options
|
|
30
|
+
def initialize(transport:, options: nil)
|
|
31
|
+
@transport = transport
|
|
32
|
+
@options = options || Options.new
|
|
33
|
+
@parser = MessageParser.new
|
|
34
|
+
@server_info = nil
|
|
35
|
+
|
|
36
|
+
# Control protocol state
|
|
37
|
+
@request_counter = 0
|
|
38
|
+
@pending_requests = {}
|
|
39
|
+
@pending_results = {}
|
|
40
|
+
@hook_callbacks = {}
|
|
41
|
+
|
|
42
|
+
# Threading primitives
|
|
43
|
+
@mutex = Mutex.new
|
|
44
|
+
@condition = ConditionVariable.new
|
|
45
|
+
|
|
46
|
+
# Reader thread
|
|
47
|
+
@reader_thread = nil
|
|
48
|
+
@message_queue = Queue.new
|
|
49
|
+
@running = false
|
|
50
|
+
|
|
51
|
+
# Abort signal from options
|
|
52
|
+
@abort_signal = options&.abort_signal
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# Start the control protocol (initialize connection)
|
|
56
|
+
# @param streaming [Boolean] Whether to use streaming mode
|
|
57
|
+
# @param prompt [String, nil] Initial prompt for non-streaming mode
|
|
58
|
+
# @return [Hash, nil] Server info from initialization
|
|
59
|
+
def start(streaming: true, prompt: nil)
|
|
60
|
+
@transport.connect(streaming: streaming, prompt: prompt)
|
|
61
|
+
@running = true
|
|
62
|
+
|
|
63
|
+
# Start background reader thread
|
|
64
|
+
@reader_thread = Thread.new { reader_loop }
|
|
65
|
+
|
|
66
|
+
# Initialize if we have hooks or SDK MCP servers
|
|
67
|
+
if streaming && (options.has_hooks? || options.has_sdk_mcp_servers?)
|
|
68
|
+
@server_info = send_initialize
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
@server_info
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Stop the control protocol
|
|
75
|
+
# @return [void]
|
|
76
|
+
def stop
|
|
77
|
+
@running = false
|
|
78
|
+
@transport.end_input
|
|
79
|
+
@reader_thread&.join(5)
|
|
80
|
+
@transport.close
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Abort all pending operations (TypeScript SDK parity)
|
|
84
|
+
#
|
|
85
|
+
# This method:
|
|
86
|
+
# 1. Stops the reader loop
|
|
87
|
+
# 2. Fails all pending requests with AbortError
|
|
88
|
+
# 3. Terminates the transport
|
|
89
|
+
#
|
|
90
|
+
# @return [void]
|
|
91
|
+
def abort!
|
|
92
|
+
@running = false
|
|
93
|
+
|
|
94
|
+
# Fail all pending requests
|
|
95
|
+
@mutex.synchronize do
|
|
96
|
+
@pending_requests.each_key do |request_id|
|
|
97
|
+
@pending_results[request_id] = {
|
|
98
|
+
"subtype" => "error",
|
|
99
|
+
"error" => "Operation aborted"
|
|
100
|
+
}
|
|
101
|
+
end
|
|
102
|
+
@condition.broadcast
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
# Terminate the transport
|
|
106
|
+
@transport.terminate if @transport.respond_to?(:terminate)
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Send a user message
|
|
110
|
+
# @param content [String, Array] Message content
|
|
111
|
+
# @param session_id [String] Session ID
|
|
112
|
+
# @param uuid [String, nil] Message UUID for file checkpointing
|
|
113
|
+
# @return [void]
|
|
114
|
+
def send_user_message(content, session_id: "default", uuid: nil)
|
|
115
|
+
message = {
|
|
116
|
+
type: "user",
|
|
117
|
+
message: { role: "user", content: content },
|
|
118
|
+
session_id: session_id
|
|
119
|
+
}
|
|
120
|
+
message[:uuid] = uuid if uuid
|
|
121
|
+
write_message(message)
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
# Iterate over incoming messages (SDK messages only, not control)
|
|
125
|
+
# @yield [Message] Parsed message objects
|
|
126
|
+
# @return [Enumerator] If no block given
|
|
127
|
+
# @raise [AbortError] If abort signal is triggered
|
|
128
|
+
def each_message
|
|
129
|
+
return enum_for(:each_message) unless block_given?
|
|
130
|
+
|
|
131
|
+
while @running || !@message_queue.empty?
|
|
132
|
+
# Check abort signal
|
|
133
|
+
@abort_signal&.check!
|
|
134
|
+
|
|
135
|
+
begin
|
|
136
|
+
raw = @message_queue.pop(true)
|
|
137
|
+
message = @parser.parse(raw)
|
|
138
|
+
yield message
|
|
139
|
+
rescue ThreadError
|
|
140
|
+
# Queue empty, wait a bit
|
|
141
|
+
sleep 0.01
|
|
142
|
+
rescue AbortError
|
|
143
|
+
# Re-raise abort errors
|
|
144
|
+
raise
|
|
145
|
+
rescue => e
|
|
146
|
+
# Log parsing errors but continue
|
|
147
|
+
warn "[ClaudeAgent] Message parse error: #{e.message}" if ENV["CLAUDE_AGENT_DEBUG"]
|
|
148
|
+
end
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# Receive messages until a ResultMessage is received
|
|
153
|
+
# @yield [Message] Parsed message objects
|
|
154
|
+
# @return [Enumerator] If no block given
|
|
155
|
+
def receive_response
|
|
156
|
+
return enum_for(:receive_response) unless block_given?
|
|
157
|
+
|
|
158
|
+
each_message do |message|
|
|
159
|
+
yield message
|
|
160
|
+
break if message.is_a?(ResultMessage)
|
|
161
|
+
end
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
# Stream user input from an enumerable (TypeScript SDK parity)
|
|
165
|
+
#
|
|
166
|
+
# Sends each message from the input stream to Claude. Messages can be:
|
|
167
|
+
# - String: Sent as user message content
|
|
168
|
+
# - Hash: Must have :content key, optionally :session_id and :uuid
|
|
169
|
+
# - UserMessage: Sent directly
|
|
170
|
+
#
|
|
171
|
+
# @param stream [Enumerable] Input stream of messages
|
|
172
|
+
# @param session_id [String] Default session ID for messages
|
|
173
|
+
# @return [void]
|
|
174
|
+
# @raise [AbortError] If abort signal is triggered
|
|
175
|
+
#
|
|
176
|
+
# @example With strings
|
|
177
|
+
# protocol.stream_input(["Hello", "How are you?"])
|
|
178
|
+
#
|
|
179
|
+
# @example With hashes
|
|
180
|
+
# protocol.stream_input([
|
|
181
|
+
# { content: "Hello", uuid: "msg-1" },
|
|
182
|
+
# { content: "Follow up", session_id: "custom" }
|
|
183
|
+
# ])
|
|
184
|
+
#
|
|
185
|
+
def stream_input(stream, session_id: "default")
|
|
186
|
+
stream.each do |message|
|
|
187
|
+
# Check abort signal before each message
|
|
188
|
+
@abort_signal&.check!
|
|
189
|
+
|
|
190
|
+
case message
|
|
191
|
+
when String
|
|
192
|
+
send_user_message(message, session_id: session_id)
|
|
193
|
+
when Hash
|
|
194
|
+
content = message[:content] || message["content"]
|
|
195
|
+
msg_session = message[:session_id] || message["session_id"] || session_id
|
|
196
|
+
uuid = message[:uuid] || message["uuid"]
|
|
197
|
+
send_user_message(content, session_id: msg_session, uuid: uuid)
|
|
198
|
+
when UserMessage, UserMessageReplay
|
|
199
|
+
send_user_message(message.content, session_id: message.session_id || session_id, uuid: message.uuid)
|
|
200
|
+
else
|
|
201
|
+
raise ArgumentError, "Unknown message type in stream: #{message.class}"
|
|
202
|
+
end
|
|
203
|
+
end
|
|
204
|
+
end
|
|
205
|
+
|
|
206
|
+
# Stream user input and receive responses (TypeScript SDK parity)
|
|
207
|
+
#
|
|
208
|
+
# Sends messages from the input stream in a background thread while
|
|
209
|
+
# yielding responses in the foreground. This enables concurrent input/output.
|
|
210
|
+
#
|
|
211
|
+
# @param stream [Enumerable] Input stream of messages
|
|
212
|
+
# @param session_id [String] Default session ID for messages
|
|
213
|
+
# @yield [Message] Received messages
|
|
214
|
+
# @return [Enumerator] If no block given
|
|
215
|
+
# @raise [AbortError] If abort signal is triggered
|
|
216
|
+
#
|
|
217
|
+
# @example
|
|
218
|
+
# messages = ["Hello", "Tell me more"]
|
|
219
|
+
# protocol.stream_conversation(messages) do |msg|
|
|
220
|
+
# case msg
|
|
221
|
+
# when ClaudeAgent::AssistantMessage
|
|
222
|
+
# puts msg.text
|
|
223
|
+
# when ClaudeAgent::ResultMessage
|
|
224
|
+
# puts "Done!"
|
|
225
|
+
# end
|
|
226
|
+
# end
|
|
227
|
+
#
|
|
228
|
+
def stream_conversation(stream, session_id: "default", &block)
|
|
229
|
+
return enum_for(:stream_conversation, stream, session_id: session_id) unless block_given?
|
|
230
|
+
|
|
231
|
+
# Track errors from the sender thread
|
|
232
|
+
sender_error = nil
|
|
233
|
+
|
|
234
|
+
# Start sender thread
|
|
235
|
+
sender_thread = Thread.new do
|
|
236
|
+
stream_input(stream, session_id: session_id)
|
|
237
|
+
rescue AbortError => e
|
|
238
|
+
sender_error = e
|
|
239
|
+
rescue => e
|
|
240
|
+
sender_error = e
|
|
241
|
+
# Don't re-raise here; let the main thread handle it
|
|
242
|
+
end
|
|
243
|
+
|
|
244
|
+
# Yield responses until we get a ResultMessage or error
|
|
245
|
+
begin
|
|
246
|
+
each_message do |message|
|
|
247
|
+
# Check if sender had an error
|
|
248
|
+
if sender_error
|
|
249
|
+
raise sender_error if sender_error.is_a?(AbortError)
|
|
250
|
+
|
|
251
|
+
raise Error, "Stream input error: #{sender_error.message}"
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
yield message
|
|
255
|
+
break if message.is_a?(ResultMessage)
|
|
256
|
+
end
|
|
257
|
+
ensure
|
|
258
|
+
# Wait for sender to finish
|
|
259
|
+
sender_thread.join(1)
|
|
260
|
+
end
|
|
261
|
+
|
|
262
|
+
# Check for sender errors after loop
|
|
263
|
+
raise sender_error if sender_error.is_a?(AbortError)
|
|
264
|
+
|
|
265
|
+
raise Error, "Stream input error: #{sender_error.message}" if sender_error
|
|
266
|
+
end
|
|
267
|
+
|
|
268
|
+
# Send an interrupt request
|
|
269
|
+
# @return [void]
|
|
270
|
+
def interrupt
|
|
271
|
+
send_control_request(subtype: "interrupt")
|
|
272
|
+
end
|
|
273
|
+
|
|
274
|
+
# Change the permission mode
|
|
275
|
+
# @param mode [String] New permission mode
|
|
276
|
+
# @return [Hash] Response
|
|
277
|
+
def set_permission_mode(mode)
|
|
278
|
+
send_control_request(subtype: "set_permission_mode", mode: mode)
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
# Change the model
|
|
282
|
+
# @param model [String, nil] New model name
|
|
283
|
+
# @return [Hash] Response
|
|
284
|
+
def set_model(model)
|
|
285
|
+
send_control_request(subtype: "set_model", model: model)
|
|
286
|
+
end
|
|
287
|
+
|
|
288
|
+
# Rewind files to a previous state
|
|
289
|
+
# @param user_message_id [String] UUID of user message to rewind to
|
|
290
|
+
# @param dry_run [Boolean] If true, preview changes without modifying files
|
|
291
|
+
# @return [RewindFilesResult] Result with rewind information
|
|
292
|
+
def rewind_files(user_message_id, dry_run: false)
|
|
293
|
+
request = { user_message_id: user_message_id }
|
|
294
|
+
request[:dry_run] = dry_run if dry_run
|
|
295
|
+
|
|
296
|
+
response = send_control_request(subtype: "rewind_files", **request)
|
|
297
|
+
|
|
298
|
+
RewindFilesResult.new(
|
|
299
|
+
can_rewind: response["canRewind"] || response["can_rewind"] || false,
|
|
300
|
+
error: response["error"],
|
|
301
|
+
files_changed: response["filesChanged"] || response["files_changed"],
|
|
302
|
+
insertions: response["insertions"],
|
|
303
|
+
deletions: response["deletions"]
|
|
304
|
+
)
|
|
305
|
+
end
|
|
306
|
+
|
|
307
|
+
# Set maximum thinking tokens (TypeScript SDK parity)
|
|
308
|
+
# @param tokens [Integer, nil] Max thinking tokens (nil to reset)
|
|
309
|
+
# @return [Hash] Response
|
|
310
|
+
def set_max_thinking_tokens(tokens)
|
|
311
|
+
send_control_request(subtype: "set_max_thinking_tokens", max_thinking_tokens: tokens)
|
|
312
|
+
end
|
|
313
|
+
|
|
314
|
+
# Get available slash commands (TypeScript SDK parity)
|
|
315
|
+
# @return [Array<SlashCommand>]
|
|
316
|
+
def supported_commands
|
|
317
|
+
response = send_control_request(subtype: "supported_commands")
|
|
318
|
+
(response["commands"] || []).map do |cmd|
|
|
319
|
+
SlashCommand.new(
|
|
320
|
+
name: cmd["name"],
|
|
321
|
+
description: cmd["description"],
|
|
322
|
+
argument_hint: cmd["argumentHint"]
|
|
323
|
+
)
|
|
324
|
+
end
|
|
325
|
+
end
|
|
326
|
+
|
|
327
|
+
# Get available models (TypeScript SDK parity)
|
|
328
|
+
# @return [Array<ModelInfo>]
|
|
329
|
+
def supported_models
|
|
330
|
+
response = send_control_request(subtype: "supported_models")
|
|
331
|
+
(response["models"] || []).map do |model|
|
|
332
|
+
ModelInfo.new(
|
|
333
|
+
value: model["value"],
|
|
334
|
+
display_name: model["displayName"],
|
|
335
|
+
description: model["description"]
|
|
336
|
+
)
|
|
337
|
+
end
|
|
338
|
+
end
|
|
339
|
+
|
|
340
|
+
# Get MCP server status (TypeScript SDK parity)
|
|
341
|
+
# @return [Array<McpServerStatus>]
|
|
342
|
+
def mcp_server_status
|
|
343
|
+
response = send_control_request(subtype: "mcp_server_status")
|
|
344
|
+
(response["servers"] || []).map do |server|
|
|
345
|
+
McpServerStatus.new(
|
|
346
|
+
name: server["name"],
|
|
347
|
+
status: server["status"],
|
|
348
|
+
server_info: server["serverInfo"]
|
|
349
|
+
)
|
|
350
|
+
end
|
|
351
|
+
end
|
|
352
|
+
|
|
353
|
+
# Get account information (TypeScript SDK parity)
|
|
354
|
+
# @return [AccountInfo]
|
|
355
|
+
def account_info
|
|
356
|
+
response = send_control_request(subtype: "account_info")
|
|
357
|
+
AccountInfo.new(
|
|
358
|
+
email: response["email"],
|
|
359
|
+
organization: response["organization"],
|
|
360
|
+
subscription_type: response["subscriptionType"],
|
|
361
|
+
token_source: response["tokenSource"],
|
|
362
|
+
api_key_source: response["apiKeySource"]
|
|
363
|
+
)
|
|
364
|
+
end
|
|
365
|
+
|
|
366
|
+
# Dynamically set MCP servers for this session (TypeScript SDK parity)
|
|
367
|
+
#
|
|
368
|
+
# This replaces the current set of dynamically-added MCP servers.
|
|
369
|
+
# Servers that are removed will be disconnected, and new servers will be connected.
|
|
370
|
+
#
|
|
371
|
+
# @param servers [Hash] Map of server name to configuration
|
|
372
|
+
# @return [McpSetServersResult] Result with added, removed, and errors
|
|
373
|
+
#
|
|
374
|
+
# @example
|
|
375
|
+
# result = protocol.set_mcp_servers({
|
|
376
|
+
# "my-server" => { type: "stdio", command: "node", args: ["server.js"] }
|
|
377
|
+
# })
|
|
378
|
+
# puts "Added: #{result.added}"
|
|
379
|
+
# puts "Removed: #{result.removed}"
|
|
380
|
+
#
|
|
381
|
+
def set_mcp_servers(servers)
|
|
382
|
+
# Convert servers hash to format expected by CLI
|
|
383
|
+
servers_config = servers.transform_values do |config|
|
|
384
|
+
if config.is_a?(Hash)
|
|
385
|
+
# Skip SDK servers (they're handled locally) - only send process-based servers
|
|
386
|
+
next nil if config[:type] == "sdk" || config["type"] == "sdk"
|
|
387
|
+
|
|
388
|
+
config
|
|
389
|
+
else
|
|
390
|
+
config
|
|
391
|
+
end
|
|
392
|
+
end.compact
|
|
393
|
+
|
|
394
|
+
response = send_control_request(subtype: "mcp_set_servers", servers: servers_config)
|
|
395
|
+
|
|
396
|
+
McpSetServersResult.new(
|
|
397
|
+
added: response["added"] || [],
|
|
398
|
+
removed: response["removed"] || [],
|
|
399
|
+
errors: response["errors"] || {}
|
|
400
|
+
)
|
|
401
|
+
end
|
|
402
|
+
|
|
403
|
+
private
|
|
404
|
+
|
|
405
|
+
# Background thread that reads messages and routes them
|
|
406
|
+
def reader_loop
|
|
407
|
+
@transport.read_messages do |raw|
|
|
408
|
+
# Check abort signal on each iteration
|
|
409
|
+
if @abort_signal&.aborted?
|
|
410
|
+
@running = false
|
|
411
|
+
break
|
|
412
|
+
end
|
|
413
|
+
|
|
414
|
+
break unless @running
|
|
415
|
+
|
|
416
|
+
if raw["type"] == "control_request"
|
|
417
|
+
handle_control_request(raw)
|
|
418
|
+
elsif raw["type"] == "control_response"
|
|
419
|
+
handle_control_response(raw)
|
|
420
|
+
else
|
|
421
|
+
# SDK message - queue for consumer
|
|
422
|
+
@message_queue.push(raw)
|
|
423
|
+
end
|
|
424
|
+
end
|
|
425
|
+
rescue IOError, Errno::EPIPE
|
|
426
|
+
# Transport closed
|
|
427
|
+
@running = false
|
|
428
|
+
rescue AbortError
|
|
429
|
+
# Abort signal raised
|
|
430
|
+
@running = false
|
|
431
|
+
end
|
|
432
|
+
|
|
433
|
+
# Send initialization request
|
|
434
|
+
# @return [Hash] Server info
|
|
435
|
+
def send_initialize
|
|
436
|
+
hooks_config = build_hooks_config
|
|
437
|
+
|
|
438
|
+
request = { subtype: "initialize" }
|
|
439
|
+
request[:hooks] = hooks_config if hooks_config
|
|
440
|
+
|
|
441
|
+
send_control_request(**request)
|
|
442
|
+
end
|
|
443
|
+
|
|
444
|
+
# Build hooks configuration for initialization
|
|
445
|
+
# @return [Hash, nil]
|
|
446
|
+
def build_hooks_config
|
|
447
|
+
return nil unless options.has_hooks?
|
|
448
|
+
|
|
449
|
+
config = {}
|
|
450
|
+
|
|
451
|
+
options.hooks.each do |event, matchers|
|
|
452
|
+
config[event] = matchers.map.with_index do |matcher, idx|
|
|
453
|
+
callback_ids = matcher.callbacks.map.with_index do |callback, cidx|
|
|
454
|
+
callback_id = "hook_#{event}_#{idx}_#{cidx}"
|
|
455
|
+
@hook_callbacks[callback_id] = callback
|
|
456
|
+
callback_id
|
|
457
|
+
end
|
|
458
|
+
|
|
459
|
+
entry = {
|
|
460
|
+
matcher: matcher.matcher,
|
|
461
|
+
hookCallbackIds: callback_ids
|
|
462
|
+
}
|
|
463
|
+
entry[:timeout] = matcher.timeout if matcher.timeout
|
|
464
|
+
entry
|
|
465
|
+
end
|
|
466
|
+
end
|
|
467
|
+
|
|
468
|
+
config
|
|
469
|
+
end
|
|
470
|
+
|
|
471
|
+
# Handle incoming control request from CLI
|
|
472
|
+
# @param raw [Hash] Raw control request
|
|
473
|
+
def handle_control_request(raw)
|
|
474
|
+
request = raw["request"] || {}
|
|
475
|
+
request_id = raw["request_id"]
|
|
476
|
+
subtype = request["subtype"]
|
|
477
|
+
|
|
478
|
+
response = case subtype
|
|
479
|
+
when "can_use_tool"
|
|
480
|
+
handle_can_use_tool(request)
|
|
481
|
+
when "hook_callback"
|
|
482
|
+
handle_hook_callback(request)
|
|
483
|
+
when "mcp_message"
|
|
484
|
+
handle_mcp_message(request)
|
|
485
|
+
else
|
|
486
|
+
{ error: "Unknown control request subtype: #{subtype}" }
|
|
487
|
+
end
|
|
488
|
+
|
|
489
|
+
send_control_response(request_id, response)
|
|
490
|
+
rescue => e
|
|
491
|
+
send_control_response(request_id, { error: e.message })
|
|
492
|
+
end
|
|
493
|
+
|
|
494
|
+
# Handle can_use_tool permission request
|
|
495
|
+
# @param request [Hash] Request data
|
|
496
|
+
# @return [Hash] Response
|
|
497
|
+
def handle_can_use_tool(request)
|
|
498
|
+
return { behavior: "allow" } unless options.can_use_tool
|
|
499
|
+
|
|
500
|
+
tool_name = request["tool_name"]
|
|
501
|
+
input = request["input"] || {}
|
|
502
|
+
context = {
|
|
503
|
+
permission_suggestions: request["permission_suggestions"],
|
|
504
|
+
blocked_path: request["blocked_path"],
|
|
505
|
+
decision_reason: request["decision_reason"],
|
|
506
|
+
tool_use_id: request["tool_use_id"],
|
|
507
|
+
agent_id: request["agent_id"]
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
result = options.can_use_tool.call(tool_name, input, context)
|
|
511
|
+
|
|
512
|
+
# Normalize result
|
|
513
|
+
if result.is_a?(Hash)
|
|
514
|
+
if result[:behavior] == "allow"
|
|
515
|
+
response = { behavior: "allow" }
|
|
516
|
+
response[:updatedInput] = result[:updated_input] if result[:updated_input]
|
|
517
|
+
if result[:updated_permissions]
|
|
518
|
+
response[:updatedPermissions] = result[:updated_permissions].map do |p|
|
|
519
|
+
p.respond_to?(:to_h) ? p.to_h : p
|
|
520
|
+
end
|
|
521
|
+
end
|
|
522
|
+
response
|
|
523
|
+
else
|
|
524
|
+
{
|
|
525
|
+
behavior: "deny",
|
|
526
|
+
message: result[:message] || "",
|
|
527
|
+
interrupt: result[:interrupt] || false
|
|
528
|
+
}
|
|
529
|
+
end
|
|
530
|
+
else
|
|
531
|
+
{ behavior: "allow" }
|
|
532
|
+
end
|
|
533
|
+
end
|
|
534
|
+
|
|
535
|
+
# Handle hook callback request
|
|
536
|
+
# @param request [Hash] Request data
|
|
537
|
+
# @return [Hash] Response
|
|
538
|
+
def handle_hook_callback(request)
|
|
539
|
+
callback_id = request["callback_id"]
|
|
540
|
+
input = request["input"] || {}
|
|
541
|
+
tool_use_id = request["tool_use_id"]
|
|
542
|
+
|
|
543
|
+
callback = @hook_callbacks[callback_id]
|
|
544
|
+
return {} unless callback
|
|
545
|
+
|
|
546
|
+
context = { tool_use_id: tool_use_id }
|
|
547
|
+
result = callback.call(input, context)
|
|
548
|
+
|
|
549
|
+
# Normalize result - convert Ruby field names to CLI field names
|
|
550
|
+
normalize_hook_response(result || {})
|
|
551
|
+
end
|
|
552
|
+
|
|
553
|
+
# Handle MCP message routing
|
|
554
|
+
# @param request [Hash] Request data
|
|
555
|
+
# @return [Hash] Response
|
|
556
|
+
def handle_mcp_message(request)
|
|
557
|
+
server_name = request["server_name"]
|
|
558
|
+
message = request["message"]
|
|
559
|
+
|
|
560
|
+
# Find SDK MCP server
|
|
561
|
+
server_config = options.mcp_servers[server_name]
|
|
562
|
+
return { error: "Unknown MCP server: #{server_name}" } unless server_config
|
|
563
|
+
return { error: "Not an SDK MCP server" } unless server_config[:type] == "sdk"
|
|
564
|
+
|
|
565
|
+
server_instance = server_config[:instance]
|
|
566
|
+
return { error: "No server instance" } unless server_instance
|
|
567
|
+
|
|
568
|
+
# Route message to server
|
|
569
|
+
mcp_response = server_instance.handle_message(message)
|
|
570
|
+
{ mcp_response: mcp_response }
|
|
571
|
+
end
|
|
572
|
+
|
|
573
|
+
# Mapping of Ruby keys to CLI keys for hook responses
|
|
574
|
+
# Handles special cases where Ruby uses trailing underscore for reserved words
|
|
575
|
+
HOOK_RESPONSE_KEYS = {
|
|
576
|
+
continue_: "continue",
|
|
577
|
+
continue: "continue",
|
|
578
|
+
async_: "async",
|
|
579
|
+
async: "async",
|
|
580
|
+
async_timeout: "asyncTimeout",
|
|
581
|
+
suppress_output: "suppressOutput",
|
|
582
|
+
stop_reason: "stopReason",
|
|
583
|
+
decision: "decision",
|
|
584
|
+
system_message: "systemMessage",
|
|
585
|
+
reason: "reason"
|
|
586
|
+
}.freeze
|
|
587
|
+
|
|
588
|
+
# Normalize hook response for CLI
|
|
589
|
+
# @param result [Hash] Raw result from callback
|
|
590
|
+
# @return [Hash] Normalized response
|
|
591
|
+
def normalize_hook_response(result)
|
|
592
|
+
response = HOOK_RESPONSE_KEYS.each_with_object({}) do |(ruby_key, json_key), acc|
|
|
593
|
+
acc[json_key] = result[ruby_key] if result.key?(ruby_key)
|
|
594
|
+
end
|
|
595
|
+
|
|
596
|
+
if result[:hook_specific_output]
|
|
597
|
+
response["hookSpecificOutput"] = normalize_hook_specific_output(result[:hook_specific_output])
|
|
598
|
+
end
|
|
599
|
+
|
|
600
|
+
response
|
|
601
|
+
end
|
|
602
|
+
|
|
603
|
+
# Normalize hookSpecificOutput nested fields to camelCase
|
|
604
|
+
# @param hso [Hash] Hook-specific output
|
|
605
|
+
# @return [Hash] Normalized output
|
|
606
|
+
def normalize_hook_specific_output(hso)
|
|
607
|
+
hso.each_with_object({}) do |(key, value), normalized|
|
|
608
|
+
camel_key = key.to_s.camelize(:lower)
|
|
609
|
+
normalized[camel_key] = value
|
|
610
|
+
end
|
|
611
|
+
end
|
|
612
|
+
|
|
613
|
+
# Handle control response from CLI
|
|
614
|
+
# @param raw [Hash] Raw control response
|
|
615
|
+
def handle_control_response(raw)
|
|
616
|
+
response = raw["response"] || {}
|
|
617
|
+
request_id = response["request_id"]
|
|
618
|
+
|
|
619
|
+
@mutex.synchronize do
|
|
620
|
+
if @pending_requests.key?(request_id)
|
|
621
|
+
@pending_results[request_id] = response
|
|
622
|
+
@condition.broadcast
|
|
623
|
+
end
|
|
624
|
+
end
|
|
625
|
+
end
|
|
626
|
+
|
|
627
|
+
# Send a control request and wait for response
|
|
628
|
+
# @param subtype [String] Request subtype
|
|
629
|
+
# @param kwargs [Hash] Additional request data
|
|
630
|
+
# @param timeout [Integer] Timeout in seconds
|
|
631
|
+
# @return [Hash] Response data
|
|
632
|
+
# @raise [AbortError] If abort signal is triggered
|
|
633
|
+
def send_control_request(subtype:, timeout: DEFAULT_TIMEOUT, **kwargs)
|
|
634
|
+
# Check abort signal before sending
|
|
635
|
+
@abort_signal&.check!
|
|
636
|
+
|
|
637
|
+
request_id = generate_request_id
|
|
638
|
+
|
|
639
|
+
request = {
|
|
640
|
+
type: "control_request",
|
|
641
|
+
request_id: request_id,
|
|
642
|
+
request: { subtype: subtype, **kwargs }
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
@mutex.synchronize do
|
|
646
|
+
@pending_requests[request_id] = true
|
|
647
|
+
end
|
|
648
|
+
|
|
649
|
+
write_message(request)
|
|
650
|
+
|
|
651
|
+
# Wait for response
|
|
652
|
+
response = nil
|
|
653
|
+
@mutex.synchronize do
|
|
654
|
+
deadline = Time.now + timeout
|
|
655
|
+
until @pending_results.key?(request_id)
|
|
656
|
+
# Check abort signal during wait (outside mutex for thread safety)
|
|
657
|
+
if @abort_signal&.aborted?
|
|
658
|
+
@pending_requests.delete(request_id)
|
|
659
|
+
raise AbortError, @abort_signal.reason
|
|
660
|
+
end
|
|
661
|
+
|
|
662
|
+
remaining = deadline - Time.now
|
|
663
|
+
if remaining <= 0
|
|
664
|
+
@pending_requests.delete(request_id)
|
|
665
|
+
raise TimeoutError.new("Control request timed out", request_id: request_id, timeout_seconds: timeout)
|
|
666
|
+
end
|
|
667
|
+
@condition.wait(@mutex, [ remaining, 0.1 ].min) # Wake up periodically to check abort
|
|
668
|
+
end
|
|
669
|
+
response = @pending_results.delete(request_id)
|
|
670
|
+
@pending_requests.delete(request_id)
|
|
671
|
+
end
|
|
672
|
+
|
|
673
|
+
if response["subtype"] == "error"
|
|
674
|
+
raise Error, response["error"] || "Unknown error"
|
|
675
|
+
end
|
|
676
|
+
|
|
677
|
+
response["response"] || response
|
|
678
|
+
end
|
|
679
|
+
|
|
680
|
+
# Send a control response
|
|
681
|
+
# @param request_id [String] Request ID to respond to
|
|
682
|
+
# @param data [Hash] Response data
|
|
683
|
+
def send_control_response(request_id, data)
|
|
684
|
+
response = {
|
|
685
|
+
type: "control_response",
|
|
686
|
+
response: {
|
|
687
|
+
subtype: data[:error] ? "error" : "success",
|
|
688
|
+
request_id: request_id
|
|
689
|
+
}
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
if data[:error]
|
|
693
|
+
response[:response][:error] = data[:error]
|
|
694
|
+
else
|
|
695
|
+
response[:response][:response] = data
|
|
696
|
+
end
|
|
697
|
+
|
|
698
|
+
write_message(response)
|
|
699
|
+
end
|
|
700
|
+
|
|
701
|
+
# Write a message to the transport
|
|
702
|
+
# @param message [Hash] Message to write
|
|
703
|
+
def write_message(message)
|
|
704
|
+
json = JSON.generate(message)
|
|
705
|
+
@transport.write(json)
|
|
706
|
+
end
|
|
707
|
+
|
|
708
|
+
# Generate a unique request ID
|
|
709
|
+
# @return [String]
|
|
710
|
+
def generate_request_id
|
|
711
|
+
@mutex.synchronize do
|
|
712
|
+
@request_counter += 1
|
|
713
|
+
"#{REQUEST_ID_PREFIX}_#{@request_counter}_#{SecureRandom.hex(4)}"
|
|
714
|
+
end
|
|
715
|
+
end
|
|
716
|
+
end
|
|
717
|
+
end
|