graph-agent 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.github/workflows/ci.yml +50 -0
- data/.github/workflows/release.yml +49 -0
- data/.gitignore +6 -0
- data/.rspec +3 -0
- data/.rubocop.yml +126 -0
- data/CHANGELOG.md +26 -0
- data/CLAUDE.md +128 -0
- data/Gemfile +11 -0
- data/Gemfile.lock +94 -0
- data/LICENSE +21 -0
- data/Makefile +114 -0
- data/README.md +464 -0
- data/Rakefile +15 -0
- data/docs/README.md +55 -0
- data/docs/api_reference.md +832 -0
- data/docs/concepts.md +216 -0
- data/docs/edges.md +265 -0
- data/docs/error_handling.md +241 -0
- data/docs/human_in_the_loop.md +231 -0
- data/docs/persistence.md +276 -0
- data/docs/quickstart.md +154 -0
- data/docs/send_and_command.md +218 -0
- data/docs/state.md +181 -0
- data/docs/streaming.md +172 -0
- data/graph-agent.gemspec +48 -0
- data/lib/graph_agent/channels/base_channel.rb +52 -0
- data/lib/graph_agent/channels/binary_operator_aggregate.rb +56 -0
- data/lib/graph_agent/channels/ephemeral_value.rb +59 -0
- data/lib/graph_agent/channels/last_value.rb +49 -0
- data/lib/graph_agent/channels/topic.rb +58 -0
- data/lib/graph_agent/checkpoint/base_saver.rb +38 -0
- data/lib/graph_agent/checkpoint/in_memory_saver.rb +145 -0
- data/lib/graph_agent/constants.rb +9 -0
- data/lib/graph_agent/errors.rb +41 -0
- data/lib/graph_agent/graph/compiled_state_graph.rb +362 -0
- data/lib/graph_agent/graph/conditional_edge.rb +57 -0
- data/lib/graph_agent/graph/edge.rb +23 -0
- data/lib/graph_agent/graph/mermaid_visualizer.rb +154 -0
- data/lib/graph_agent/graph/message_graph.rb +18 -0
- data/lib/graph_agent/graph/node.rb +61 -0
- data/lib/graph_agent/graph/state_graph.rb +197 -0
- data/lib/graph_agent/reducers.rb +34 -0
- data/lib/graph_agent/state/schema.rb +54 -0
- data/lib/graph_agent/types/cache_policy.rb +12 -0
- data/lib/graph_agent/types/command.rb +26 -0
- data/lib/graph_agent/types/interrupt.rb +28 -0
- data/lib/graph_agent/types/retry_policy.rb +42 -0
- data/lib/graph_agent/types/send.rb +26 -0
- data/lib/graph_agent/types/state_snapshot.rb +28 -0
- data/lib/graph_agent/version.rb +5 -0
- data/lib/graph_agent.rb +29 -0
- metadata +158 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "securerandom"
|
|
4
|
+
|
|
5
|
+
module GraphAgent
|
|
6
|
+
module Graph
|
|
7
|
+
class CompiledStateGraph # rubocop:disable Metrics/ClassLength
|
|
8
|
+
DEFAULT_RECURSION_LIMIT = 25
|
|
9
|
+
|
|
10
|
+
attr_reader :builder, :checkpointer
|
|
11
|
+
|
|
12
|
+
def initialize(builder:, checkpointer: nil, interrupt_before: [], interrupt_after: [], debug: false)
|
|
13
|
+
@builder = builder
|
|
14
|
+
@checkpointer = checkpointer
|
|
15
|
+
@interrupt_before = _normalize_interrupt(interrupt_before)
|
|
16
|
+
@interrupt_after = _normalize_interrupt(interrupt_after)
|
|
17
|
+
@debug = debug
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def invoke(input, config: {}, recursion_limit: DEFAULT_RECURSION_LIMIT)
|
|
21
|
+
last_state = nil
|
|
22
|
+
|
|
23
|
+
_run_pregel(input, config: config, recursion_limit: recursion_limit) do |_event|
|
|
24
|
+
last_state = _event[:state] if _event[:type] == :values
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
last_state
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def stream(input, config: {}, recursion_limit: DEFAULT_RECURSION_LIMIT, stream_mode: :values, &block)
|
|
31
|
+
unless block
|
|
32
|
+
return enum_for(:stream, input, config: config, recursion_limit: recursion_limit, stream_mode: stream_mode)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
_run_pregel(input, config: config, recursion_limit: recursion_limit) do |event|
|
|
36
|
+
_emit_stream_event(event, stream_mode, &block)
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def get_state(config)
|
|
41
|
+
return nil unless @checkpointer
|
|
42
|
+
|
|
43
|
+
tuple = @checkpointer.get_tuple(config)
|
|
44
|
+
return nil unless tuple
|
|
45
|
+
|
|
46
|
+
StateSnapshot.new(
|
|
47
|
+
values: tuple.checkpoint[:channel_values] || {},
|
|
48
|
+
config: tuple.config,
|
|
49
|
+
metadata: tuple.metadata || {},
|
|
50
|
+
parent_config: tuple.parent_config,
|
|
51
|
+
next_nodes: tuple.checkpoint[:next_nodes] || []
|
|
52
|
+
)
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def update_state(config, values, as_node: nil)
|
|
56
|
+
return nil unless @checkpointer
|
|
57
|
+
|
|
58
|
+
tuple = @checkpointer.get_tuple(config)
|
|
59
|
+
return nil unless tuple
|
|
60
|
+
|
|
61
|
+
current_state = tuple.checkpoint[:channel_values] || {}
|
|
62
|
+
new_state = _apply_updates(current_state, values)
|
|
63
|
+
|
|
64
|
+
checkpoint = _build_checkpoint(new_state, [])
|
|
65
|
+
metadata = { source: :update, step: (tuple.metadata || {})[:step].to_i + 1, writes: values }
|
|
66
|
+
|
|
67
|
+
@checkpointer.put(tuple.config, checkpoint, metadata, {})
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def get_graph
|
|
71
|
+
{ nodes: @builder.nodes.keys, edges: @builder.edges.to_a.map { |e| [e.source, e.target] } }
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
private
|
|
75
|
+
|
|
76
|
+
def _run_pregel(input, config:, recursion_limit:, &event_handler)
|
|
77
|
+
state = _initialize_state(input, config)
|
|
78
|
+
step = 0
|
|
79
|
+
current_nodes = _resolve_entry_nodes(state, config)
|
|
80
|
+
|
|
81
|
+
_save_checkpoint(config, state, step, :input, current_nodes)
|
|
82
|
+
|
|
83
|
+
while current_nodes.any? { |n| !_is_terminal?(n) }
|
|
84
|
+
_check_recursion_limit!(step, recursion_limit)
|
|
85
|
+
runnable_nodes = current_nodes.reject { |n| _is_terminal?(n) }
|
|
86
|
+
_check_interrupts_before!(runnable_nodes, config, state, step)
|
|
87
|
+
|
|
88
|
+
state, next_nodes = _execute_superstep(runnable_nodes, state, config, step, &event_handler)
|
|
89
|
+
|
|
90
|
+
_check_interrupts_after!(runnable_nodes, config, state, step)
|
|
91
|
+
|
|
92
|
+
current_nodes = next_nodes.uniq
|
|
93
|
+
step += 1
|
|
94
|
+
_save_checkpoint(config, state, step, :loop, current_nodes)
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
_save_checkpoint(config, state, step, :exit, [])
|
|
98
|
+
event_handler.call({ type: :values, step: step, state: state })
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def _execute_superstep(runnable_nodes, state, config, step, &event_handler)
|
|
102
|
+
state_snapshot = _deep_dup(state)
|
|
103
|
+
step_updates = {}
|
|
104
|
+
next_nodes_from_commands = []
|
|
105
|
+
sends = []
|
|
106
|
+
|
|
107
|
+
runnable_nodes.each do |node_name|
|
|
108
|
+
result = _execute_node(node_name, state_snapshot, config)
|
|
109
|
+
_collect_results(result, node_name, step_updates, next_nodes_from_commands, sends)
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
step_updates.each_value { |updates| state = _apply_updates(state, updates) }
|
|
113
|
+
|
|
114
|
+
event_handler.call({ type: :updates, step: step, updates: step_updates.transform_keys(&:to_s) })
|
|
115
|
+
event_handler.call({ type: :values, step: step, state: state })
|
|
116
|
+
|
|
117
|
+
next_nodes = _resolve_next_nodes_for_step(runnable_nodes, state, config)
|
|
118
|
+
next_nodes.concat(next_nodes_from_commands)
|
|
119
|
+
|
|
120
|
+
state = _execute_sends(sends, state, config)
|
|
121
|
+
|
|
122
|
+
[state, next_nodes]
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def _execute_node(node_name, state_snapshot, config)
|
|
126
|
+
node = @builder.nodes[node_name]
|
|
127
|
+
raise InvalidGraphError.new("Node '#{node_name}' not found") unless node
|
|
128
|
+
|
|
129
|
+
node.call(state_snapshot, config)
|
|
130
|
+
rescue GraphInterrupt, GraphRecursionError
|
|
131
|
+
raise
|
|
132
|
+
rescue => e
|
|
133
|
+
raise NodeExecutionError.new(node_name, e)
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def _collect_results(result, node_name, step_updates, next_nodes_from_commands, sends)
|
|
137
|
+
node_updates, node_commands, node_sends = _process_result(result)
|
|
138
|
+
|
|
139
|
+
step_updates[node_name] = node_updates if node_updates && !node_updates.empty?
|
|
140
|
+
|
|
141
|
+
node_commands.each do |cmd|
|
|
142
|
+
_collect_command(cmd, node_name, step_updates, next_nodes_from_commands, sends)
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
sends.concat(node_sends)
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def _collect_command(cmd, node_name, step_updates, next_nodes_from_commands, sends)
|
|
149
|
+
if cmd.update
|
|
150
|
+
cmd_updates = cmd.update.is_a?(Hash) ? cmd.update : {}
|
|
151
|
+
step_updates["#{node_name}:command"] = cmd_updates unless cmd_updates.empty?
|
|
152
|
+
end
|
|
153
|
+
cmd.goto.each do |target|
|
|
154
|
+
if target.is_a?(GraphAgent::Send)
|
|
155
|
+
sends << target
|
|
156
|
+
else
|
|
157
|
+
next_nodes_from_commands << target.to_s
|
|
158
|
+
end
|
|
159
|
+
end
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
def _execute_sends(sends, state, config)
|
|
163
|
+
sends.each do |send_obj|
|
|
164
|
+
target = send_obj.node.to_s
|
|
165
|
+
next unless @builder.nodes.key?(target)
|
|
166
|
+
|
|
167
|
+
send_state = send_obj.arg.is_a?(Hash) ? _apply_updates(_deep_dup(state), send_obj.arg) : state
|
|
168
|
+
result = @builder.nodes[target].call(send_state, config)
|
|
169
|
+
node_updates, = _process_result(result)
|
|
170
|
+
state = _apply_updates(state, node_updates) if node_updates
|
|
171
|
+
end
|
|
172
|
+
state
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
def _check_recursion_limit!(step, limit)
|
|
176
|
+
return unless step >= limit
|
|
177
|
+
|
|
178
|
+
raise GraphRecursionError.new("Recursion limit of #{limit} reached without hitting END node")
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
def _check_interrupts_before!(runnable_nodes, config, state, step)
|
|
182
|
+
runnable_nodes.each do |node_name|
|
|
183
|
+
next unless _should_interrupt_before?(node_name)
|
|
184
|
+
|
|
185
|
+
_save_checkpoint(config, state, step, :interrupt, [node_name])
|
|
186
|
+
raise GraphInterrupt.new([Interrupt.new("Interrupted before '#{node_name}'")])
|
|
187
|
+
end
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
def _check_interrupts_after!(runnable_nodes, config, state, step)
|
|
191
|
+
runnable_nodes.each do |node_name|
|
|
192
|
+
next unless _should_interrupt_after?(node_name)
|
|
193
|
+
|
|
194
|
+
_save_checkpoint(config, state, step, :interrupt, [])
|
|
195
|
+
raise GraphInterrupt.new([Interrupt.new("Interrupted after '#{node_name}'")])
|
|
196
|
+
end
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
def _emit_stream_event(event, stream_mode)
|
|
200
|
+
case stream_mode
|
|
201
|
+
when :values
|
|
202
|
+
yield event[:state].dup if event[:type] == :values
|
|
203
|
+
when :updates
|
|
204
|
+
yield event[:updates] if event[:type] == :updates && event[:updates]
|
|
205
|
+
when :debug
|
|
206
|
+
yield event
|
|
207
|
+
end
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
def _process_result(result)
|
|
211
|
+
updates = {}
|
|
212
|
+
commands = []
|
|
213
|
+
sends = []
|
|
214
|
+
|
|
215
|
+
case result
|
|
216
|
+
when Hash then updates = result
|
|
217
|
+
when Command then commands << result
|
|
218
|
+
updates = result.update if result.update.is_a?(Hash)
|
|
219
|
+
when Send then sends << result
|
|
220
|
+
when Array then _process_array_result(result, updates, commands, sends)
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
[updates, commands, sends]
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
def _process_array_result(result, updates, commands, sends)
|
|
227
|
+
result.each do |item|
|
|
228
|
+
case item
|
|
229
|
+
when Hash then updates.merge!(item)
|
|
230
|
+
when Command then commands << item
|
|
231
|
+
updates.merge!(item.update) if item.update.is_a?(Hash)
|
|
232
|
+
when Send then sends << item
|
|
233
|
+
end
|
|
234
|
+
end
|
|
235
|
+
end
|
|
236
|
+
|
|
237
|
+
def _initialize_state(input, config)
|
|
238
|
+
state = _restore_from_checkpoint(input, config)
|
|
239
|
+
return state if state
|
|
240
|
+
|
|
241
|
+
if @builder.schema.is_a?(State::Schema)
|
|
242
|
+
state = @builder.schema.initial_state
|
|
243
|
+
_apply_updates(state, input) if input.is_a?(Hash)
|
|
244
|
+
state
|
|
245
|
+
elsif input.is_a?(Hash)
|
|
246
|
+
input.transform_keys(&:to_sym).dup
|
|
247
|
+
else
|
|
248
|
+
{}
|
|
249
|
+
end
|
|
250
|
+
end
|
|
251
|
+
|
|
252
|
+
def _restore_from_checkpoint(input, config)
|
|
253
|
+
return nil unless @checkpointer && config.dig(:configurable, :thread_id)
|
|
254
|
+
|
|
255
|
+
tuple = @checkpointer.get_tuple(config)
|
|
256
|
+
return nil unless tuple
|
|
257
|
+
|
|
258
|
+
state = tuple.checkpoint[:channel_values]&.dup || {}
|
|
259
|
+
input.is_a?(Hash) ? _apply_updates(state, input) : state
|
|
260
|
+
end
|
|
261
|
+
|
|
262
|
+
def _apply_updates(state, updates)
|
|
263
|
+
return state unless updates.is_a?(Hash) && state.is_a?(Hash)
|
|
264
|
+
|
|
265
|
+
updates.each do |key, value|
|
|
266
|
+
key = key.to_sym
|
|
267
|
+
field = @builder.schema.is_a?(State::Schema) ? @builder.schema.fields[key] : nil
|
|
268
|
+
|
|
269
|
+
if field&.reducer
|
|
270
|
+
state[key] = field.reducer.call(state[key], value)
|
|
271
|
+
else
|
|
272
|
+
state[key] = value
|
|
273
|
+
end
|
|
274
|
+
end
|
|
275
|
+
|
|
276
|
+
state
|
|
277
|
+
end
|
|
278
|
+
|
|
279
|
+
def _resolve_entry_nodes(state, config)
|
|
280
|
+
nodes = @builder.edges.select { |e| e.source == START.to_s }.map(&:target)
|
|
281
|
+
|
|
282
|
+
if @builder.branches.key?(START.to_s)
|
|
283
|
+
@builder.branches[START.to_s].each_value do |branch|
|
|
284
|
+
nodes.concat(Array(branch.resolve(state, config)).map(&:to_s))
|
|
285
|
+
end
|
|
286
|
+
end
|
|
287
|
+
|
|
288
|
+
nodes.uniq
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
def _resolve_next_nodes_for_step(executed_nodes, state, config)
|
|
292
|
+
next_nodes = []
|
|
293
|
+
|
|
294
|
+
executed_nodes.each do |node_name|
|
|
295
|
+
@builder.edges.each { |edge| next_nodes << edge.target if edge.source == node_name }
|
|
296
|
+
_resolve_branches(node_name, state, config, next_nodes)
|
|
297
|
+
end
|
|
298
|
+
|
|
299
|
+
_resolve_waiting_edges(executed_nodes, next_nodes)
|
|
300
|
+
next_nodes.uniq
|
|
301
|
+
end
|
|
302
|
+
|
|
303
|
+
def _resolve_branches(node_name, state, config, next_nodes)
|
|
304
|
+
return unless @builder.branches.key?(node_name)
|
|
305
|
+
|
|
306
|
+
@builder.branches[node_name].each_value do |branch|
|
|
307
|
+
next_nodes.concat(Array(branch.resolve(state, config)).map(&:to_s))
|
|
308
|
+
end
|
|
309
|
+
end
|
|
310
|
+
|
|
311
|
+
def _resolve_waiting_edges(executed_nodes, next_nodes)
|
|
312
|
+
@builder.waiting_edges.each do |sources, target|
|
|
313
|
+
next_nodes << target if sources.all? { |s| executed_nodes.include?(s) }
|
|
314
|
+
end
|
|
315
|
+
end
|
|
316
|
+
|
|
317
|
+
def _is_terminal?(node_name)
|
|
318
|
+
node_name == END_NODE.to_s
|
|
319
|
+
end
|
|
320
|
+
|
|
321
|
+
def _should_interrupt_before?(node_name)
|
|
322
|
+
@interrupt_before.include?("*") || @interrupt_before.include?(node_name)
|
|
323
|
+
end
|
|
324
|
+
|
|
325
|
+
def _should_interrupt_after?(node_name)
|
|
326
|
+
@interrupt_after.include?("*") || @interrupt_after.include?(node_name)
|
|
327
|
+
end
|
|
328
|
+
|
|
329
|
+
def _normalize_interrupt(value)
|
|
330
|
+
case value
|
|
331
|
+
when Array then value.map(&:to_s)
|
|
332
|
+
when String, Symbol then [value.to_s]
|
|
333
|
+
when nil then []
|
|
334
|
+
else Array(value).map(&:to_s)
|
|
335
|
+
end
|
|
336
|
+
end
|
|
337
|
+
|
|
338
|
+
def _save_checkpoint(config, state, step, source, next_nodes)
|
|
339
|
+
return unless @checkpointer && config.dig(:configurable, :thread_id)
|
|
340
|
+
|
|
341
|
+
checkpoint = _build_checkpoint(_deep_dup(state), next_nodes)
|
|
342
|
+
metadata = { source: source, step: step, parents: {} }
|
|
343
|
+
@checkpointer.put(config, checkpoint, metadata, {})
|
|
344
|
+
end
|
|
345
|
+
|
|
346
|
+
def _build_checkpoint(channel_values, next_nodes)
|
|
347
|
+
{ id: SecureRandom.uuid, channel_values: channel_values, channel_versions: {},
|
|
348
|
+
versions_seen: {}, next_nodes: next_nodes }
|
|
349
|
+
end
|
|
350
|
+
|
|
351
|
+
def _deep_dup(obj)
|
|
352
|
+
case obj
|
|
353
|
+
when Hash then obj.transform_values { |v| _deep_dup(v) }
|
|
354
|
+
when Array then obj.map { |v| _deep_dup(v) }
|
|
355
|
+
when String then obj.dup
|
|
356
|
+
when Integer, Float, Symbol, TrueClass, FalseClass, NilClass then obj
|
|
357
|
+
else obj.respond_to?(:dup) ? obj.dup : obj
|
|
358
|
+
end
|
|
359
|
+
end
|
|
360
|
+
end
|
|
361
|
+
end
|
|
362
|
+
end
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GraphAgent
|
|
4
|
+
module Graph
|
|
5
|
+
class ConditionalEdge
|
|
6
|
+
attr_reader :source, :path, :path_map
|
|
7
|
+
|
|
8
|
+
def initialize(source, path, path_map: nil)
|
|
9
|
+
@source = source.to_s
|
|
10
|
+
@path = path
|
|
11
|
+
@path_map = path_map
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def resolve(state, config = {})
|
|
15
|
+
result = invoke_path(state, config)
|
|
16
|
+
_resolve_result(result)
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
private
|
|
20
|
+
|
|
21
|
+
def invoke_path(state, config)
|
|
22
|
+
unless @path.is_a?(Proc) || @path.respond_to?(:call)
|
|
23
|
+
raise GraphError.new("Conditional edge path is not callable")
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
case @path.arity
|
|
27
|
+
when 0 then @path.call
|
|
28
|
+
when 1, -1 then @path.call(state)
|
|
29
|
+
else @path.call(state, config)
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def _resolve_result(result)
|
|
34
|
+
case result
|
|
35
|
+
when GraphAgent::Send then [result]
|
|
36
|
+
when Array then _resolve_array(result)
|
|
37
|
+
else @path_map ? _map_result(result) : result
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def _resolve_array(result)
|
|
42
|
+
return result if result.all? { |r| r.is_a?(GraphAgent::Send) }
|
|
43
|
+
return result.map { |r| _map_result(r) } if @path_map
|
|
44
|
+
|
|
45
|
+
result
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def _map_result(result)
|
|
49
|
+
mapped = @path_map[result]
|
|
50
|
+
mapped = @path_map[:default] if mapped.nil? && @path_map.key?(:default)
|
|
51
|
+
raise InvalidGraphError.new("Unknown path result: #{result.inspect}") if mapped.nil?
|
|
52
|
+
|
|
53
|
+
mapped
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
end
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GraphAgent
|
|
4
|
+
module Graph
|
|
5
|
+
class Edge
|
|
6
|
+
attr_reader :source, :target
|
|
7
|
+
|
|
8
|
+
def initialize(source, target)
|
|
9
|
+
@source = source.to_s
|
|
10
|
+
@target = target.to_s
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def ==(other)
|
|
14
|
+
other.is_a?(Edge) && source == other.source && target == other.target
|
|
15
|
+
end
|
|
16
|
+
alias eql? ==
|
|
17
|
+
|
|
18
|
+
def hash
|
|
19
|
+
[source, target].hash
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "set"
|
|
4
|
+
|
|
5
|
+
module GraphAgent
|
|
6
|
+
module Graph
|
|
7
|
+
class MermaidVisualizer
|
|
8
|
+
START_LABEL = "START"
|
|
9
|
+
END_LABEL = "END"
|
|
10
|
+
CONDITION_PREFIX = "cond_"
|
|
11
|
+
|
|
12
|
+
# Generate Mermaid diagram for a StateGraph
|
|
13
|
+
def self.render(state_graph, options = {})
|
|
14
|
+
new(state_graph, options).render
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def initialize(state_graph, options = {})
|
|
18
|
+
@graph = state_graph
|
|
19
|
+
@options = options
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def render
|
|
23
|
+
lines = ["graph TD"]
|
|
24
|
+
lines << _style_definitions
|
|
25
|
+
lines << ""
|
|
26
|
+
|
|
27
|
+
# Render entry point (START)
|
|
28
|
+
lines << _node_definition(START.to_s, START_LABEL, :start)
|
|
29
|
+
|
|
30
|
+
# Render all graph nodes
|
|
31
|
+
@graph.nodes.each do |name, node|
|
|
32
|
+
lines << _node_definition(name, _node_label(name, node), :node)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Render exit point (END)
|
|
36
|
+
lines << _node_definition(END_NODE.to_s, END_LABEL, :end)
|
|
37
|
+
|
|
38
|
+
lines << ""
|
|
39
|
+
|
|
40
|
+
# Render regular edges
|
|
41
|
+
@graph.edges.each do |edge|
|
|
42
|
+
lines << _edge_definition(edge)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# Render conditional edges (branches)
|
|
46
|
+
@graph.branches.each do |source, branches|
|
|
47
|
+
branches.each_with_index do |(name, branch), idx|
|
|
48
|
+
cond_id = "#{source}_#{CONDITION_PREFIX}#{idx}"
|
|
49
|
+
lines << _node_definition(cond_id, _condition_label(branch), :condition)
|
|
50
|
+
|
|
51
|
+
# Edge from source to condition node
|
|
52
|
+
lines << " #{_safe_id(source)} --> #{_safe_id(cond_id)}"
|
|
53
|
+
|
|
54
|
+
# Edges from condition to targets
|
|
55
|
+
if branch.path_map
|
|
56
|
+
_render_path_map_edges(branch, cond_id, lines)
|
|
57
|
+
else
|
|
58
|
+
# Simple condition - render with note
|
|
59
|
+
lines << " #{_safe_id(cond_id)} -.->|condition| #{_safe_id(source)}_next"
|
|
60
|
+
lines << " #{_safe_id(source)}_next[\"?\"]"
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
# Render waiting edges (multi-source edges)
|
|
66
|
+
@graph.waiting_edges.each do |(sources, target)|
|
|
67
|
+
sources.each do |source|
|
|
68
|
+
lines << " #{_safe_id(source)} --> #{_safe_id(target)}"
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
lines.join("\n")
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
private
|
|
76
|
+
|
|
77
|
+
def _style_definitions
|
|
78
|
+
<<~STYLES
|
|
79
|
+
classDef start fill:#e1f5e1,stroke:#4caf50,stroke-width:2px
|
|
80
|
+
classDef end fill:#ffebee,stroke:#f44336,stroke-width:2px
|
|
81
|
+
classDef node fill:#e3f2fd,stroke:#2196f3,stroke-width:2px,rx:5px
|
|
82
|
+
classDef condition fill:#fff9c4,stroke:#ffc107,stroke-width:2px
|
|
83
|
+
|
|
84
|
+
class #{_safe_id(START.to_s)} start
|
|
85
|
+
class #{_safe_id(END_NODE.to_s)} end
|
|
86
|
+
STYLES
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
def _node_definition(id, label, type)
|
|
90
|
+
safe_id = _safe_id(id)
|
|
91
|
+
case type
|
|
92
|
+
when :condition
|
|
93
|
+
" #{safe_id}{#{label.inspect}}"
|
|
94
|
+
else
|
|
95
|
+
" #{safe_id}[#{label.inspect}]"
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
def _edge_definition(edge)
|
|
100
|
+
" #{_safe_id(edge.source)} --> #{_safe_id(edge.target)}"
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def _render_path_map_edges(branch, cond_id, lines)
|
|
104
|
+
branch.path_map.each do |condition, target|
|
|
105
|
+
next if condition == :default
|
|
106
|
+
|
|
107
|
+
lines << " #{_safe_id(cond_id)} -.->|#{condition}| #{_safe_id(target)}"
|
|
108
|
+
end
|
|
109
|
+
# Handle default path
|
|
110
|
+
if branch.path_map.key?(:default)
|
|
111
|
+
lines << " #{_safe_id(cond_id)} -.->|default| #{_safe_id(branch.path_map[:default])}"
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
def _safe_id(id)
|
|
116
|
+
# Escape special characters for Mermaid
|
|
117
|
+
id.to_s.gsub(/[^a-zA-Z0-9_]/, "_")
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def _node_label(name, node)
|
|
121
|
+
# Use node name or try to extract a readable name from action
|
|
122
|
+
if @options[:show_node_names]
|
|
123
|
+
name
|
|
124
|
+
else
|
|
125
|
+
_extract_readable_name(name, node)
|
|
126
|
+
end
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
def _extract_readable_name(name, node)
|
|
130
|
+
# Try to get a readable name from the action
|
|
131
|
+
action = node.action
|
|
132
|
+
if action.respond_to?(:name) && !action.name.nil? && !action.name.to_s.empty?
|
|
133
|
+
action.name.to_s.split("::").last
|
|
134
|
+
else
|
|
135
|
+
name
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def _condition_label(branch)
|
|
140
|
+
# Try to extract a meaningful label from the condition
|
|
141
|
+
path = branch.path
|
|
142
|
+
if path.respond_to?(:name) && !path.name.nil? && !path.name.empty?
|
|
143
|
+
path.name
|
|
144
|
+
elsif branch.path_map
|
|
145
|
+
# Show the possible conditions
|
|
146
|
+
conditions = branch.path_map.keys - [:default]
|
|
147
|
+
conditions.join(" / ")
|
|
148
|
+
else
|
|
149
|
+
"?"
|
|
150
|
+
end
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
end
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GraphAgent
|
|
4
|
+
module Graph
|
|
5
|
+
class MessagesState < State::Schema
|
|
6
|
+
def initialize
|
|
7
|
+
super
|
|
8
|
+
field :messages, type: Array, reducer: GraphAgent::Reducers.method(:add_messages), default: []
|
|
9
|
+
end
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
class MessageGraph < StateGraph
|
|
13
|
+
def initialize
|
|
14
|
+
super(MessagesState.new)
|
|
15
|
+
end
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
end
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GraphAgent
|
|
4
|
+
module Graph
|
|
5
|
+
class Node
|
|
6
|
+
attr_reader :name, :action, :metadata, :retry_policy, :cache_policy
|
|
7
|
+
|
|
8
|
+
def initialize(name, action, metadata: nil, retry_policy: nil, cache_policy: nil)
|
|
9
|
+
@name = name.to_s
|
|
10
|
+
@action = action
|
|
11
|
+
@metadata = metadata || {}
|
|
12
|
+
@retry_policy = retry_policy
|
|
13
|
+
@cache_policy = cache_policy
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def call(state, config = {})
|
|
17
|
+
result = execute_with_retry(state, config)
|
|
18
|
+
normalize_result(result)
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
private
|
|
22
|
+
|
|
23
|
+
def execute_with_retry(state, config)
|
|
24
|
+
if @retry_policy
|
|
25
|
+
attempt = 0
|
|
26
|
+
begin
|
|
27
|
+
attempt += 1
|
|
28
|
+
invoke_action(state, config)
|
|
29
|
+
rescue => e
|
|
30
|
+
if attempt < @retry_policy.max_attempts && @retry_policy.should_retry?(e)
|
|
31
|
+
sleep(@retry_policy.interval_for(attempt))
|
|
32
|
+
retry
|
|
33
|
+
end
|
|
34
|
+
raise
|
|
35
|
+
end
|
|
36
|
+
else
|
|
37
|
+
invoke_action(state, config)
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def invoke_action(state, config)
|
|
42
|
+
if @action.is_a?(Proc) || @action.respond_to?(:call)
|
|
43
|
+
case @action.arity
|
|
44
|
+
when 0
|
|
45
|
+
@action.call
|
|
46
|
+
when 1, -1
|
|
47
|
+
@action.call(state)
|
|
48
|
+
else
|
|
49
|
+
@action.call(state, config)
|
|
50
|
+
end
|
|
51
|
+
else
|
|
52
|
+
raise GraphError.new("Node '#{name}' action is not callable")
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def normalize_result(result)
|
|
57
|
+
result
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|