graphai-lib 0.0.8__tar.gz → 0.0.9rc2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/PKG-INFO +1 -1
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai/graph.py +203 -38
- graphai_lib-0.0.9rc2/graphai/py.typed +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai/utils.py +127 -21
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai_lib.egg-info/PKG-INFO +1 -1
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai_lib.egg-info/SOURCES.txt +1 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/pyproject.toml +4 -1
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/README.md +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai/__init__.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai/callback.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai/nodes/__init__.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai/nodes/base.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai_lib.egg-info/dependency_links.txt +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai_lib.egg-info/requires.txt +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/graphai_lib.egg-info/top_level.txt +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9rc2}/setup.cfg +0 -0
@@ -1,8 +1,22 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, Protocol
|
2
3
|
from graphai.callback import Callback
|
3
4
|
from graphai.utils import logger
|
4
5
|
|
5
6
|
|
7
|
+
# to fix mypy error
|
8
|
+
class _HasName(Protocol):
|
9
|
+
name: str
|
10
|
+
|
11
|
+
|
12
|
+
class GraphError(Exception):
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class GraphCompileError(GraphError):
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
6
20
|
class NodeProtocol(Protocol):
|
7
21
|
"""Protocol defining the interface of a decorated node."""
|
8
22
|
|
@@ -20,6 +34,26 @@ class NodeProtocol(Protocol):
|
|
20
34
|
) -> dict[str, Any]: ...
|
21
35
|
|
22
36
|
|
37
|
+
def _name_of(x: Any) -> str | None:
|
38
|
+
"""Return the node name if x is a str or has .name, else None."""
|
39
|
+
if x is None:
|
40
|
+
return None
|
41
|
+
if isinstance(x, str):
|
42
|
+
return x
|
43
|
+
name = getattr(x, "name", None)
|
44
|
+
return name if isinstance(name, str) else None
|
45
|
+
|
46
|
+
|
47
|
+
def _require_name(x: Any, kind: str) -> str:
|
48
|
+
"""Like _name_of, but raises a helpful compile error when missing."""
|
49
|
+
s = _name_of(x)
|
50
|
+
if s is None:
|
51
|
+
raise GraphCompileError(
|
52
|
+
f"Edge {kind} must be a node name (str) or object with .name"
|
53
|
+
)
|
54
|
+
return s
|
55
|
+
|
56
|
+
|
23
57
|
class Graph:
|
24
58
|
def __init__(
|
25
59
|
self, max_steps: int = 10, initial_state: dict[str, Any] | None = None
|
@@ -28,7 +62,7 @@ class Graph:
|
|
28
62
|
self.edges: list[Any] = []
|
29
63
|
self.start_node: NodeProtocol | None = None
|
30
64
|
self.end_nodes: list[NodeProtocol] = []
|
31
|
-
self.Callback:
|
65
|
+
self.Callback: type[Callback] = Callback
|
32
66
|
self.max_steps = max_steps
|
33
67
|
self.state = initial_state or {}
|
34
68
|
|
@@ -37,22 +71,22 @@ class Graph:
|
|
37
71
|
"""Get the current graph state."""
|
38
72
|
return self.state
|
39
73
|
|
40
|
-
def set_state(self, state: dict[str, Any]) ->
|
74
|
+
def set_state(self, state: dict[str, Any]) -> Graph:
|
41
75
|
"""Set the graph state."""
|
42
76
|
self.state = state
|
43
77
|
return self
|
44
78
|
|
45
|
-
def update_state(self, values: dict[str, Any]) ->
|
79
|
+
def update_state(self, values: dict[str, Any]) -> Graph:
|
46
80
|
"""Update the graph state with new values."""
|
47
81
|
self.state.update(values)
|
48
82
|
return self
|
49
83
|
|
50
|
-
def reset_state(self) ->
|
84
|
+
def reset_state(self) -> Graph:
|
51
85
|
"""Reset the graph state to an empty dict."""
|
52
86
|
self.state = {}
|
53
87
|
return self
|
54
88
|
|
55
|
-
def add_node(self, node: NodeProtocol) ->
|
89
|
+
def add_node(self, node: NodeProtocol) -> Graph:
|
56
90
|
if node.name in self.nodes:
|
57
91
|
raise Exception(f"Node with name '{node.name}' already exists.")
|
58
92
|
self.nodes[node.name] = node
|
@@ -68,7 +102,9 @@ class Graph:
|
|
68
102
|
self.end_nodes.append(node)
|
69
103
|
return self
|
70
104
|
|
71
|
-
def add_edge(
|
105
|
+
def add_edge(
|
106
|
+
self, source: NodeProtocol | str, destination: NodeProtocol | str
|
107
|
+
) -> Graph:
|
72
108
|
"""Adds an edge between two nodes that already exist in the graph.
|
73
109
|
|
74
110
|
Args:
|
@@ -89,9 +125,7 @@ class Graph:
|
|
89
125
|
else:
|
90
126
|
source_name = str(source)
|
91
127
|
if source_node is None:
|
92
|
-
raise ValueError(
|
93
|
-
f"Node with name '{source_name}' not found."
|
94
|
-
)
|
128
|
+
raise ValueError(f"Node with name '{source_name}' not found.")
|
95
129
|
# get destination node from graph
|
96
130
|
destination_name: str
|
97
131
|
if isinstance(destination, str):
|
@@ -105,9 +139,7 @@ class Graph:
|
|
105
139
|
else:
|
106
140
|
destination_name = str(destination)
|
107
141
|
if destination_node is None:
|
108
|
-
raise ValueError(
|
109
|
-
f"Node with name '{destination_name}' not found."
|
110
|
-
)
|
142
|
+
raise ValueError(f"Node with name '{destination_name}' not found.")
|
111
143
|
edge = Edge(source_node, destination_node)
|
112
144
|
self.edges.append(edge)
|
113
145
|
return self
|
@@ -117,7 +149,7 @@ class Graph:
|
|
117
149
|
sources: list[NodeProtocol],
|
118
150
|
router: NodeProtocol,
|
119
151
|
destinations: list[NodeProtocol],
|
120
|
-
) ->
|
152
|
+
) -> Graph:
|
121
153
|
if not router.is_router:
|
122
154
|
raise TypeError("A router object must be passed to the router parameter.")
|
123
155
|
[self.add_edge(source, router) for source in sources]
|
@@ -125,26 +157,151 @@ class Graph:
|
|
125
157
|
self.add_edge(router, destination)
|
126
158
|
return self
|
127
159
|
|
128
|
-
def set_start_node(self, node: NodeProtocol) ->
|
160
|
+
def set_start_node(self, node: NodeProtocol) -> Graph:
|
129
161
|
self.start_node = node
|
130
162
|
return self
|
131
163
|
|
132
|
-
def set_end_node(self, node: NodeProtocol) ->
|
164
|
+
def set_end_node(self, node: NodeProtocol) -> Graph:
|
133
165
|
self.end_node = node
|
134
166
|
return self
|
135
167
|
|
136
168
|
def compile(self) -> "Graph":
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
169
|
+
"""
|
170
|
+
Validate the graph:
|
171
|
+
- exactly one start node present (or Graph.start_node set)
|
172
|
+
- at least one end node present
|
173
|
+
- all edges reference known nodes
|
174
|
+
- no cycles
|
175
|
+
- all nodes reachable from the start
|
176
|
+
Returns self on success; raises GraphCompileError otherwise.
|
177
|
+
"""
|
178
|
+
# nodes map
|
179
|
+
nodes = getattr(self, "nodes", None)
|
180
|
+
if not isinstance(nodes, dict) or not nodes:
|
181
|
+
raise GraphCompileError("No nodes have been added to the graph")
|
182
|
+
|
183
|
+
start_name: str | None = None
|
184
|
+
# Bind and narrow the attribute for mypy
|
185
|
+
start_node: _HasName | None = getattr(self, "start_node", None)
|
186
|
+
if start_node is not None:
|
187
|
+
start_name = start_node.name
|
188
|
+
else:
|
189
|
+
starts = [
|
190
|
+
name
|
191
|
+
for name, n in nodes.items()
|
192
|
+
if getattr(n, "is_start", False) or getattr(n, "start", False)
|
193
|
+
]
|
194
|
+
if len(starts) > 1:
|
195
|
+
raise GraphCompileError(f"Multiple start nodes defined: {starts}")
|
196
|
+
if len(starts) == 1:
|
197
|
+
start_name = starts[0]
|
198
|
+
|
199
|
+
if not start_name:
|
200
|
+
raise GraphCompileError("No start node defined")
|
201
|
+
|
202
|
+
# at least one end node
|
203
|
+
if not any(
|
204
|
+
getattr(n, "is_end", False) or getattr(n, "end", False)
|
205
|
+
for n in nodes.values()
|
206
|
+
):
|
207
|
+
raise GraphCompileError("No end node defined")
|
208
|
+
|
209
|
+
# normalize edges into adjacency {src: set(dst)}
|
210
|
+
raw_edges = getattr(self, "edges", None)
|
211
|
+
adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
|
212
|
+
|
213
|
+
def _add_edge(src: str, dst: str) -> None:
|
214
|
+
if src not in nodes:
|
215
|
+
raise GraphCompileError(f"Edge references unknown source node: {src}")
|
216
|
+
if dst not in nodes:
|
217
|
+
raise GraphCompileError(
|
218
|
+
f"Edge from {src} references unknown node(s): ['{dst}']"
|
219
|
+
)
|
220
|
+
adj[src].add(dst)
|
221
|
+
|
222
|
+
if raw_edges is None:
|
223
|
+
pass
|
224
|
+
elif isinstance(raw_edges, dict):
|
225
|
+
for raw_src, dsts in raw_edges.items():
|
226
|
+
src = _require_name(raw_src, "source")
|
227
|
+
dst_iter = (
|
228
|
+
[dsts]
|
229
|
+
if isinstance(dsts, (str,)) or getattr(dsts, "name", None)
|
230
|
+
else list(dsts)
|
231
|
+
)
|
232
|
+
for d in dst_iter:
|
233
|
+
dst = _require_name(d, "destination")
|
234
|
+
_add_edge(src, dst)
|
235
|
+
else:
|
236
|
+
# generic iterable of “edge records”
|
237
|
+
try:
|
238
|
+
iterator = iter(raw_edges)
|
239
|
+
except TypeError:
|
240
|
+
raise GraphCompileError("Internal edge map has unsupported type")
|
241
|
+
|
242
|
+
for item in iterator:
|
243
|
+
# (src, dst) OR (src, Iterable[dst])
|
244
|
+
if isinstance(item, (tuple, list)) and len(item) == 2:
|
245
|
+
raw_src, rhs = item
|
246
|
+
src = _require_name(raw_src, "source")
|
247
|
+
|
248
|
+
if isinstance(rhs, str) or getattr(rhs, "name", None):
|
249
|
+
dst = _require_name(rhs, "destination")
|
250
|
+
_add_edge(src, rhs)
|
251
|
+
else:
|
252
|
+
# assume iterable of dsts (strings or node-like)
|
253
|
+
try:
|
254
|
+
for d in rhs:
|
255
|
+
dst = _require_name(d, "destination")
|
256
|
+
_add_edge(src, d)
|
257
|
+
except TypeError:
|
258
|
+
raise GraphCompileError(
|
259
|
+
"Edge tuple second item must be a destination or an iterable of destinations"
|
260
|
+
)
|
261
|
+
continue
|
262
|
+
|
263
|
+
# Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
|
264
|
+
if isinstance(item, dict):
|
265
|
+
src = _require_name(item.get("source", item.get("src")), "source")
|
266
|
+
dst = _require_name(
|
267
|
+
item.get("destination", item.get("dst")), "destination"
|
268
|
+
)
|
269
|
+
_add_edge(src, dst)
|
270
|
+
continue
|
271
|
+
|
272
|
+
# Object with attributes .source/.destination (or .src/.dst)
|
273
|
+
if hasattr(item, "source") or hasattr(item, "src"):
|
274
|
+
src = _require_name(
|
275
|
+
getattr(item, "source", getattr(item, "src", None)), "source"
|
276
|
+
)
|
277
|
+
dst = _require_name(
|
278
|
+
getattr(item, "destination", getattr(item, "dst", None)),
|
279
|
+
"destination",
|
280
|
+
)
|
281
|
+
_add_edge(src, dst)
|
282
|
+
continue
|
283
|
+
|
284
|
+
# If none matched, this is an unsupported edge record
|
285
|
+
raise GraphCompileError(
|
286
|
+
"Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
|
287
|
+
"(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
|
288
|
+
)
|
289
|
+
|
290
|
+
# reachability from start
|
291
|
+
seen: set[str] = set()
|
292
|
+
stack = [start_name]
|
293
|
+
while stack:
|
294
|
+
cur = stack.pop()
|
295
|
+
if cur in seen:
|
296
|
+
continue
|
297
|
+
seen.add(cur)
|
298
|
+
stack.extend(adj.get(cur, ()))
|
144
299
|
|
145
|
-
|
146
|
-
|
147
|
-
|
300
|
+
unreachable = sorted(set(nodes.keys()) - seen)
|
301
|
+
if unreachable:
|
302
|
+
raise GraphCompileError(f"Unreachable nodes: {unreachable}")
|
303
|
+
|
304
|
+
return self
|
148
305
|
|
149
306
|
def _validate_output(self, output: dict[str, Any], node_name: str):
|
150
307
|
if not isinstance(output, dict):
|
@@ -219,7 +376,7 @@ class Graph:
|
|
219
376
|
as the default callback when no callback is passed to the `execute` method.
|
220
377
|
|
221
378
|
:param callback_class: The callback class to use as the default callback.
|
222
|
-
:type callback_class:
|
379
|
+
:type callback_class: type[Callback]
|
223
380
|
"""
|
224
381
|
self.Callback = callback_class
|
225
382
|
return self
|
@@ -249,20 +406,24 @@ class Graph:
|
|
249
406
|
f"No outgoing edge found for current node '{current_node.name}'."
|
250
407
|
)
|
251
408
|
|
252
|
-
def visualize(self):
|
409
|
+
def visualize(self, *, save_path: str | None = None):
|
410
|
+
"""Render the current graph. If matplotlib is not installed,
|
411
|
+
raise a helpful error telling users to install the viz extra.
|
412
|
+
Optionally save to a file via `save_path`.
|
413
|
+
"""
|
253
414
|
try:
|
254
|
-
import
|
255
|
-
except ImportError:
|
415
|
+
import matplotlib.pyplot as plt
|
416
|
+
except ImportError as e:
|
256
417
|
raise ImportError(
|
257
|
-
"
|
258
|
-
)
|
418
|
+
"Graph visualization requires matplotlib. Install it with: `pip install matplotlib`"
|
419
|
+
) from e
|
259
420
|
|
260
421
|
try:
|
261
|
-
import
|
262
|
-
except ImportError:
|
422
|
+
import networkx as nx
|
423
|
+
except ImportError as e:
|
263
424
|
raise ImportError(
|
264
|
-
"
|
265
|
-
)
|
425
|
+
"NetworkX is required for visualization. Please install it with `pip install networkx`."
|
426
|
+
) from e
|
266
427
|
|
267
428
|
G: Any = nx.DiGraph()
|
268
429
|
|
@@ -328,8 +489,12 @@ class Graph:
|
|
328
489
|
arrowsize=20,
|
329
490
|
)
|
330
491
|
|
331
|
-
|
332
|
-
|
492
|
+
if save_path:
|
493
|
+
plt.savefig(save_path, bbox_inches="tight")
|
494
|
+
else:
|
495
|
+
plt.axis("off")
|
496
|
+
plt.show()
|
497
|
+
plt.close()
|
333
498
|
|
334
499
|
|
335
500
|
class Edge:
|
File without changes
|
@@ -1,9 +1,18 @@
|
|
1
|
+
from enum import Enum
|
1
2
|
import inspect
|
2
3
|
import os
|
3
|
-
|
4
|
+
import sys
|
5
|
+
from typing import Any, Callable, Union, get_args, get_origin
|
4
6
|
from pydantic import BaseModel, Field
|
7
|
+
from pydantic_core import PydanticUndefined
|
5
8
|
import logging
|
6
|
-
|
9
|
+
|
10
|
+
|
11
|
+
# we support python 3.10 so we define our own StrEnum (introduced in 3.11)
|
12
|
+
class StrEnum(str, Enum):
|
13
|
+
"""Backport of StrEnum for Python < 3.11"""
|
14
|
+
def __str__(self):
|
15
|
+
return self.value
|
7
16
|
|
8
17
|
|
9
18
|
class ColoredFormatter(logging.Formatter):
|
@@ -119,6 +128,9 @@ class Parameter(BaseModel):
|
|
119
128
|
}
|
120
129
|
}
|
121
130
|
|
131
|
+
class OpenAIAPI(StrEnum):
|
132
|
+
COMPLETIONS = "completions"
|
133
|
+
RESPONSES = "responses"
|
122
134
|
|
123
135
|
class FunctionSchema(BaseModel):
|
124
136
|
"""Class that consumes a function and can return a schema required by
|
@@ -167,29 +179,95 @@ class FunctionSchema(BaseModel):
|
|
167
179
|
)
|
168
180
|
|
169
181
|
@classmethod
|
170
|
-
def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
|
182
|
+
def from_pydantic(cls, model: type[BaseModel]) -> "FunctionSchema":
|
183
|
+
"""Create a FunctionSchema from a Pydantic model class.
|
184
|
+
|
185
|
+
:param model: The Pydantic model class to convert
|
186
|
+
:type model: type[BaseModel]
|
187
|
+
:return: FunctionSchema instance
|
188
|
+
:rtype: FunctionSchema
|
189
|
+
"""
|
190
|
+
# Extract model metadata
|
191
|
+
name = model.__name__
|
192
|
+
description = model.__doc__ or ""
|
193
|
+
|
194
|
+
# Build parameters list
|
195
|
+
parameters = []
|
171
196
|
signature_parts = []
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
197
|
+
|
198
|
+
for field_name, field_info in model.model_fields.items():
|
199
|
+
# Get the field type
|
200
|
+
field_type = model.__annotations__.get(field_name)
|
201
|
+
|
202
|
+
# Determine the type name - handle Optional and other generic types
|
203
|
+
type_name = str(field_type)
|
204
|
+
|
205
|
+
# Try to extract the actual type from Optional[T] -> T
|
206
|
+
origin = get_origin(field_type)
|
207
|
+
args = get_args(field_type)
|
208
|
+
|
209
|
+
if origin is Union:
|
210
|
+
# This is likely Optional[T] which is Union[T, None]
|
211
|
+
non_none_types = [arg for arg in args if arg is not type(None)]
|
212
|
+
if non_none_types:
|
213
|
+
actual_type = non_none_types[0]
|
214
|
+
if hasattr(actual_type, '__name__'):
|
215
|
+
type_name = actual_type.__name__
|
216
|
+
else:
|
217
|
+
type_name = str(actual_type)
|
218
|
+
elif field_type and hasattr(field_type, '__name__'):
|
219
|
+
type_name = field_type.__name__
|
220
|
+
|
221
|
+
# Check if field is required (no default value)
|
222
|
+
# In Pydantic v2, PydanticUndefined means no default
|
223
|
+
is_required = (
|
224
|
+
field_info.default is PydanticUndefined
|
225
|
+
and field_info.default_factory is None
|
226
|
+
)
|
227
|
+
|
228
|
+
# Get the actual default value
|
229
|
+
if field_info.default is not PydanticUndefined and field_info.default is not None:
|
230
|
+
default_value = field_info.default
|
231
|
+
elif field_info.default_factory is not None:
|
232
|
+
# For default_factory, we can't always call it without arguments
|
233
|
+
# Just use a placeholder to indicate there's a factory
|
234
|
+
try:
|
235
|
+
# Try calling with no arguments (common case)
|
236
|
+
default_value = field_info.default_factory() # type: ignore[call-arg]
|
237
|
+
except TypeError:
|
238
|
+
# If it needs arguments, just indicate it has a factory default
|
239
|
+
default_value = "<factory>"
|
240
|
+
else:
|
241
|
+
default_value = inspect.Parameter.empty
|
242
|
+
|
243
|
+
# Add parameter
|
244
|
+
parameters.append(
|
245
|
+
Parameter(
|
246
|
+
name=field_name,
|
247
|
+
description=field_info.description,
|
248
|
+
type=type_name,
|
249
|
+
default=default_value,
|
250
|
+
required=is_required,
|
179
251
|
)
|
252
|
+
)
|
253
|
+
|
254
|
+
# Build signature part
|
255
|
+
if default_value != inspect.Parameter.empty:
|
256
|
+
signature_parts.append(f"{field_name}: {type_name} = {repr(default_value)}")
|
180
257
|
else:
|
181
|
-
|
182
|
-
|
183
|
-
signature = f"({', '.join(signature_parts)}) ->
|
258
|
+
signature_parts.append(f"{field_name}: {type_name}")
|
259
|
+
|
260
|
+
signature = f"({', '.join(signature_parts)}) -> dict"
|
261
|
+
|
184
262
|
return cls.model_construct(
|
185
|
-
name=
|
186
|
-
description=
|
263
|
+
name=name,
|
264
|
+
description=description,
|
187
265
|
signature=signature,
|
188
|
-
output="",
|
189
|
-
parameters=
|
266
|
+
output="dict",
|
267
|
+
parameters=parameters,
|
190
268
|
)
|
191
269
|
|
192
|
-
def to_dict(self) -> dict:
|
270
|
+
def to_dict(self) -> dict[str, Any]:
|
193
271
|
schema_dict = {
|
194
272
|
"type": "function",
|
195
273
|
"function": {
|
@@ -210,14 +288,42 @@ class FunctionSchema(BaseModel):
|
|
210
288
|
}
|
211
289
|
return schema_dict
|
212
290
|
|
213
|
-
def to_openai(self) -> dict:
|
214
|
-
|
291
|
+
def to_openai(self, api: OpenAIAPI=OpenAIAPI.COMPLETIONS) -> dict[str, Any]:
|
292
|
+
"""Convert the function schema into OpenAI-compatible formats. Supports
|
293
|
+
both completions and responses APIs.
|
294
|
+
|
295
|
+
:param api: The API to convert to.
|
296
|
+
:type api: OpenAIAPI
|
297
|
+
:return: The function schema in OpenAI-compatible format.
|
298
|
+
:rtype: dict
|
299
|
+
"""
|
300
|
+
if api == "completions":
|
301
|
+
return self.to_dict()
|
302
|
+
elif api == "responses":
|
303
|
+
return {
|
304
|
+
"type": "function",
|
305
|
+
"name": self.name,
|
306
|
+
"description": self.description,
|
307
|
+
"parameters": {
|
308
|
+
"type": "object",
|
309
|
+
"properties": {
|
310
|
+
k: v
|
311
|
+
for param in self.parameters
|
312
|
+
for k, v in param.to_dict().items()
|
313
|
+
},
|
314
|
+
"required": [
|
315
|
+
param.name for param in self.parameters if param.required
|
316
|
+
],
|
317
|
+
},
|
318
|
+
}
|
319
|
+
else:
|
320
|
+
raise ValueError(f"Unrecognized OpenAI API: {api}")
|
215
321
|
|
216
322
|
|
217
323
|
DEFAULT = set(["default", "openai", "ollama", "litellm"])
|
218
324
|
|
219
325
|
|
220
|
-
def get_schemas(callables: list[Callable], format: str = "default") -> list[dict]:
|
326
|
+
def get_schemas(callables: list[Callable], format: str = "default") -> list[dict[str, Any]]:
|
221
327
|
if format in DEFAULT:
|
222
328
|
return [
|
223
329
|
FunctionSchema.from_callable(callable).to_dict() for callable in callables
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "graphai-lib"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.9rc2"
|
4
4
|
description = "Not an AI framework"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.10,<3.14"
|
@@ -29,6 +29,9 @@ build-backend = "setuptools.build_meta"
|
|
29
29
|
[tool.setuptools]
|
30
30
|
packages = ["graphai", "graphai.nodes"]
|
31
31
|
|
32
|
+
[tool.setuptools.package-data]
|
33
|
+
graphai = ["py.typed"]
|
34
|
+
|
32
35
|
[tool.mypy]
|
33
36
|
python_version = "3.10"
|
34
37
|
warn_return_any = true
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|