shortgraph 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
shortgraph/__init__.py ADDED
@@ -0,0 +1,38 @@
1
+ # ShortGraph Core Package
2
+
3
+ # Core Graph Dimensions (LangGraph-aligned)
4
+ from shortgraph.graph.state import StateGraph, START, END
5
+ from shortgraph.pregel.main import Pregel
6
+ from shortgraph.channels.manager import AgentState
7
+ from shortgraph.checkpoint.base import BaseCheckpoint, MemoryCheckpoint
8
+
9
+ # Paradigms
10
+ from shortgraph.paradigms.smart_agent import SmartAgent
11
+
12
+ # Models
13
+ from shortgraph.models.base import BaseLLM
14
+ from shortgraph.models.factory import LLMFactory
15
+
16
+ # Tools
17
+ from shortgraph.tools.base import Tool, tool
18
+ from shortgraph.prebuilt.tool_node import ToolNode
19
+
20
+ # Memory
21
+ from shortgraph.memory.buffer import WindowBufferMemory
22
+
23
+ __all__ = [
24
+ "StateGraph",
25
+ "Pregel",
26
+ "START",
27
+ "END",
28
+ "AgentState",
29
+ "BaseCheckpoint",
30
+ "MemoryCheckpoint",
31
+ "SmartAgent",
32
+ "BaseLLM",
33
+ "LLMFactory",
34
+ "Tool",
35
+ "tool",
36
+ "ToolNode",
37
+ "WindowBufferMemory",
38
+ ]
@@ -0,0 +1,3 @@
1
+ from .manager import AgentState
2
+
3
+ __all__ = ["AgentState"]
@@ -0,0 +1,92 @@
1
+ from typing import Dict, Any, List, Optional, Callable, Type
2
+ from dataclasses import dataclass, field, asdict
3
+ import datetime
4
+ import operator
5
+ import copy
6
+
7
+ def default_reducer(current: Any, new: Any) -> Any:
8
+ """Default: Overwrite."""
9
+ return new
10
+
11
+ def append_reducer(current: list, new: list) -> list:
12
+ """Append new items to current list."""
13
+ if current is None:
14
+ current = []
15
+ if new is None:
16
+ return current
17
+ return current + new
18
+
19
+ @dataclass
20
+ class Channel:
21
+ """
22
+ Defines how a specific field in the state should be updated.
23
+ """
24
+ name: str
25
+ reducer: Callable[[Any, Any], Any] = default_reducer
26
+ default: Any = None
27
+
28
+ class AgentState:
29
+ """
30
+ Enhanced AgentState with reducer support.
31
+ """
32
+ def __init__(self, data: Optional[Dict[str, Any]] = None, channels: Optional[Dict[str, Channel]] = None):
33
+ self._data = data or {}
34
+ self._channels = channels or {}
35
+
36
+ # Ensure default channels exist
37
+ if "messages" not in self._channels:
38
+ self._channels["messages"] = Channel("messages", reducer=append_reducer, default=[])
39
+ if "history" not in self._channels:
40
+ self._channels["history"] = Channel("history", reducer=append_reducer, default=[])
41
+
42
+ # Initialize defaults
43
+ for name, channel in self._channels.items():
44
+ if name not in self._data:
45
+ self._data[name] = channel.default
46
+
47
+ def update(self, updates: Dict[str, Any]):
48
+ """
49
+ Apply updates using the defined reducers.
50
+ """
51
+ for key, value in updates.items():
52
+ if key in self._channels:
53
+ current_val = self._data.get(key, self._channels[key].default)
54
+ self._data[key] = self._channels[key].reducer(current_val, value)
55
+ else:
56
+ # Default behavior for unknown keys: overwrite
57
+ self._data[key] = value
58
+
59
+ def get(self, key: str, default: Any = None) -> Any:
60
+ return self._data.get(key, default)
61
+
62
+ def __getitem__(self, key: str) -> Any:
63
+ return self._data[key]
64
+
65
+ def __setitem__(self, key: str, value: Any):
66
+ self.update({key: value})
67
+
68
+ def to_dict(self) -> Dict[str, Any]:
69
+ return copy.deepcopy(self._data)
70
+
71
+ @classmethod
72
+ def from_dict(cls, data: Dict[str, Any]) -> 'AgentState':
73
+ return cls(data=data)
74
+
75
+ # Helper methods for backward compatibility
76
+ @property
77
+ def messages(self) -> List[Dict[str, Any]]:
78
+ return self.get("messages", [])
79
+
80
+ @property
81
+ def history(self) -> List[str]:
82
+ return self.get("history", [])
83
+
84
+ def add_message(self, role: str, content: str):
85
+ msg = {"role": role, "content": content, "timestamp": datetime.datetime.now().isoformat()}
86
+ self.update({"messages": [msg]}) # List for append_reducer
87
+
88
+ def set_variable(self, key: str, value: Any):
89
+ self.update({key: value})
90
+
91
+ def get_variable(self, key: str, default: Any = None) -> Any:
92
+ return self.get(key, default)
@@ -0,0 +1,3 @@
1
+ from .base import BaseCheckpoint, MemoryCheckpoint
2
+
3
+ __all__ = ["BaseCheckpoint", "MemoryCheckpoint"]
@@ -0,0 +1,115 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Optional
3
+ from dataclasses import dataclass, field
4
+ import json
5
+ import copy
6
+
7
+ @dataclass
8
+ class Snapshot:
9
+ state: Dict[str, Any]
10
+ next_node: Optional[str] = None
11
+ metadata: Dict[str, Any] = field(default_factory=dict)
12
+
13
+ class BaseCheckpoint(ABC):
14
+ """
15
+ Abstract base class for saving and loading graph state.
16
+ """
17
+
18
+ @abstractmethod
19
+ def put(self, thread_id: str, checkpoint_id: str, snapshot: Dict[str, Any]) -> None:
20
+ """Save a snapshot (state + execution metadata)."""
21
+ pass
22
+
23
+ @abstractmethod
24
+ def get(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
25
+ """
26
+ Retrieve a snapshot.
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ def list_history(self, thread_id: str) -> Dict[str, Any]:
32
+ """List all checkpoints for a thread."""
33
+ pass
34
+
35
+ class MemoryCheckpoint(BaseCheckpoint):
36
+ """
37
+ In-memory implementation.
38
+ """
39
+ def __init__(self):
40
+ # Structure: {thread_id: {checkpoint_id: snapshot_dict}}
41
+ self.storage: Dict[str, Dict[str, Any]] = {}
42
+ # Track order: {thread_id: [checkpoint_id_1, checkpoint_id_2]}
43
+ self.order: Dict[str, list] = {}
44
+
45
+ def put(self, thread_id: str, checkpoint_id: str, snapshot: Dict[str, Any]) -> None:
46
+ if thread_id not in self.storage:
47
+ self.storage[thread_id] = {}
48
+ self.order[thread_id] = []
49
+
50
+ self.storage[thread_id][checkpoint_id] = copy.deepcopy(snapshot)
51
+ self.order[thread_id].append(checkpoint_id)
52
+
53
+ def get(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
54
+ if thread_id not in self.storage:
55
+ return None
56
+
57
+ if checkpoint_id is None:
58
+ # Get latest
59
+ if not self.order[thread_id]:
60
+ return None
61
+ latest_id = self.order[thread_id][-1]
62
+ return copy.deepcopy(self.storage[thread_id][latest_id])
63
+
64
+ return copy.deepcopy(self.storage[thread_id].get(checkpoint_id))
65
+
66
+ def list_history(self, thread_id: str) -> list:
67
+ if thread_id not in self.order:
68
+ return []
69
+ return list(self.order[thread_id])
70
+
71
+ class FileCheckpoint(BaseCheckpoint):
72
+ """
73
+ Simple file-based checkpointing.
74
+ """
75
+ def __init__(self, filepath: str = "checkpoints.json"):
76
+ self.filepath = filepath
77
+ self._load()
78
+
79
+ def _load(self):
80
+ try:
81
+ with open(self.filepath, 'r', encoding='utf-8') as f:
82
+ data = json.load(f)
83
+ self.storage = data.get("storage", {})
84
+ self.order = data.get("order", {})
85
+ except FileNotFoundError:
86
+ self.storage = {}
87
+ self.order = {}
88
+
89
+ def _save(self):
90
+ with open(self.filepath, 'w', encoding='utf-8') as f:
91
+ json.dump({"storage": self.storage, "order": self.order}, f, indent=2, default=str)
92
+
93
+ def put(self, thread_id: str, checkpoint_id: str, snapshot: Dict[str, Any]) -> None:
94
+ if thread_id not in self.storage:
95
+ self.storage[thread_id] = {}
96
+ self.order[thread_id] = []
97
+
98
+ self.storage[thread_id][checkpoint_id] = snapshot
99
+ self.order[thread_id].append(checkpoint_id)
100
+ self._save()
101
+
102
+ def get(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
103
+ if thread_id not in self.storage:
104
+ return None
105
+
106
+ if checkpoint_id is None:
107
+ if not self.order[thread_id]:
108
+ return None
109
+ latest_id = self.order[thread_id][-1]
110
+ return self.storage[thread_id][latest_id]
111
+
112
+ return self.storage[thread_id].get(checkpoint_id)
113
+
114
+ def list_history(self, thread_id: str) -> list:
115
+ return self.order.get(thread_id, [])
@@ -0,0 +1,3 @@
1
+ from .state import StateGraph, START, END
2
+
3
+ __all__ = ["StateGraph", "START", "END"]
@@ -0,0 +1,75 @@
1
+ from typing import Dict, Any, List, Callable, Union, Optional
2
+ import time
3
+ from shortgraph.channels.manager import AgentState
4
+ from shortgraph.checkpoint.base import BaseCheckpoint
5
+ from shortgraph.pregel.main import Pregel
6
+
7
+ # Special constants
8
+ START = "__start__"
9
+ END = "__end__"
10
+
11
+ class StateGraph:
12
+ """
13
+ A lightweight graph engine for defining agent workflows.
14
+ Structure definition only.
15
+ """
16
+ def __init__(self):
17
+ self.nodes: Dict[str, Callable] = {}
18
+ self.edges: Dict[str, str] = {}
19
+ self.conditional_edges: Dict[str, Callable] = {}
20
+ self.conditional_edge_metadata: Dict[str, Dict[str, str]] = {} # Store path_map for visualization
21
+ self.entry_point: str = ""
22
+ self.node_metadata: Dict[str, Dict[str, Any]] = {}
23
+
24
+ def add_node(self, name: str, func: Union[Callable, 'Pregel'], retry: int = 0):
25
+ """
26
+ Add a node. Can be a function or another Pregel graph (SubGraph).
27
+ :param retry: Number of times to retry if the node raises an exception.
28
+ """
29
+ self.nodes[name] = func
30
+ self.node_metadata[name] = {"retry": retry}
31
+
32
+ def add_edge(self, start_key: str, end_key: str):
33
+ if start_key == END:
34
+ raise ValueError("END node cannot have outgoing edges")
35
+ self.edges[start_key] = end_key
36
+
37
+ def add_conditional_edges(
38
+ self,
39
+ source_key: str,
40
+ condition_func: Callable[[AgentState], str],
41
+ path_map: Optional[Dict[str, str]] = None
42
+ ):
43
+ if source_key == END:
44
+ raise ValueError("END node cannot have outgoing edges")
45
+
46
+ def router(state: AgentState) -> str:
47
+ result = condition_func(state)
48
+ if path_map:
49
+ return path_map.get(result, result)
50
+ return result
51
+
52
+ self.conditional_edges[source_key] = router
53
+ if path_map:
54
+ self.conditional_edge_metadata[source_key] = path_map
55
+
56
+ def set_entry_point(self, key: str):
57
+ self.entry_point = key
58
+
59
+ def compile(
60
+ self,
61
+ checkpointer: Optional[BaseCheckpoint] = None,
62
+ interrupt_before: Optional[List[str]] = None,
63
+ interrupt_after: Optional[List[str]] = None
64
+ ) -> Pregel:
65
+ if not self.entry_point:
66
+ raise ValueError("Graph must have an entry point.")
67
+
68
+ pregel_app = Pregel(
69
+ self,
70
+ checkpointer,
71
+ interrupt_before=interrupt_before,
72
+ interrupt_after=interrupt_after
73
+ )
74
+
75
+ return pregel_app
@@ -0,0 +1,3 @@
1
+ from .buffer import WindowBufferMemory
2
+
3
+ __all__ = ["WindowBufferMemory"]
@@ -0,0 +1,60 @@
1
+ from typing import List, Dict, Any, Optional
2
+ from abc import ABC, abstractmethod
3
+
4
+ class BaseMemory(ABC):
5
+ """
6
+ Abstract base class for memory management.
7
+ """
8
+ @abstractmethod
9
+ def load(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
10
+ """Load messages from state and prepare them for LLM (e.g., trimming)."""
11
+ pass
12
+
13
+ @abstractmethod
14
+ def save(self, state: Dict[str, Any], new_messages: List[Dict[str, Any]]) -> Dict[str, Any]:
15
+ """Save new messages to state."""
16
+ pass
17
+
18
+ class WindowBufferMemory(BaseMemory):
19
+ """
20
+ Keeps a sliding window of the most recent k messages.
21
+ Preserves the System message if present.
22
+ """
23
+ def __init__(self, window_size: int = 10):
24
+ self.window_size = window_size
25
+
26
+ def load(self, state_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
27
+ """
28
+ Returns the most recent k messages.
29
+ Always keeps the first message if it's a 'system' message.
30
+ """
31
+ if not state_messages:
32
+ return []
33
+
34
+ if len(state_messages) <= self.window_size:
35
+ return state_messages
36
+
37
+ # Check for system message
38
+ result = []
39
+ start_idx = 0
40
+ if state_messages[0].get("role") == "system":
41
+ result.append(state_messages[0])
42
+ start_idx = 1
43
+
44
+ # Calculate how many to take from the end
45
+ remaining_slots = self.window_size - len(result)
46
+ if remaining_slots > 0:
47
+ # Take the last 'remaining_slots' messages
48
+ recent = state_messages[-remaining_slots:]
49
+ # Ensure we don't duplicate if overlap (though logic above handles it)
50
+ result.extend(recent)
51
+
52
+ return result
53
+
54
+ def save(self, current_messages: List[Dict[str, Any]], new_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
55
+ """
56
+ In ShortGraph, state updates are usually handled by Reducers (append).
57
+ So this method might be used if we were manually managing state,
58
+ but usually we just return what needs to be added.
59
+ """
60
+ return new_messages
@@ -0,0 +1,5 @@
1
+ from .base import BaseLLM
2
+ from .factory import LLMFactory
3
+ from .openai import OpenAILLM
4
+
5
+ __all__ = ["BaseLLM", "LLMFactory", "OpenAILLM"]
@@ -0,0 +1,33 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any, Optional, Union
3
+
4
+ class BaseLLM(ABC):
5
+ """
6
+ Abstract base class for LLMs to ensure neutrality.
7
+ """
8
+
9
+ @abstractmethod
10
+ def generate(self, prompt: str, stop: Optional[List[str]] = None) -> str:
11
+ """
12
+ Generate text based on prompt.
13
+ """
14
+ pass
15
+
16
+ @abstractmethod
17
+ def chat(self, messages: List[Dict[str, str]], stop: Optional[List[str]] = None) -> str:
18
+ """
19
+ Chat completion interface.
20
+ messages format: [{"role": "user", "content": "hello"}, ...]
21
+ """
22
+ pass
23
+
24
+ class MockLLM(BaseLLM):
25
+ """
26
+ A mock LLM for testing purposes.
27
+ """
28
+ def generate(self, prompt: str, stop: Optional[List[str]] = None) -> str:
29
+ return f"Mock response to: {prompt[:20]}..."
30
+
31
+ def chat(self, messages: List[Dict[str, str]], stop: Optional[List[str]] = None) -> str:
32
+ last_msg = messages[-1]['content']
33
+ return f"Mock chat response to: {last_msg[:20]}..."
@@ -0,0 +1,102 @@
1
+ from typing import Optional, Any
2
+ import os
3
+ from shortgraph.models.base import BaseLLM
4
+ from shortgraph.models.openai import OpenAILLM
5
+
6
+ class LLMFactory:
7
+ """
8
+ A simple factory to create LLM instances for various providers.
9
+ """
10
+
11
+ @staticmethod
12
+ def create_openai(
13
+ api_key: Optional[str] = None,
14
+ model: str = "gpt-3.5-turbo"
15
+ ) -> BaseLLM:
16
+ """Create a standard OpenAI LLM instance."""
17
+ return OpenAILLM(api_key=api_key, model=model)
18
+
19
+ @staticmethod
20
+ def create_deepseek(
21
+ api_key: Optional[str] = None,
22
+ model: str = "deepseek-chat"
23
+ ) -> BaseLLM:
24
+ """
25
+ Create a DeepSeek LLM instance.
26
+ Default Base URL: https://api.deepseek.com
27
+ """
28
+ return OpenAILLM(
29
+ api_key=api_key or os.environ.get("DEEPSEEK_API_KEY"),
30
+ base_url="https://api.deepseek.com",
31
+ model=model
32
+ )
33
+
34
+ @staticmethod
35
+ def create_moonshot(
36
+ api_key: Optional[str] = None,
37
+ model: str = "moonshot-v1-8k"
38
+ ) -> BaseLLM:
39
+ """
40
+ Create a Moonshot (Kimi) LLM instance.
41
+ Default Base URL: https://api.moonshot.cn/v1
42
+ """
43
+ return OpenAILLM(
44
+ api_key=api_key or os.environ.get("MOONSHOT_API_KEY"),
45
+ base_url="https://api.moonshot.cn/v1",
46
+ model=model
47
+ )
48
+
49
+ @staticmethod
50
+ def create_zhipu(
51
+ api_key: Optional[str] = None,
52
+ model: str = "glm-4"
53
+ ) -> BaseLLM:
54
+ """
55
+ Create a ZhipuAI (ChatGLM) LLM instance.
56
+ Default Base URL: https://open.bigmodel.cn/api/paas/v4/
57
+ """
58
+ return OpenAILLM(
59
+ api_key=api_key or os.environ.get("ZHIPU_API_KEY"),
60
+ base_url="https://open.bigmodel.cn/api/paas/v4/",
61
+ model=model
62
+ )
63
+
64
+ @staticmethod
65
+ def create_ollama(
66
+ base_url: str = "http://localhost:11434/v1",
67
+ model: str = "llama3"
68
+ ) -> BaseLLM:
69
+ """
70
+ Create a local Ollama LLM instance (via OpenAI compatibility).
71
+ """
72
+ return OpenAILLM(
73
+ api_key="ollama", # API key is required but ignored by Ollama
74
+ base_url=base_url,
75
+ model=model
76
+ )
77
+
78
+ @staticmethod
79
+ def create_custom(
80
+ api_key: str,
81
+ base_url: str,
82
+ model: str
83
+ ) -> BaseLLM:
84
+ """Create a custom OpenAI-compatible LLM instance."""
85
+ return OpenAILLM(api_key=api_key, base_url=base_url, model=model)
86
+
87
+ @staticmethod
88
+ def create_mock() -> BaseLLM:
89
+ """Create a Mock LLM for testing/demo purposes."""
90
+ class MockSmartLLM(BaseLLM):
91
+ def generate(self, prompt: str, stop: Optional[list] = None) -> str:
92
+ # Simple logic to simulate weather tool usage
93
+ if "Observation:" in prompt:
94
+ return "Final Answer: 根据天气预报, 北京今天天气不错, 晴转多云, 适合出行!"
95
+ return 'Thought: 用户想查询北京的天气, 我需要调用 get_weather 工具.\nAction: get_weather\nAction Input: "北京"'
96
+
97
+ def chat(self, messages: list, stop: Optional[list] = None) -> str:
98
+ # Fallback to generate for chat interface
99
+ prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
100
+ return self.generate(prompt)
101
+
102
+ return MockSmartLLM()
@@ -0,0 +1,49 @@
1
+ from typing import List, Dict, Optional, Any
2
+ import os
3
+ try:
4
+ from openai import OpenAI
5
+ except ImportError:
6
+ OpenAI = None
7
+
8
+ from shortgraph.models.base import BaseLLM
9
+
10
+ class OpenAILLM(BaseLLM):
11
+ """
12
+ Adapter for OpenAI's GPT models and OpenAI-compatible APIs (DeepSeek, Moonshot, etc.).
13
+ """
14
+ def __init__(
15
+ self,
16
+ api_key: Optional[str] = None,
17
+ model: str = "gpt-3.5-turbo",
18
+ base_url: Optional[str] = None
19
+ ):
20
+ if OpenAI is None:
21
+ raise ImportError("Please install openai package: `pip install openai`")
22
+
23
+ self.client = OpenAI(
24
+ api_key=api_key or os.environ.get("OPENAI_API_KEY"),
25
+ base_url=base_url or os.environ.get("OPENAI_BASE_URL")
26
+ )
27
+ self.model = model
28
+
29
+ def generate(self, prompt: str, stop: Optional[List[str]] = None) -> str:
30
+ try:
31
+ response = self.client.chat.completions.create(
32
+ model=self.model,
33
+ messages=[{"role": "user", "content": prompt}],
34
+ stop=stop
35
+ )
36
+ return response.choices[0].message.content
37
+ except Exception as e:
38
+ return f"Error generating response: {str(e)}"
39
+
40
+ def chat(self, messages: List[Dict[str, str]], stop: Optional[List[str]] = None) -> str:
41
+ try:
42
+ response = self.client.chat.completions.create(
43
+ model=self.model,
44
+ messages=messages,
45
+ stop=stop
46
+ )
47
+ return response.choices[0].message.content
48
+ except Exception as e:
49
+ return f"Error chatting: {str(e)}"
@@ -0,0 +1,6 @@
1
+ from .smart_agent import SmartAgent
2
+ from .react import ReActAgent
3
+ from .plan_execute import PlanExecuteAgent
4
+ from .graph_react import GraphReActAgent
5
+
6
+ __all__ = ["SmartAgent", "ReActAgent", "PlanExecuteAgent", "GraphReActAgent"]