shortgraph 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shortgraph/__init__.py +38 -0
- shortgraph/channels/__init__.py +3 -0
- shortgraph/channels/manager.py +92 -0
- shortgraph/checkpoint/__init__.py +3 -0
- shortgraph/checkpoint/base.py +115 -0
- shortgraph/graph/__init__.py +3 -0
- shortgraph/graph/state.py +75 -0
- shortgraph/memory/__init__.py +3 -0
- shortgraph/memory/buffer.py +60 -0
- shortgraph/models/__init__.py +5 -0
- shortgraph/models/base.py +33 -0
- shortgraph/models/factory.py +102 -0
- shortgraph/models/openai.py +49 -0
- shortgraph/paradigms/__init__.py +6 -0
- shortgraph/paradigms/graph_react.py +151 -0
- shortgraph/paradigms/plan_execute.py +146 -0
- shortgraph/paradigms/react.py +90 -0
- shortgraph/paradigms/smart_agent.py +151 -0
- shortgraph/permissions/__init__.py +3 -0
- shortgraph/permissions/guard.py +36 -0
- shortgraph/prebuilt/__init__.py +3 -0
- shortgraph/prebuilt/tool_node.py +58 -0
- shortgraph/pregel/__init__.py +3 -0
- shortgraph/pregel/algo.py +34 -0
- shortgraph/pregel/loop.py +135 -0
- shortgraph/pregel/main.py +53 -0
- shortgraph/tools/__init__.py +3 -0
- shortgraph/tools/base.py +44 -0
- shortgraph/utils/__init__.py +1 -0
- shortgraph-0.1.0.dist-info/METADATA +242 -0
- shortgraph-0.1.0.dist-info/RECORD +34 -0
- shortgraph-0.1.0.dist-info/WHEEL +5 -0
- shortgraph-0.1.0.dist-info/licenses/LICENSE +21 -0
- shortgraph-0.1.0.dist-info/top_level.txt +1 -0
shortgraph/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# ShortGraph Core Package
|
|
2
|
+
|
|
3
|
+
# Core Graph Dimensions (LangGraph-aligned)
|
|
4
|
+
from shortgraph.graph.state import StateGraph, START, END
|
|
5
|
+
from shortgraph.pregel.main import Pregel
|
|
6
|
+
from shortgraph.channels.manager import AgentState
|
|
7
|
+
from shortgraph.checkpoint.base import BaseCheckpoint, MemoryCheckpoint
|
|
8
|
+
|
|
9
|
+
# Paradigms
|
|
10
|
+
from shortgraph.paradigms.smart_agent import SmartAgent
|
|
11
|
+
|
|
12
|
+
# Models
|
|
13
|
+
from shortgraph.models.base import BaseLLM
|
|
14
|
+
from shortgraph.models.factory import LLMFactory
|
|
15
|
+
|
|
16
|
+
# Tools
|
|
17
|
+
from shortgraph.tools.base import Tool, tool
|
|
18
|
+
from shortgraph.prebuilt.tool_node import ToolNode
|
|
19
|
+
|
|
20
|
+
# Memory
|
|
21
|
+
from shortgraph.memory.buffer import WindowBufferMemory
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"StateGraph",
|
|
25
|
+
"Pregel",
|
|
26
|
+
"START",
|
|
27
|
+
"END",
|
|
28
|
+
"AgentState",
|
|
29
|
+
"BaseCheckpoint",
|
|
30
|
+
"MemoryCheckpoint",
|
|
31
|
+
"SmartAgent",
|
|
32
|
+
"BaseLLM",
|
|
33
|
+
"LLMFactory",
|
|
34
|
+
"Tool",
|
|
35
|
+
"tool",
|
|
36
|
+
"ToolNode",
|
|
37
|
+
"WindowBufferMemory",
|
|
38
|
+
]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional, Callable, Type
|
|
2
|
+
from dataclasses import dataclass, field, asdict
|
|
3
|
+
import datetime
|
|
4
|
+
import operator
|
|
5
|
+
import copy
|
|
6
|
+
|
|
7
|
+
def default_reducer(current: Any, new: Any) -> Any:
|
|
8
|
+
"""Default: Overwrite."""
|
|
9
|
+
return new
|
|
10
|
+
|
|
11
|
+
def append_reducer(current: list, new: list) -> list:
|
|
12
|
+
"""Append new items to current list."""
|
|
13
|
+
if current is None:
|
|
14
|
+
current = []
|
|
15
|
+
if new is None:
|
|
16
|
+
return current
|
|
17
|
+
return current + new
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class Channel:
|
|
21
|
+
"""
|
|
22
|
+
Defines how a specific field in the state should be updated.
|
|
23
|
+
"""
|
|
24
|
+
name: str
|
|
25
|
+
reducer: Callable[[Any, Any], Any] = default_reducer
|
|
26
|
+
default: Any = None
|
|
27
|
+
|
|
28
|
+
class AgentState:
|
|
29
|
+
"""
|
|
30
|
+
Enhanced AgentState with reducer support.
|
|
31
|
+
"""
|
|
32
|
+
def __init__(self, data: Optional[Dict[str, Any]] = None, channels: Optional[Dict[str, Channel]] = None):
|
|
33
|
+
self._data = data or {}
|
|
34
|
+
self._channels = channels or {}
|
|
35
|
+
|
|
36
|
+
# Ensure default channels exist
|
|
37
|
+
if "messages" not in self._channels:
|
|
38
|
+
self._channels["messages"] = Channel("messages", reducer=append_reducer, default=[])
|
|
39
|
+
if "history" not in self._channels:
|
|
40
|
+
self._channels["history"] = Channel("history", reducer=append_reducer, default=[])
|
|
41
|
+
|
|
42
|
+
# Initialize defaults
|
|
43
|
+
for name, channel in self._channels.items():
|
|
44
|
+
if name not in self._data:
|
|
45
|
+
self._data[name] = channel.default
|
|
46
|
+
|
|
47
|
+
def update(self, updates: Dict[str, Any]):
|
|
48
|
+
"""
|
|
49
|
+
Apply updates using the defined reducers.
|
|
50
|
+
"""
|
|
51
|
+
for key, value in updates.items():
|
|
52
|
+
if key in self._channels:
|
|
53
|
+
current_val = self._data.get(key, self._channels[key].default)
|
|
54
|
+
self._data[key] = self._channels[key].reducer(current_val, value)
|
|
55
|
+
else:
|
|
56
|
+
# Default behavior for unknown keys: overwrite
|
|
57
|
+
self._data[key] = value
|
|
58
|
+
|
|
59
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
60
|
+
return self._data.get(key, default)
|
|
61
|
+
|
|
62
|
+
def __getitem__(self, key: str) -> Any:
|
|
63
|
+
return self._data[key]
|
|
64
|
+
|
|
65
|
+
def __setitem__(self, key: str, value: Any):
|
|
66
|
+
self.update({key: value})
|
|
67
|
+
|
|
68
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
69
|
+
return copy.deepcopy(self._data)
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_dict(cls, data: Dict[str, Any]) -> 'AgentState':
|
|
73
|
+
return cls(data=data)
|
|
74
|
+
|
|
75
|
+
# Helper methods for backward compatibility
|
|
76
|
+
@property
|
|
77
|
+
def messages(self) -> List[Dict[str, Any]]:
|
|
78
|
+
return self.get("messages", [])
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def history(self) -> List[str]:
|
|
82
|
+
return self.get("history", [])
|
|
83
|
+
|
|
84
|
+
def add_message(self, role: str, content: str):
|
|
85
|
+
msg = {"role": role, "content": content, "timestamp": datetime.datetime.now().isoformat()}
|
|
86
|
+
self.update({"messages": [msg]}) # List for append_reducer
|
|
87
|
+
|
|
88
|
+
def set_variable(self, key: str, value: Any):
|
|
89
|
+
self.update({key: value})
|
|
90
|
+
|
|
91
|
+
def get_variable(self, key: str, default: Any = None) -> Any:
|
|
92
|
+
return self.get(key, default)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, Any, Optional
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
import json
|
|
5
|
+
import copy
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class Snapshot:
|
|
9
|
+
state: Dict[str, Any]
|
|
10
|
+
next_node: Optional[str] = None
|
|
11
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
12
|
+
|
|
13
|
+
class BaseCheckpoint(ABC):
|
|
14
|
+
"""
|
|
15
|
+
Abstract base class for saving and loading graph state.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def put(self, thread_id: str, checkpoint_id: str, snapshot: Dict[str, Any]) -> None:
|
|
20
|
+
"""Save a snapshot (state + execution metadata)."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def get(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
25
|
+
"""
|
|
26
|
+
Retrieve a snapshot.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def list_history(self, thread_id: str) -> Dict[str, Any]:
|
|
32
|
+
"""List all checkpoints for a thread."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
class MemoryCheckpoint(BaseCheckpoint):
|
|
36
|
+
"""
|
|
37
|
+
In-memory implementation.
|
|
38
|
+
"""
|
|
39
|
+
def __init__(self):
|
|
40
|
+
# Structure: {thread_id: {checkpoint_id: snapshot_dict}}
|
|
41
|
+
self.storage: Dict[str, Dict[str, Any]] = {}
|
|
42
|
+
# Track order: {thread_id: [checkpoint_id_1, checkpoint_id_2]}
|
|
43
|
+
self.order: Dict[str, list] = {}
|
|
44
|
+
|
|
45
|
+
def put(self, thread_id: str, checkpoint_id: str, snapshot: Dict[str, Any]) -> None:
|
|
46
|
+
if thread_id not in self.storage:
|
|
47
|
+
self.storage[thread_id] = {}
|
|
48
|
+
self.order[thread_id] = []
|
|
49
|
+
|
|
50
|
+
self.storage[thread_id][checkpoint_id] = copy.deepcopy(snapshot)
|
|
51
|
+
self.order[thread_id].append(checkpoint_id)
|
|
52
|
+
|
|
53
|
+
def get(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
54
|
+
if thread_id not in self.storage:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
if checkpoint_id is None:
|
|
58
|
+
# Get latest
|
|
59
|
+
if not self.order[thread_id]:
|
|
60
|
+
return None
|
|
61
|
+
latest_id = self.order[thread_id][-1]
|
|
62
|
+
return copy.deepcopy(self.storage[thread_id][latest_id])
|
|
63
|
+
|
|
64
|
+
return copy.deepcopy(self.storage[thread_id].get(checkpoint_id))
|
|
65
|
+
|
|
66
|
+
def list_history(self, thread_id: str) -> list:
|
|
67
|
+
if thread_id not in self.order:
|
|
68
|
+
return []
|
|
69
|
+
return list(self.order[thread_id])
|
|
70
|
+
|
|
71
|
+
class FileCheckpoint(BaseCheckpoint):
|
|
72
|
+
"""
|
|
73
|
+
Simple file-based checkpointing.
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self, filepath: str = "checkpoints.json"):
|
|
76
|
+
self.filepath = filepath
|
|
77
|
+
self._load()
|
|
78
|
+
|
|
79
|
+
def _load(self):
|
|
80
|
+
try:
|
|
81
|
+
with open(self.filepath, 'r', encoding='utf-8') as f:
|
|
82
|
+
data = json.load(f)
|
|
83
|
+
self.storage = data.get("storage", {})
|
|
84
|
+
self.order = data.get("order", {})
|
|
85
|
+
except FileNotFoundError:
|
|
86
|
+
self.storage = {}
|
|
87
|
+
self.order = {}
|
|
88
|
+
|
|
89
|
+
def _save(self):
|
|
90
|
+
with open(self.filepath, 'w', encoding='utf-8') as f:
|
|
91
|
+
json.dump({"storage": self.storage, "order": self.order}, f, indent=2, default=str)
|
|
92
|
+
|
|
93
|
+
def put(self, thread_id: str, checkpoint_id: str, snapshot: Dict[str, Any]) -> None:
|
|
94
|
+
if thread_id not in self.storage:
|
|
95
|
+
self.storage[thread_id] = {}
|
|
96
|
+
self.order[thread_id] = []
|
|
97
|
+
|
|
98
|
+
self.storage[thread_id][checkpoint_id] = snapshot
|
|
99
|
+
self.order[thread_id].append(checkpoint_id)
|
|
100
|
+
self._save()
|
|
101
|
+
|
|
102
|
+
def get(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
103
|
+
if thread_id not in self.storage:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
if checkpoint_id is None:
|
|
107
|
+
if not self.order[thread_id]:
|
|
108
|
+
return None
|
|
109
|
+
latest_id = self.order[thread_id][-1]
|
|
110
|
+
return self.storage[thread_id][latest_id]
|
|
111
|
+
|
|
112
|
+
return self.storage[thread_id].get(checkpoint_id)
|
|
113
|
+
|
|
114
|
+
def list_history(self, thread_id: str) -> list:
|
|
115
|
+
return self.order.get(thread_id, [])
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Callable, Union, Optional
|
|
2
|
+
import time
|
|
3
|
+
from shortgraph.channels.manager import AgentState
|
|
4
|
+
from shortgraph.checkpoint.base import BaseCheckpoint
|
|
5
|
+
from shortgraph.pregel.main import Pregel
|
|
6
|
+
|
|
7
|
+
# Special constants
|
|
8
|
+
START = "__start__"
|
|
9
|
+
END = "__end__"
|
|
10
|
+
|
|
11
|
+
class StateGraph:
|
|
12
|
+
"""
|
|
13
|
+
A lightweight graph engine for defining agent workflows.
|
|
14
|
+
Structure definition only.
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.nodes: Dict[str, Callable] = {}
|
|
18
|
+
self.edges: Dict[str, str] = {}
|
|
19
|
+
self.conditional_edges: Dict[str, Callable] = {}
|
|
20
|
+
self.conditional_edge_metadata: Dict[str, Dict[str, str]] = {} # Store path_map for visualization
|
|
21
|
+
self.entry_point: str = ""
|
|
22
|
+
self.node_metadata: Dict[str, Dict[str, Any]] = {}
|
|
23
|
+
|
|
24
|
+
def add_node(self, name: str, func: Union[Callable, 'Pregel'], retry: int = 0):
|
|
25
|
+
"""
|
|
26
|
+
Add a node. Can be a function or another Pregel graph (SubGraph).
|
|
27
|
+
:param retry: Number of times to retry if the node raises an exception.
|
|
28
|
+
"""
|
|
29
|
+
self.nodes[name] = func
|
|
30
|
+
self.node_metadata[name] = {"retry": retry}
|
|
31
|
+
|
|
32
|
+
def add_edge(self, start_key: str, end_key: str):
|
|
33
|
+
if start_key == END:
|
|
34
|
+
raise ValueError("END node cannot have outgoing edges")
|
|
35
|
+
self.edges[start_key] = end_key
|
|
36
|
+
|
|
37
|
+
def add_conditional_edges(
|
|
38
|
+
self,
|
|
39
|
+
source_key: str,
|
|
40
|
+
condition_func: Callable[[AgentState], str],
|
|
41
|
+
path_map: Optional[Dict[str, str]] = None
|
|
42
|
+
):
|
|
43
|
+
if source_key == END:
|
|
44
|
+
raise ValueError("END node cannot have outgoing edges")
|
|
45
|
+
|
|
46
|
+
def router(state: AgentState) -> str:
|
|
47
|
+
result = condition_func(state)
|
|
48
|
+
if path_map:
|
|
49
|
+
return path_map.get(result, result)
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
self.conditional_edges[source_key] = router
|
|
53
|
+
if path_map:
|
|
54
|
+
self.conditional_edge_metadata[source_key] = path_map
|
|
55
|
+
|
|
56
|
+
def set_entry_point(self, key: str):
|
|
57
|
+
self.entry_point = key
|
|
58
|
+
|
|
59
|
+
def compile(
|
|
60
|
+
self,
|
|
61
|
+
checkpointer: Optional[BaseCheckpoint] = None,
|
|
62
|
+
interrupt_before: Optional[List[str]] = None,
|
|
63
|
+
interrupt_after: Optional[List[str]] = None
|
|
64
|
+
) -> Pregel:
|
|
65
|
+
if not self.entry_point:
|
|
66
|
+
raise ValueError("Graph must have an entry point.")
|
|
67
|
+
|
|
68
|
+
pregel_app = Pregel(
|
|
69
|
+
self,
|
|
70
|
+
checkpointer,
|
|
71
|
+
interrupt_before=interrupt_before,
|
|
72
|
+
interrupt_after=interrupt_after
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return pregel_app
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
class BaseMemory(ABC):
|
|
5
|
+
"""
|
|
6
|
+
Abstract base class for memory management.
|
|
7
|
+
"""
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def load(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
10
|
+
"""Load messages from state and prepare them for LLM (e.g., trimming)."""
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def save(self, state: Dict[str, Any], new_messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
15
|
+
"""Save new messages to state."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
class WindowBufferMemory(BaseMemory):
|
|
19
|
+
"""
|
|
20
|
+
Keeps a sliding window of the most recent k messages.
|
|
21
|
+
Preserves the System message if present.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, window_size: int = 10):
|
|
24
|
+
self.window_size = window_size
|
|
25
|
+
|
|
26
|
+
def load(self, state_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
27
|
+
"""
|
|
28
|
+
Returns the most recent k messages.
|
|
29
|
+
Always keeps the first message if it's a 'system' message.
|
|
30
|
+
"""
|
|
31
|
+
if not state_messages:
|
|
32
|
+
return []
|
|
33
|
+
|
|
34
|
+
if len(state_messages) <= self.window_size:
|
|
35
|
+
return state_messages
|
|
36
|
+
|
|
37
|
+
# Check for system message
|
|
38
|
+
result = []
|
|
39
|
+
start_idx = 0
|
|
40
|
+
if state_messages[0].get("role") == "system":
|
|
41
|
+
result.append(state_messages[0])
|
|
42
|
+
start_idx = 1
|
|
43
|
+
|
|
44
|
+
# Calculate how many to take from the end
|
|
45
|
+
remaining_slots = self.window_size - len(result)
|
|
46
|
+
if remaining_slots > 0:
|
|
47
|
+
# Take the last 'remaining_slots' messages
|
|
48
|
+
recent = state_messages[-remaining_slots:]
|
|
49
|
+
# Ensure we don't duplicate if overlap (though logic above handles it)
|
|
50
|
+
result.extend(recent)
|
|
51
|
+
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
def save(self, current_messages: List[Dict[str, Any]], new_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
55
|
+
"""
|
|
56
|
+
In ShortGraph, state updates are usually handled by Reducers (append).
|
|
57
|
+
So this method might be used if we were manually managing state,
|
|
58
|
+
but usually we just return what needs to be added.
|
|
59
|
+
"""
|
|
60
|
+
return new_messages
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Dict, Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
class BaseLLM(ABC):
|
|
5
|
+
"""
|
|
6
|
+
Abstract base class for LLMs to ensure neutrality.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
11
|
+
"""
|
|
12
|
+
Generate text based on prompt.
|
|
13
|
+
"""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def chat(self, messages: List[Dict[str, str]], stop: Optional[List[str]] = None) -> str:
|
|
18
|
+
"""
|
|
19
|
+
Chat completion interface.
|
|
20
|
+
messages format: [{"role": "user", "content": "hello"}, ...]
|
|
21
|
+
"""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
class MockLLM(BaseLLM):
|
|
25
|
+
"""
|
|
26
|
+
A mock LLM for testing purposes.
|
|
27
|
+
"""
|
|
28
|
+
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
29
|
+
return f"Mock response to: {prompt[:20]}..."
|
|
30
|
+
|
|
31
|
+
def chat(self, messages: List[Dict[str, str]], stop: Optional[List[str]] = None) -> str:
|
|
32
|
+
last_msg = messages[-1]['content']
|
|
33
|
+
return f"Mock chat response to: {last_msg[:20]}..."
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
import os
|
|
3
|
+
from shortgraph.models.base import BaseLLM
|
|
4
|
+
from shortgraph.models.openai import OpenAILLM
|
|
5
|
+
|
|
6
|
+
class LLMFactory:
|
|
7
|
+
"""
|
|
8
|
+
A simple factory to create LLM instances for various providers.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def create_openai(
|
|
13
|
+
api_key: Optional[str] = None,
|
|
14
|
+
model: str = "gpt-3.5-turbo"
|
|
15
|
+
) -> BaseLLM:
|
|
16
|
+
"""Create a standard OpenAI LLM instance."""
|
|
17
|
+
return OpenAILLM(api_key=api_key, model=model)
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def create_deepseek(
|
|
21
|
+
api_key: Optional[str] = None,
|
|
22
|
+
model: str = "deepseek-chat"
|
|
23
|
+
) -> BaseLLM:
|
|
24
|
+
"""
|
|
25
|
+
Create a DeepSeek LLM instance.
|
|
26
|
+
Default Base URL: https://api.deepseek.com
|
|
27
|
+
"""
|
|
28
|
+
return OpenAILLM(
|
|
29
|
+
api_key=api_key or os.environ.get("DEEPSEEK_API_KEY"),
|
|
30
|
+
base_url="https://api.deepseek.com",
|
|
31
|
+
model=model
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def create_moonshot(
|
|
36
|
+
api_key: Optional[str] = None,
|
|
37
|
+
model: str = "moonshot-v1-8k"
|
|
38
|
+
) -> BaseLLM:
|
|
39
|
+
"""
|
|
40
|
+
Create a Moonshot (Kimi) LLM instance.
|
|
41
|
+
Default Base URL: https://api.moonshot.cn/v1
|
|
42
|
+
"""
|
|
43
|
+
return OpenAILLM(
|
|
44
|
+
api_key=api_key or os.environ.get("MOONSHOT_API_KEY"),
|
|
45
|
+
base_url="https://api.moonshot.cn/v1",
|
|
46
|
+
model=model
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def create_zhipu(
|
|
51
|
+
api_key: Optional[str] = None,
|
|
52
|
+
model: str = "glm-4"
|
|
53
|
+
) -> BaseLLM:
|
|
54
|
+
"""
|
|
55
|
+
Create a ZhipuAI (ChatGLM) LLM instance.
|
|
56
|
+
Default Base URL: https://open.bigmodel.cn/api/paas/v4/
|
|
57
|
+
"""
|
|
58
|
+
return OpenAILLM(
|
|
59
|
+
api_key=api_key or os.environ.get("ZHIPU_API_KEY"),
|
|
60
|
+
base_url="https://open.bigmodel.cn/api/paas/v4/",
|
|
61
|
+
model=model
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def create_ollama(
|
|
66
|
+
base_url: str = "http://localhost:11434/v1",
|
|
67
|
+
model: str = "llama3"
|
|
68
|
+
) -> BaseLLM:
|
|
69
|
+
"""
|
|
70
|
+
Create a local Ollama LLM instance (via OpenAI compatibility).
|
|
71
|
+
"""
|
|
72
|
+
return OpenAILLM(
|
|
73
|
+
api_key="ollama", # API key is required but ignored by Ollama
|
|
74
|
+
base_url=base_url,
|
|
75
|
+
model=model
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def create_custom(
|
|
80
|
+
api_key: str,
|
|
81
|
+
base_url: str,
|
|
82
|
+
model: str
|
|
83
|
+
) -> BaseLLM:
|
|
84
|
+
"""Create a custom OpenAI-compatible LLM instance."""
|
|
85
|
+
return OpenAILLM(api_key=api_key, base_url=base_url, model=model)
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def create_mock() -> BaseLLM:
|
|
89
|
+
"""Create a Mock LLM for testing/demo purposes."""
|
|
90
|
+
class MockSmartLLM(BaseLLM):
|
|
91
|
+
def generate(self, prompt: str, stop: Optional[list] = None) -> str:
|
|
92
|
+
# Simple logic to simulate weather tool usage
|
|
93
|
+
if "Observation:" in prompt:
|
|
94
|
+
return "Final Answer: 根据天气预报, 北京今天天气不错, 晴转多云, 适合出行!"
|
|
95
|
+
return 'Thought: 用户想查询北京的天气, 我需要调用 get_weather 工具.\nAction: get_weather\nAction Input: "北京"'
|
|
96
|
+
|
|
97
|
+
def chat(self, messages: list, stop: Optional[list] = None) -> str:
|
|
98
|
+
# Fallback to generate for chat interface
|
|
99
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
100
|
+
return self.generate(prompt)
|
|
101
|
+
|
|
102
|
+
return MockSmartLLM()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import List, Dict, Optional, Any
|
|
2
|
+
import os
|
|
3
|
+
try:
|
|
4
|
+
from openai import OpenAI
|
|
5
|
+
except ImportError:
|
|
6
|
+
OpenAI = None
|
|
7
|
+
|
|
8
|
+
from shortgraph.models.base import BaseLLM
|
|
9
|
+
|
|
10
|
+
class OpenAILLM(BaseLLM):
|
|
11
|
+
"""
|
|
12
|
+
Adapter for OpenAI's GPT models and OpenAI-compatible APIs (DeepSeek, Moonshot, etc.).
|
|
13
|
+
"""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
api_key: Optional[str] = None,
|
|
17
|
+
model: str = "gpt-3.5-turbo",
|
|
18
|
+
base_url: Optional[str] = None
|
|
19
|
+
):
|
|
20
|
+
if OpenAI is None:
|
|
21
|
+
raise ImportError("Please install openai package: `pip install openai`")
|
|
22
|
+
|
|
23
|
+
self.client = OpenAI(
|
|
24
|
+
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
|
|
25
|
+
base_url=base_url or os.environ.get("OPENAI_BASE_URL")
|
|
26
|
+
)
|
|
27
|
+
self.model = model
|
|
28
|
+
|
|
29
|
+
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
30
|
+
try:
|
|
31
|
+
response = self.client.chat.completions.create(
|
|
32
|
+
model=self.model,
|
|
33
|
+
messages=[{"role": "user", "content": prompt}],
|
|
34
|
+
stop=stop
|
|
35
|
+
)
|
|
36
|
+
return response.choices[0].message.content
|
|
37
|
+
except Exception as e:
|
|
38
|
+
return f"Error generating response: {str(e)}"
|
|
39
|
+
|
|
40
|
+
def chat(self, messages: List[Dict[str, str]], stop: Optional[List[str]] = None) -> str:
|
|
41
|
+
try:
|
|
42
|
+
response = self.client.chat.completions.create(
|
|
43
|
+
model=self.model,
|
|
44
|
+
messages=messages,
|
|
45
|
+
stop=stop
|
|
46
|
+
)
|
|
47
|
+
return response.choices[0].message.content
|
|
48
|
+
except Exception as e:
|
|
49
|
+
return f"Error chatting: {str(e)}"
|