graph-agent 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. checksums.yaml +7 -0
  2. data/.github/workflows/ci.yml +50 -0
  3. data/.github/workflows/release.yml +49 -0
  4. data/.gitignore +6 -0
  5. data/.rspec +3 -0
  6. data/.rubocop.yml +126 -0
  7. data/CHANGELOG.md +26 -0
  8. data/CLAUDE.md +128 -0
  9. data/Gemfile +11 -0
  10. data/Gemfile.lock +94 -0
  11. data/LICENSE +21 -0
  12. data/Makefile +114 -0
  13. data/README.md +464 -0
  14. data/Rakefile +15 -0
  15. data/docs/README.md +55 -0
  16. data/docs/api_reference.md +832 -0
  17. data/docs/concepts.md +216 -0
  18. data/docs/edges.md +265 -0
  19. data/docs/error_handling.md +241 -0
  20. data/docs/human_in_the_loop.md +231 -0
  21. data/docs/persistence.md +276 -0
  22. data/docs/quickstart.md +154 -0
  23. data/docs/send_and_command.md +218 -0
  24. data/docs/state.md +181 -0
  25. data/docs/streaming.md +172 -0
  26. data/graph-agent.gemspec +48 -0
  27. data/lib/graph_agent/channels/base_channel.rb +52 -0
  28. data/lib/graph_agent/channels/binary_operator_aggregate.rb +56 -0
  29. data/lib/graph_agent/channels/ephemeral_value.rb +59 -0
  30. data/lib/graph_agent/channels/last_value.rb +49 -0
  31. data/lib/graph_agent/channels/topic.rb +58 -0
  32. data/lib/graph_agent/checkpoint/base_saver.rb +38 -0
  33. data/lib/graph_agent/checkpoint/in_memory_saver.rb +145 -0
  34. data/lib/graph_agent/constants.rb +9 -0
  35. data/lib/graph_agent/errors.rb +41 -0
  36. data/lib/graph_agent/graph/compiled_state_graph.rb +362 -0
  37. data/lib/graph_agent/graph/conditional_edge.rb +57 -0
  38. data/lib/graph_agent/graph/edge.rb +23 -0
  39. data/lib/graph_agent/graph/mermaid_visualizer.rb +154 -0
  40. data/lib/graph_agent/graph/message_graph.rb +18 -0
  41. data/lib/graph_agent/graph/node.rb +61 -0
  42. data/lib/graph_agent/graph/state_graph.rb +197 -0
  43. data/lib/graph_agent/reducers.rb +34 -0
  44. data/lib/graph_agent/state/schema.rb +54 -0
  45. data/lib/graph_agent/types/cache_policy.rb +12 -0
  46. data/lib/graph_agent/types/command.rb +26 -0
  47. data/lib/graph_agent/types/interrupt.rb +28 -0
  48. data/lib/graph_agent/types/retry_policy.rb +42 -0
  49. data/lib/graph_agent/types/send.rb +26 -0
  50. data/lib/graph_agent/types/state_snapshot.rb +28 -0
  51. data/lib/graph_agent/version.rb +5 -0
  52. data/lib/graph_agent.rb +29 -0
  53. metadata +158 -0
@@ -0,0 +1,362 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "securerandom"
4
+
5
+ module GraphAgent
6
+ module Graph
7
+ class CompiledStateGraph # rubocop:disable Metrics/ClassLength
8
+ DEFAULT_RECURSION_LIMIT = 25
9
+
10
+ attr_reader :builder, :checkpointer
11
+
12
+ def initialize(builder:, checkpointer: nil, interrupt_before: [], interrupt_after: [], debug: false)
13
+ @builder = builder
14
+ @checkpointer = checkpointer
15
+ @interrupt_before = _normalize_interrupt(interrupt_before)
16
+ @interrupt_after = _normalize_interrupt(interrupt_after)
17
+ @debug = debug
18
+ end
19
+
20
+ def invoke(input, config: {}, recursion_limit: DEFAULT_RECURSION_LIMIT)
21
+ last_state = nil
22
+
23
+ _run_pregel(input, config: config, recursion_limit: recursion_limit) do |_event|
24
+ last_state = _event[:state] if _event[:type] == :values
25
+ end
26
+
27
+ last_state
28
+ end
29
+
30
+ def stream(input, config: {}, recursion_limit: DEFAULT_RECURSION_LIMIT, stream_mode: :values, &block)
31
+ unless block
32
+ return enum_for(:stream, input, config: config, recursion_limit: recursion_limit, stream_mode: stream_mode)
33
+ end
34
+
35
+ _run_pregel(input, config: config, recursion_limit: recursion_limit) do |event|
36
+ _emit_stream_event(event, stream_mode, &block)
37
+ end
38
+ end
39
+
40
+ def get_state(config)
41
+ return nil unless @checkpointer
42
+
43
+ tuple = @checkpointer.get_tuple(config)
44
+ return nil unless tuple
45
+
46
+ StateSnapshot.new(
47
+ values: tuple.checkpoint[:channel_values] || {},
48
+ config: tuple.config,
49
+ metadata: tuple.metadata || {},
50
+ parent_config: tuple.parent_config,
51
+ next_nodes: tuple.checkpoint[:next_nodes] || []
52
+ )
53
+ end
54
+
55
+ def update_state(config, values, as_node: nil)
56
+ return nil unless @checkpointer
57
+
58
+ tuple = @checkpointer.get_tuple(config)
59
+ return nil unless tuple
60
+
61
+ current_state = tuple.checkpoint[:channel_values] || {}
62
+ new_state = _apply_updates(current_state, values)
63
+
64
+ checkpoint = _build_checkpoint(new_state, [])
65
+ metadata = { source: :update, step: (tuple.metadata || {})[:step].to_i + 1, writes: values }
66
+
67
+ @checkpointer.put(tuple.config, checkpoint, metadata, {})
68
+ end
69
+
70
+ def get_graph
71
+ { nodes: @builder.nodes.keys, edges: @builder.edges.to_a.map { |e| [e.source, e.target] } }
72
+ end
73
+
74
+ private
75
+
76
+ def _run_pregel(input, config:, recursion_limit:, &event_handler)
77
+ state = _initialize_state(input, config)
78
+ step = 0
79
+ current_nodes = _resolve_entry_nodes(state, config)
80
+
81
+ _save_checkpoint(config, state, step, :input, current_nodes)
82
+
83
+ while current_nodes.any? { |n| !_is_terminal?(n) }
84
+ _check_recursion_limit!(step, recursion_limit)
85
+ runnable_nodes = current_nodes.reject { |n| _is_terminal?(n) }
86
+ _check_interrupts_before!(runnable_nodes, config, state, step)
87
+
88
+ state, next_nodes = _execute_superstep(runnable_nodes, state, config, step, &event_handler)
89
+
90
+ _check_interrupts_after!(runnable_nodes, config, state, step)
91
+
92
+ current_nodes = next_nodes.uniq
93
+ step += 1
94
+ _save_checkpoint(config, state, step, :loop, current_nodes)
95
+ end
96
+
97
+ _save_checkpoint(config, state, step, :exit, [])
98
+ event_handler.call({ type: :values, step: step, state: state })
99
+ end
100
+
101
+ def _execute_superstep(runnable_nodes, state, config, step, &event_handler)
102
+ state_snapshot = _deep_dup(state)
103
+ step_updates = {}
104
+ next_nodes_from_commands = []
105
+ sends = []
106
+
107
+ runnable_nodes.each do |node_name|
108
+ result = _execute_node(node_name, state_snapshot, config)
109
+ _collect_results(result, node_name, step_updates, next_nodes_from_commands, sends)
110
+ end
111
+
112
+ step_updates.each_value { |updates| state = _apply_updates(state, updates) }
113
+
114
+ event_handler.call({ type: :updates, step: step, updates: step_updates.transform_keys(&:to_s) })
115
+ event_handler.call({ type: :values, step: step, state: state })
116
+
117
+ next_nodes = _resolve_next_nodes_for_step(runnable_nodes, state, config)
118
+ next_nodes.concat(next_nodes_from_commands)
119
+
120
+ state = _execute_sends(sends, state, config)
121
+
122
+ [state, next_nodes]
123
+ end
124
+
125
+ def _execute_node(node_name, state_snapshot, config)
126
+ node = @builder.nodes[node_name]
127
+ raise InvalidGraphError.new("Node '#{node_name}' not found") unless node
128
+
129
+ node.call(state_snapshot, config)
130
+ rescue GraphInterrupt, GraphRecursionError
131
+ raise
132
+ rescue => e
133
+ raise NodeExecutionError.new(node_name, e)
134
+ end
135
+
136
+ def _collect_results(result, node_name, step_updates, next_nodes_from_commands, sends)
137
+ node_updates, node_commands, node_sends = _process_result(result)
138
+
139
+ step_updates[node_name] = node_updates if node_updates && !node_updates.empty?
140
+
141
+ node_commands.each do |cmd|
142
+ _collect_command(cmd, node_name, step_updates, next_nodes_from_commands, sends)
143
+ end
144
+
145
+ sends.concat(node_sends)
146
+ end
147
+
148
+ def _collect_command(cmd, node_name, step_updates, next_nodes_from_commands, sends)
149
+ if cmd.update
150
+ cmd_updates = cmd.update.is_a?(Hash) ? cmd.update : {}
151
+ step_updates["#{node_name}:command"] = cmd_updates unless cmd_updates.empty?
152
+ end
153
+ cmd.goto.each do |target|
154
+ if target.is_a?(GraphAgent::Send)
155
+ sends << target
156
+ else
157
+ next_nodes_from_commands << target.to_s
158
+ end
159
+ end
160
+ end
161
+
162
+ def _execute_sends(sends, state, config)
163
+ sends.each do |send_obj|
164
+ target = send_obj.node.to_s
165
+ next unless @builder.nodes.key?(target)
166
+
167
+ send_state = send_obj.arg.is_a?(Hash) ? _apply_updates(_deep_dup(state), send_obj.arg) : state
168
+ result = @builder.nodes[target].call(send_state, config)
169
+ node_updates, = _process_result(result)
170
+ state = _apply_updates(state, node_updates) if node_updates
171
+ end
172
+ state
173
+ end
174
+
175
+ def _check_recursion_limit!(step, limit)
176
+ return unless step >= limit
177
+
178
+ raise GraphRecursionError.new("Recursion limit of #{limit} reached without hitting END node")
179
+ end
180
+
181
+ def _check_interrupts_before!(runnable_nodes, config, state, step)
182
+ runnable_nodes.each do |node_name|
183
+ next unless _should_interrupt_before?(node_name)
184
+
185
+ _save_checkpoint(config, state, step, :interrupt, [node_name])
186
+ raise GraphInterrupt.new([Interrupt.new("Interrupted before '#{node_name}'")])
187
+ end
188
+ end
189
+
190
+ def _check_interrupts_after!(runnable_nodes, config, state, step)
191
+ runnable_nodes.each do |node_name|
192
+ next unless _should_interrupt_after?(node_name)
193
+
194
+ _save_checkpoint(config, state, step, :interrupt, [])
195
+ raise GraphInterrupt.new([Interrupt.new("Interrupted after '#{node_name}'")])
196
+ end
197
+ end
198
+
199
+ def _emit_stream_event(event, stream_mode)
200
+ case stream_mode
201
+ when :values
202
+ yield event[:state].dup if event[:type] == :values
203
+ when :updates
204
+ yield event[:updates] if event[:type] == :updates && event[:updates]
205
+ when :debug
206
+ yield event
207
+ end
208
+ end
209
+
210
+ def _process_result(result)
211
+ updates = {}
212
+ commands = []
213
+ sends = []
214
+
215
+ case result
216
+ when Hash then updates = result
217
+ when Command then commands << result
218
+ updates = result.update if result.update.is_a?(Hash)
219
+ when Send then sends << result
220
+ when Array then _process_array_result(result, updates, commands, sends)
221
+ end
222
+
223
+ [updates, commands, sends]
224
+ end
225
+
226
+ def _process_array_result(result, updates, commands, sends)
227
+ result.each do |item|
228
+ case item
229
+ when Hash then updates.merge!(item)
230
+ when Command then commands << item
231
+ updates.merge!(item.update) if item.update.is_a?(Hash)
232
+ when Send then sends << item
233
+ end
234
+ end
235
+ end
236
+
237
+ def _initialize_state(input, config)
238
+ state = _restore_from_checkpoint(input, config)
239
+ return state if state
240
+
241
+ if @builder.schema.is_a?(State::Schema)
242
+ state = @builder.schema.initial_state
243
+ _apply_updates(state, input) if input.is_a?(Hash)
244
+ state
245
+ elsif input.is_a?(Hash)
246
+ input.transform_keys(&:to_sym).dup
247
+ else
248
+ {}
249
+ end
250
+ end
251
+
252
+ def _restore_from_checkpoint(input, config)
253
+ return nil unless @checkpointer && config.dig(:configurable, :thread_id)
254
+
255
+ tuple = @checkpointer.get_tuple(config)
256
+ return nil unless tuple
257
+
258
+ state = tuple.checkpoint[:channel_values]&.dup || {}
259
+ input.is_a?(Hash) ? _apply_updates(state, input) : state
260
+ end
261
+
262
+ def _apply_updates(state, updates)
263
+ return state unless updates.is_a?(Hash) && state.is_a?(Hash)
264
+
265
+ updates.each do |key, value|
266
+ key = key.to_sym
267
+ field = @builder.schema.is_a?(State::Schema) ? @builder.schema.fields[key] : nil
268
+
269
+ if field&.reducer
270
+ state[key] = field.reducer.call(state[key], value)
271
+ else
272
+ state[key] = value
273
+ end
274
+ end
275
+
276
+ state
277
+ end
278
+
279
+ def _resolve_entry_nodes(state, config)
280
+ nodes = @builder.edges.select { |e| e.source == START.to_s }.map(&:target)
281
+
282
+ if @builder.branches.key?(START.to_s)
283
+ @builder.branches[START.to_s].each_value do |branch|
284
+ nodes.concat(Array(branch.resolve(state, config)).map(&:to_s))
285
+ end
286
+ end
287
+
288
+ nodes.uniq
289
+ end
290
+
291
+ def _resolve_next_nodes_for_step(executed_nodes, state, config)
292
+ next_nodes = []
293
+
294
+ executed_nodes.each do |node_name|
295
+ @builder.edges.each { |edge| next_nodes << edge.target if edge.source == node_name }
296
+ _resolve_branches(node_name, state, config, next_nodes)
297
+ end
298
+
299
+ _resolve_waiting_edges(executed_nodes, next_nodes)
300
+ next_nodes.uniq
301
+ end
302
+
303
+ def _resolve_branches(node_name, state, config, next_nodes)
304
+ return unless @builder.branches.key?(node_name)
305
+
306
+ @builder.branches[node_name].each_value do |branch|
307
+ next_nodes.concat(Array(branch.resolve(state, config)).map(&:to_s))
308
+ end
309
+ end
310
+
311
+ def _resolve_waiting_edges(executed_nodes, next_nodes)
312
+ @builder.waiting_edges.each do |sources, target|
313
+ next_nodes << target if sources.all? { |s| executed_nodes.include?(s) }
314
+ end
315
+ end
316
+
317
+ def _is_terminal?(node_name)
318
+ node_name == END_NODE.to_s
319
+ end
320
+
321
+ def _should_interrupt_before?(node_name)
322
+ @interrupt_before.include?("*") || @interrupt_before.include?(node_name)
323
+ end
324
+
325
+ def _should_interrupt_after?(node_name)
326
+ @interrupt_after.include?("*") || @interrupt_after.include?(node_name)
327
+ end
328
+
329
+ def _normalize_interrupt(value)
330
+ case value
331
+ when Array then value.map(&:to_s)
332
+ when String, Symbol then [value.to_s]
333
+ when nil then []
334
+ else Array(value).map(&:to_s)
335
+ end
336
+ end
337
+
338
+ def _save_checkpoint(config, state, step, source, next_nodes)
339
+ return unless @checkpointer && config.dig(:configurable, :thread_id)
340
+
341
+ checkpoint = _build_checkpoint(_deep_dup(state), next_nodes)
342
+ metadata = { source: source, step: step, parents: {} }
343
+ @checkpointer.put(config, checkpoint, metadata, {})
344
+ end
345
+
346
+ def _build_checkpoint(channel_values, next_nodes)
347
+ { id: SecureRandom.uuid, channel_values: channel_values, channel_versions: {},
348
+ versions_seen: {}, next_nodes: next_nodes }
349
+ end
350
+
351
+ def _deep_dup(obj)
352
+ case obj
353
+ when Hash then obj.transform_values { |v| _deep_dup(v) }
354
+ when Array then obj.map { |v| _deep_dup(v) }
355
+ when String then obj.dup
356
+ when Integer, Float, Symbol, TrueClass, FalseClass, NilClass then obj
357
+ else obj.respond_to?(:dup) ? obj.dup : obj
358
+ end
359
+ end
360
+ end
361
+ end
362
+ end
@@ -0,0 +1,57 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GraphAgent
4
+ module Graph
5
+ class ConditionalEdge
6
+ attr_reader :source, :path, :path_map
7
+
8
+ def initialize(source, path, path_map: nil)
9
+ @source = source.to_s
10
+ @path = path
11
+ @path_map = path_map
12
+ end
13
+
14
+ def resolve(state, config = {})
15
+ result = invoke_path(state, config)
16
+ _resolve_result(result)
17
+ end
18
+
19
+ private
20
+
21
+ def invoke_path(state, config)
22
+ unless @path.is_a?(Proc) || @path.respond_to?(:call)
23
+ raise GraphError.new("Conditional edge path is not callable")
24
+ end
25
+
26
+ case @path.arity
27
+ when 0 then @path.call
28
+ when 1, -1 then @path.call(state)
29
+ else @path.call(state, config)
30
+ end
31
+ end
32
+
33
+ def _resolve_result(result)
34
+ case result
35
+ when GraphAgent::Send then [result]
36
+ when Array then _resolve_array(result)
37
+ else @path_map ? _map_result(result) : result
38
+ end
39
+ end
40
+
41
+ def _resolve_array(result)
42
+ return result if result.all? { |r| r.is_a?(GraphAgent::Send) }
43
+ return result.map { |r| _map_result(r) } if @path_map
44
+
45
+ result
46
+ end
47
+
48
+ def _map_result(result)
49
+ mapped = @path_map[result]
50
+ mapped = @path_map[:default] if mapped.nil? && @path_map.key?(:default)
51
+ raise InvalidGraphError.new("Unknown path result: #{result.inspect}") if mapped.nil?
52
+
53
+ mapped
54
+ end
55
+ end
56
+ end
57
+ end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GraphAgent
4
+ module Graph
5
+ class Edge
6
+ attr_reader :source, :target
7
+
8
+ def initialize(source, target)
9
+ @source = source.to_s
10
+ @target = target.to_s
11
+ end
12
+
13
+ def ==(other)
14
+ other.is_a?(Edge) && source == other.source && target == other.target
15
+ end
16
+ alias eql? ==
17
+
18
+ def hash
19
+ [source, target].hash
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,154 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "set"
4
+
5
+ module GraphAgent
6
+ module Graph
7
+ class MermaidVisualizer
8
+ START_LABEL = "START"
9
+ END_LABEL = "END"
10
+ CONDITION_PREFIX = "cond_"
11
+
12
+ # Generate Mermaid diagram for a StateGraph
13
+ def self.render(state_graph, options = {})
14
+ new(state_graph, options).render
15
+ end
16
+
17
+ def initialize(state_graph, options = {})
18
+ @graph = state_graph
19
+ @options = options
20
+ end
21
+
22
+ def render
23
+ lines = ["graph TD"]
24
+ lines << _style_definitions
25
+ lines << ""
26
+
27
+ # Render entry point (START)
28
+ lines << _node_definition(START.to_s, START_LABEL, :start)
29
+
30
+ # Render all graph nodes
31
+ @graph.nodes.each do |name, node|
32
+ lines << _node_definition(name, _node_label(name, node), :node)
33
+ end
34
+
35
+ # Render exit point (END)
36
+ lines << _node_definition(END_NODE.to_s, END_LABEL, :end)
37
+
38
+ lines << ""
39
+
40
+ # Render regular edges
41
+ @graph.edges.each do |edge|
42
+ lines << _edge_definition(edge)
43
+ end
44
+
45
+ # Render conditional edges (branches)
46
+ @graph.branches.each do |source, branches|
47
+ branches.each_with_index do |(name, branch), idx|
48
+ cond_id = "#{source}_#{CONDITION_PREFIX}#{idx}"
49
+ lines << _node_definition(cond_id, _condition_label(branch), :condition)
50
+
51
+ # Edge from source to condition node
52
+ lines << " #{_safe_id(source)} --> #{_safe_id(cond_id)}"
53
+
54
+ # Edges from condition to targets
55
+ if branch.path_map
56
+ _render_path_map_edges(branch, cond_id, lines)
57
+ else
58
+ # Simple condition - render with note
59
+ lines << " #{_safe_id(cond_id)} -.->|condition| #{_safe_id(source)}_next"
60
+ lines << " #{_safe_id(source)}_next[\"?\"]"
61
+ end
62
+ end
63
+ end
64
+
65
+ # Render waiting edges (multi-source edges)
66
+ @graph.waiting_edges.each do |(sources, target)|
67
+ sources.each do |source|
68
+ lines << " #{_safe_id(source)} --> #{_safe_id(target)}"
69
+ end
70
+ end
71
+
72
+ lines.join("\n")
73
+ end
74
+
75
+ private
76
+
77
+ def _style_definitions
78
+ <<~STYLES
79
+ classDef start fill:#e1f5e1,stroke:#4caf50,stroke-width:2px
80
+ classDef end fill:#ffebee,stroke:#f44336,stroke-width:2px
81
+ classDef node fill:#e3f2fd,stroke:#2196f3,stroke-width:2px,rx:5px
82
+ classDef condition fill:#fff9c4,stroke:#ffc107,stroke-width:2px
83
+
84
+ class #{_safe_id(START.to_s)} start
85
+ class #{_safe_id(END_NODE.to_s)} end
86
+ STYLES
87
+ end
88
+
89
+ def _node_definition(id, label, type)
90
+ safe_id = _safe_id(id)
91
+ case type
92
+ when :condition
93
+ " #{safe_id}{#{label.inspect}}"
94
+ else
95
+ " #{safe_id}[#{label.inspect}]"
96
+ end
97
+ end
98
+
99
+ def _edge_definition(edge)
100
+ " #{_safe_id(edge.source)} --> #{_safe_id(edge.target)}"
101
+ end
102
+
103
+ def _render_path_map_edges(branch, cond_id, lines)
104
+ branch.path_map.each do |condition, target|
105
+ next if condition == :default
106
+
107
+ lines << " #{_safe_id(cond_id)} -.->|#{condition}| #{_safe_id(target)}"
108
+ end
109
+ # Handle default path
110
+ if branch.path_map.key?(:default)
111
+ lines << " #{_safe_id(cond_id)} -.->|default| #{_safe_id(branch.path_map[:default])}"
112
+ end
113
+ end
114
+
115
+ def _safe_id(id)
116
+ # Escape special characters for Mermaid
117
+ id.to_s.gsub(/[^a-zA-Z0-9_]/, "_")
118
+ end
119
+
120
+ def _node_label(name, node)
121
+ # Use node name or try to extract a readable name from action
122
+ if @options[:show_node_names]
123
+ name
124
+ else
125
+ _extract_readable_name(name, node)
126
+ end
127
+ end
128
+
129
+ def _extract_readable_name(name, node)
130
+ # Try to get a readable name from the action
131
+ action = node.action
132
+ if action.respond_to?(:name) && !action.name.nil? && !action.name.to_s.empty?
133
+ action.name.to_s.split("::").last
134
+ else
135
+ name
136
+ end
137
+ end
138
+
139
+ def _condition_label(branch)
140
+ # Try to extract a meaningful label from the condition
141
+ path = branch.path
142
+ if path.respond_to?(:name) && !path.name.nil? && !path.name.empty?
143
+ path.name
144
+ elsif branch.path_map
145
+ # Show the possible conditions
146
+ conditions = branch.path_map.keys - [:default]
147
+ conditions.join(" / ")
148
+ else
149
+ "?"
150
+ end
151
+ end
152
+ end
153
+ end
154
+ end
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GraphAgent
4
+ module Graph
5
+ class MessagesState < State::Schema
6
+ def initialize
7
+ super
8
+ field :messages, type: Array, reducer: GraphAgent::Reducers.method(:add_messages), default: []
9
+ end
10
+ end
11
+
12
+ class MessageGraph < StateGraph
13
+ def initialize
14
+ super(MessagesState.new)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,61 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GraphAgent
4
+ module Graph
5
+ class Node
6
+ attr_reader :name, :action, :metadata, :retry_policy, :cache_policy
7
+
8
+ def initialize(name, action, metadata: nil, retry_policy: nil, cache_policy: nil)
9
+ @name = name.to_s
10
+ @action = action
11
+ @metadata = metadata || {}
12
+ @retry_policy = retry_policy
13
+ @cache_policy = cache_policy
14
+ end
15
+
16
+ def call(state, config = {})
17
+ result = execute_with_retry(state, config)
18
+ normalize_result(result)
19
+ end
20
+
21
+ private
22
+
23
+ def execute_with_retry(state, config)
24
+ if @retry_policy
25
+ attempt = 0
26
+ begin
27
+ attempt += 1
28
+ invoke_action(state, config)
29
+ rescue => e
30
+ if attempt < @retry_policy.max_attempts && @retry_policy.should_retry?(e)
31
+ sleep(@retry_policy.interval_for(attempt))
32
+ retry
33
+ end
34
+ raise
35
+ end
36
+ else
37
+ invoke_action(state, config)
38
+ end
39
+ end
40
+
41
+ def invoke_action(state, config)
42
+ if @action.is_a?(Proc) || @action.respond_to?(:call)
43
+ case @action.arity
44
+ when 0
45
+ @action.call
46
+ when 1, -1
47
+ @action.call(state)
48
+ else
49
+ @action.call(state, config)
50
+ end
51
+ else
52
+ raise GraphError.new("Node '#{name}' action is not callable")
53
+ end
54
+ end
55
+
56
+ def normalize_result(result)
57
+ result
58
+ end
59
+ end
60
+ end
61
+ end