graphai-lib 0.0.8__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,13 +1,53 @@
1
1
  import asyncio
2
+ from dataclasses import dataclass
3
+ from enum import Enum
2
4
  from pydantic import Field
3
5
  from typing import Any
4
6
  from collections.abc import AsyncIterator
7
+ import warnings
5
8
 
6
9
 
7
10
  log_stream = True
8
11
 
9
12
 
13
+ class StrEnum(Enum):
14
+ def __str__(self) -> str:
15
+ return str(self.value)
16
+
17
+ class GraphEventType(StrEnum):
18
+ START = "start"
19
+ END = "end"
20
+ START_NODE = "start_node"
21
+ END_NODE = "end_node"
22
+ CALLBACK = "callback"
23
+
24
+ @dataclass
25
+ class GraphEvent:
26
+ """A graph event emitted for specific graph events such as start node or end node,
27
+ and used by the callback to emit user-defined events.
28
+
29
+ :param type: The type of event, can be start_node, end_node, or callback.
30
+ :type type: GraphEventType
31
+ :param identifier: The identifier of the event, this is set typically by a callback
32
+ handler and can be used to distinguish between different events. For example, a
33
+ conversation/session ID could be used.
34
+ :type identifier: str
35
+ :param token: The token associated with the event, such as LLM streamed output.
36
+ :type token: str | None
37
+ :param params: The parameters associated with the event, such as tool call parameters
38
+ or event metadata.
39
+ :type params: dict[str, Any] | None
40
+ """
41
+ type: GraphEventType
42
+ identifier: str
43
+ token: str | None = None
44
+ params: dict[str, Any] | None = None
45
+
46
+
10
47
  class Callback:
48
+ """The original callback handler class. Outputs a stream of structured text
49
+ tokens. It is recommended to use the newer `EventCallback` handler instead.
50
+ """
11
51
  identifier: str = Field(
12
52
  default="graphai",
13
53
  description=(
@@ -67,6 +107,11 @@ class Callback:
67
107
  special_token_format: str = "<{identifier}:{token}:{params}>",
68
108
  token_format: str = "{token}",
69
109
  ):
110
+ warnings.warn(
111
+ "The `Callback` class is deprecated and will be removed in " +
112
+ "v0.1.0. Use the `EventCallback` class instead.",
113
+ DeprecationWarning
114
+ )
70
115
  self.identifier = identifier
71
116
  self.special_token_format = special_token_format
72
117
  self.token_format = token_format
@@ -198,3 +243,100 @@ class Callback:
198
243
  return self.special_token_format.format(
199
244
  identifier=identifier, token=name, params=params_str
200
245
  )
246
+
247
+
248
+ class EventCallback(Callback):
249
+ """The event callback handler class. Outputs a stream of structured text
250
+ tokens. It is recommended to use the newer `EventCallback` handler instead.
251
+ """
252
+ def __init__(
253
+ self,
254
+ identifier: str = "graphai",
255
+ special_token_format: str | None = None,
256
+ token_format: str | None = None,
257
+ ):
258
+ warnings.warn(
259
+ "The `special_token_format` and `token_format` parameters are " +
260
+ "deprecated and will be removed in v0.1.0.",
261
+ DeprecationWarning
262
+ )
263
+ if special_token_format is None:
264
+ special_token_format = "<{identifier}:{token}:{params}>"
265
+ if token_format is None:
266
+ token_format = "{token}"
267
+ super().__init__(identifier, special_token_format, token_format)
268
+ self.events: list[GraphEvent] = []
269
+
270
+ def __call__(self, token: str, node_name: str | None = None):
271
+ if self._done:
272
+ raise RuntimeError("Cannot add tokens to a closed stream")
273
+ self._check_node_name(node_name=node_name)
274
+ event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
275
+ # otherwise we just assume node is correct and send token
276
+ self.queue.put_nowait(event)
277
+
278
+ async def acall(self, token: str, node_name: str | None = None):
279
+ # TODO JB: do we need to have `node_name` param?
280
+ if self._done:
281
+ raise RuntimeError("Cannot add tokens to a closed stream")
282
+ self._check_node_name(node_name=node_name)
283
+ event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
284
+ # otherwise we just assume node is correct and send token
285
+ self.queue.put_nowait(event)
286
+
287
+ async def aiter(self) -> AsyncIterator[GraphEvent]: # type: ignore[override]
288
+ """Used by receiver to get the tokens from the stream queue. Creates
289
+ a generator that yields tokens from the queue until the END token is
290
+ received.
291
+ """
292
+ while True: # Keep going until we see the END token
293
+ try:
294
+ if self._done and self.queue.empty():
295
+ break
296
+ token = await self.queue.get()
297
+ yield token
298
+ self.queue.task_done()
299
+ if token.type == GraphEventType.END:
300
+ break
301
+ except asyncio.CancelledError:
302
+ break
303
+ self._done = True # Mark as done after processing all tokens
304
+
305
+ async def start_node(self, node_name: str, active: bool = True):
306
+ """Starts a new node and emits the start token."""
307
+ if self._done:
308
+ raise RuntimeError("Cannot start node on a closed stream")
309
+ self.current_node_name = node_name
310
+ if self.first_token:
311
+ self.first_token = False
312
+ self.active = active
313
+ if self.active:
314
+ token = GraphEvent(type=GraphEventType.START_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
315
+ self.queue.put_nowait(token)
316
+
317
+ async def end_node(self, node_name: str):
318
+ """Emits the end token for the current node."""
319
+ if self._done:
320
+ raise RuntimeError("Cannot end node on a closed stream")
321
+ # self.current_node_name = node_name
322
+ if self.active:
323
+ token = GraphEvent(type=GraphEventType.END_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
324
+ self.queue.put_nowait(token)
325
+
326
+ async def close(self):
327
+ """Close the stream and prevent further tokens from being added.
328
+ This will send an END token and set the done flag to True.
329
+ """
330
+ if self._done:
331
+ return
332
+ end_token = GraphEvent(type=GraphEventType.END, identifier=self.identifier)
333
+ self._done = True # Set done before putting the end token
334
+ self.queue.put_nowait(end_token)
335
+ # Don't wait for queue.join() as it can cause deadlock
336
+ # The stream will close when aiter processes the END token
337
+
338
+ async def _build_special_token(
339
+ self, name: str, params: dict[str, Any] | None = None
340
+ ):
341
+ raise NotImplementedError("This method is not implemented for the `EventCallback` class.")
342
+
@@ -1,8 +1,24 @@
1
- from typing import Any, Protocol, Type
1
+ from __future__ import annotations
2
+ import asyncio
3
+ from typing import Any, Iterable, Protocol
4
+ from graphlib import TopologicalSorter, CycleError
2
5
  from graphai.callback import Callback
3
6
  from graphai.utils import logger
4
7
 
5
8
 
9
+ # to fix mypy error
10
+ class _HasName(Protocol):
11
+ name: str
12
+
13
+
14
+ class GraphError(Exception):
15
+ pass
16
+
17
+
18
+ class GraphCompileError(GraphError):
19
+ pass
20
+
21
+
6
22
  class NodeProtocol(Protocol):
7
23
  """Protocol defining the interface of a decorated node."""
8
24
 
@@ -20,6 +36,26 @@ class NodeProtocol(Protocol):
20
36
  ) -> dict[str, Any]: ...
21
37
 
22
38
 
39
+ def _name_of(x: Any) -> str | None:
40
+ """Return the node name if x is a str or has .name, else None."""
41
+ if x is None:
42
+ return None
43
+ if isinstance(x, str):
44
+ return x
45
+ name = getattr(x, "name", None)
46
+ return name if isinstance(name, str) else None
47
+
48
+
49
+ def _require_name(x: Any, kind: str) -> str:
50
+ """Like _name_of, but raises a helpful compile error when missing."""
51
+ s = _name_of(x)
52
+ if s is None:
53
+ raise GraphCompileError(
54
+ f"Edge {kind} must be a node name (str) or object with .name"
55
+ )
56
+ return s
57
+
58
+
23
59
  class Graph:
24
60
  def __init__(
25
61
  self, max_steps: int = 10, initial_state: dict[str, Any] | None = None
@@ -28,7 +64,7 @@ class Graph:
28
64
  self.edges: list[Any] = []
29
65
  self.start_node: NodeProtocol | None = None
30
66
  self.end_nodes: list[NodeProtocol] = []
31
- self.Callback: Type[Callback] = Callback
67
+ self.Callback: type[Callback] = Callback
32
68
  self.max_steps = max_steps
33
69
  self.state = initial_state or {}
34
70
 
@@ -37,22 +73,22 @@ class Graph:
37
73
  """Get the current graph state."""
38
74
  return self.state
39
75
 
40
- def set_state(self, state: dict[str, Any]) -> "Graph":
76
+ def set_state(self, state: dict[str, Any]) -> Graph:
41
77
  """Set the graph state."""
42
78
  self.state = state
43
79
  return self
44
80
 
45
- def update_state(self, values: dict[str, Any]) -> "Graph":
81
+ def update_state(self, values: dict[str, Any]) -> Graph:
46
82
  """Update the graph state with new values."""
47
83
  self.state.update(values)
48
84
  return self
49
85
 
50
- def reset_state(self) -> "Graph":
86
+ def reset_state(self) -> Graph:
51
87
  """Reset the graph state to an empty dict."""
52
88
  self.state = {}
53
89
  return self
54
90
 
55
- def add_node(self, node: NodeProtocol) -> "Graph":
91
+ def add_node(self, node: NodeProtocol) -> Graph:
56
92
  if node.name in self.nodes:
57
93
  raise Exception(f"Node with name '{node.name}' already exists.")
58
94
  self.nodes[node.name] = node
@@ -68,7 +104,9 @@ class Graph:
68
104
  self.end_nodes.append(node)
69
105
  return self
70
106
 
71
- def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str) -> "Graph":
107
+ def add_edge(
108
+ self, source: NodeProtocol | str, destination: NodeProtocol | str
109
+ ) -> Graph:
72
110
  """Adds an edge between two nodes that already exist in the graph.
73
111
 
74
112
  Args:
@@ -89,9 +127,7 @@ class Graph:
89
127
  else:
90
128
  source_name = str(source)
91
129
  if source_node is None:
92
- raise ValueError(
93
- f"Node with name '{source_name}' not found."
94
- )
130
+ raise ValueError(f"Node with name '{source_name}' not found.")
95
131
  # get destination node from graph
96
132
  destination_name: str
97
133
  if isinstance(destination, str):
@@ -105,9 +141,7 @@ class Graph:
105
141
  else:
106
142
  destination_name = str(destination)
107
143
  if destination_node is None:
108
- raise ValueError(
109
- f"Node with name '{destination_name}' not found."
110
- )
144
+ raise ValueError(f"Node with name '{destination_name}' not found.")
111
145
  edge = Edge(source_node, destination_node)
112
146
  self.edges.append(edge)
113
147
  return self
@@ -117,7 +151,7 @@ class Graph:
117
151
  sources: list[NodeProtocol],
118
152
  router: NodeProtocol,
119
153
  destinations: list[NodeProtocol],
120
- ) -> "Graph":
154
+ ) -> Graph:
121
155
  if not router.is_router:
122
156
  raise TypeError("A router object must be passed to the router parameter.")
123
157
  [self.add_edge(source, router) for source in sources]
@@ -125,26 +159,162 @@ class Graph:
125
159
  self.add_edge(router, destination)
126
160
  return self
127
161
 
128
- def set_start_node(self, node: NodeProtocol) -> "Graph":
162
+ def set_start_node(self, node: NodeProtocol) -> Graph:
129
163
  self.start_node = node
130
164
  return self
131
165
 
132
- def set_end_node(self, node: NodeProtocol) -> "Graph":
166
+ def set_end_node(self, node: NodeProtocol) -> Graph:
133
167
  self.end_node = node
134
168
  return self
135
169
 
136
- def compile(self) -> "Graph":
137
- if not self.start_node:
138
- raise Exception("Start node not defined.")
139
- if not self.end_nodes:
140
- raise Exception("No end nodes defined.")
141
- if not self._is_valid():
142
- raise Exception("Graph is not valid.")
143
- return self
170
+ def compile(self, *, strict: bool = False) -> Graph:
171
+ """
172
+ Validate the graph:
173
+ - exactly one start node present (or Graph.start_node set)
174
+ - at least one end node present
175
+ - all edges reference known nodes
176
+ - all nodes reachable from the start
177
+ (optional) **no cycles** when strict=True
178
+ Returns self on success; raises GraphCompileError otherwise.
179
+ """
180
+ # nodes map
181
+ nodes = getattr(self, "nodes", None)
182
+ if not isinstance(nodes, dict) or not nodes:
183
+ raise GraphCompileError("No nodes have been added to the graph")
184
+
185
+ start_name: str | None = None
186
+ # Bind and narrow the attribute for mypy
187
+ start_node: _HasName | None = getattr(self, "start_node", None)
188
+ if start_node is not None:
189
+ start_name = start_node.name
190
+ else:
191
+ starts = [
192
+ name
193
+ for name, n in nodes.items()
194
+ if getattr(n, "is_start", False) or getattr(n, "start", False)
195
+ ]
196
+ if len(starts) > 1:
197
+ raise GraphCompileError(f"Multiple start nodes defined: {starts}")
198
+ if len(starts) == 1:
199
+ start_name = starts[0]
200
+
201
+ if not start_name:
202
+ raise GraphCompileError("No start node defined")
203
+
204
+ # at least one end node
205
+ if not any(
206
+ getattr(n, "is_end", False) or getattr(n, "end", False)
207
+ for n in nodes.values()
208
+ ):
209
+ raise GraphCompileError("No end node defined")
210
+
211
+ # normalize edges into adjacency {src: set(dst)}
212
+ raw_edges = getattr(self, "edges", None)
213
+ adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
214
+
215
+ def _add_edge(src: str, dst: str) -> None:
216
+ if src not in nodes:
217
+ raise GraphCompileError(f"Edge references unknown source node: {src}")
218
+ if dst not in nodes:
219
+ raise GraphCompileError(
220
+ f"Edge from {src} references unknown node(s): ['{dst}']"
221
+ )
222
+ adj[src].add(dst)
223
+
224
+ if raw_edges is None:
225
+ pass
226
+ elif isinstance(raw_edges, dict):
227
+ for raw_src, dsts in raw_edges.items():
228
+ src = _require_name(raw_src, "source")
229
+ dst_iter = (
230
+ [dsts]
231
+ if isinstance(dsts, (str,)) or getattr(dsts, "name", None)
232
+ else list(dsts)
233
+ )
234
+ for d in dst_iter:
235
+ dst = _require_name(d, "destination")
236
+ _add_edge(src, dst)
237
+ else:
238
+ # generic iterable of “edge records”
239
+ try:
240
+ iterator = iter(raw_edges)
241
+ except TypeError:
242
+ raise GraphCompileError("Internal edge map has unsupported type")
243
+
244
+ for item in iterator:
245
+ # (src, dst) OR (src, Iterable[dst])
246
+ if isinstance(item, (tuple, list)) and len(item) == 2:
247
+ raw_src, rhs = item
248
+ src = _require_name(raw_src, "source")
249
+
250
+ if isinstance(rhs, str) or getattr(rhs, "name", None):
251
+ dst = _require_name(rhs, "destination")
252
+ _add_edge(src, rhs)
253
+ else:
254
+ # assume iterable of dsts (strings or node-like)
255
+ try:
256
+ for d in rhs:
257
+ dst = _require_name(d, "destination")
258
+ _add_edge(src, d)
259
+ except TypeError:
260
+ raise GraphCompileError(
261
+ "Edge tuple second item must be a destination or an iterable of destinations"
262
+ )
263
+ continue
264
+
265
+ # Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
266
+ if isinstance(item, dict):
267
+ src = _require_name(item.get("source", item.get("src")), "source")
268
+ dst = _require_name(
269
+ item.get("destination", item.get("dst")), "destination"
270
+ )
271
+ _add_edge(src, dst)
272
+ continue
273
+
274
+ # Object with attributes .source/.destination (or .src/.dst)
275
+ if hasattr(item, "source") or hasattr(item, "src"):
276
+ src = _require_name(
277
+ getattr(item, "source", getattr(item, "src", None)), "source"
278
+ )
279
+ dst = _require_name(
280
+ getattr(item, "destination", getattr(item, "dst", None)),
281
+ "destination",
282
+ )
283
+ _add_edge(src, dst)
284
+ continue
285
+
286
+ # If none matched, this is an unsupported edge record
287
+ raise GraphCompileError(
288
+ "Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
289
+ "(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
290
+ )
291
+
292
+ # reachability from start
293
+ seen: set[str] = set()
294
+ stack = [start_name]
295
+ while stack:
296
+ cur = stack.pop()
297
+ if cur in seen:
298
+ continue
299
+ seen.add(cur)
300
+ stack.extend(adj.get(cur, ()))
301
+
302
+ unreachable = sorted(set(nodes.keys()) - seen)
303
+ if unreachable:
304
+ raise GraphCompileError(f"Unreachable nodes: {unreachable}")
305
+
306
+ # optional cycle detection (strict mode)
307
+ if strict:
308
+ preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
309
+ for s, ds in adj.items():
310
+ for d in ds:
311
+ preds[d].add(s)
312
+ try:
313
+ list(TopologicalSorter(preds).static_order())
314
+ except CycleError as e:
315
+ raise GraphCompileError("cycle detected in graph (strict mode)") from e
144
316
 
145
- def _is_valid(self):
146
- # Implement validation logic, e.g., checking for cycles, disconnected components, etc.
147
- return True
317
+ return self
148
318
 
149
319
  def _validate_output(self, output: dict[str, Any], node_name: str):
150
320
  if not isinstance(output, dict):
@@ -205,6 +375,32 @@ class Graph:
205
375
  del state["callback"]
206
376
  return state
207
377
 
378
+ async def execute_many(
379
+ self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
380
+ ) -> list[Any]:
381
+ """Execute the graph on many inputs concurrently.
382
+
383
+ :param inputs: An iterable of input dicts to feed into the graph.
384
+ :type inputs: Iterable[dict]
385
+ :param concurrency: Maximum number of graph executions to run at once.
386
+ :type concurrency: int
387
+ :param state: Optional shared state to pass to each execution.
388
+ If you want isolated state per execution, pass None
389
+ and the graph's normal semantics will apply.
390
+ :type state: Optional[Any]
391
+ :returns: The list of results in the same order as ``inputs``.
392
+ :rtype: list[Any]
393
+ """
394
+
395
+ sem = asyncio.Semaphore(concurrency)
396
+
397
+ async def _run_one(inp: dict[str, Any]) -> Any:
398
+ async with sem:
399
+ return await self.execute(input=inp)
400
+
401
+ tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
402
+ return await asyncio.gather(*tasks)
403
+
208
404
  def get_callback(self):
209
405
  """Get a new instance of the callback class.
210
406
 
@@ -219,7 +415,7 @@ class Graph:
219
415
  as the default callback when no callback is passed to the `execute` method.
220
416
 
221
417
  :param callback_class: The callback class to use as the default callback.
222
- :type callback_class: Type[Callback]
418
+ :type callback_class: type[Callback]
223
419
  """
224
420
  self.Callback = callback_class
225
421
  return self
@@ -249,20 +445,24 @@ class Graph:
249
445
  f"No outgoing edge found for current node '{current_node.name}'."
250
446
  )
251
447
 
252
- def visualize(self):
448
+ def visualize(self, *, save_path: str | None = None):
449
+ """Render the current graph. If matplotlib is not installed,
450
+ raise a helpful error telling users to install the viz extra.
451
+ Optionally save to a file via `save_path`.
452
+ """
253
453
  try:
254
- import networkx as nx
255
- except ImportError:
454
+ import matplotlib.pyplot as plt
455
+ except ImportError as e:
256
456
  raise ImportError(
257
- "NetworkX is required for visualization. Please install it with 'pip install networkx'."
258
- )
457
+ "Graph visualization requires matplotlib. Install it with: `pip install matplotlib`"
458
+ ) from e
259
459
 
260
460
  try:
261
- import matplotlib.pyplot as plt
262
- except ImportError:
461
+ import networkx as nx
462
+ except ImportError as e:
263
463
  raise ImportError(
264
- "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
265
- )
464
+ "NetworkX is required for visualization. Please install it with `pip install networkx`."
465
+ ) from e
266
466
 
267
467
  G: Any = nx.DiGraph()
268
468
 
@@ -328,8 +528,12 @@ class Graph:
328
528
  arrowsize=20,
329
529
  )
330
530
 
331
- plt.axis("off")
332
- plt.show()
531
+ if save_path:
532
+ plt.savefig(save_path, bbox_inches="tight")
533
+ else:
534
+ plt.axis("off")
535
+ plt.show()
536
+ plt.close()
333
537
 
334
538
 
335
539
  class Edge:
File without changes
@@ -1,9 +1,18 @@
1
+ from enum import Enum
1
2
  import inspect
2
3
  import os
3
- from typing import Any, Callable
4
+ import sys
5
+ from typing import Any, Callable, Union, get_args, get_origin
4
6
  from pydantic import BaseModel, Field
7
+ from pydantic_core import PydanticUndefined
5
8
  import logging
6
- import sys
9
+
10
+
11
+ # we support python 3.10 so we define our own StrEnum (introduced in 3.11)
12
+ class StrEnum(str, Enum):
13
+ """Backport of StrEnum for Python < 3.11"""
14
+ def __str__(self):
15
+ return self.value
7
16
 
8
17
 
9
18
  class ColoredFormatter(logging.Formatter):
@@ -119,6 +128,9 @@ class Parameter(BaseModel):
119
128
  }
120
129
  }
121
130
 
131
+ class OpenAIAPI(StrEnum):
132
+ COMPLETIONS = "completions"
133
+ RESPONSES = "responses"
122
134
 
123
135
  class FunctionSchema(BaseModel):
124
136
  """Class that consumes a function and can return a schema required by
@@ -167,29 +179,95 @@ class FunctionSchema(BaseModel):
167
179
  )
168
180
 
169
181
  @classmethod
170
- def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
182
+ def from_pydantic(cls, model: type[BaseModel]) -> "FunctionSchema":
183
+ """Create a FunctionSchema from a Pydantic model class.
184
+
185
+ :param model: The Pydantic model class to convert
186
+ :type model: type[BaseModel]
187
+ :return: FunctionSchema instance
188
+ :rtype: FunctionSchema
189
+ """
190
+ # Extract model metadata
191
+ name = model.__name__
192
+ description = model.__doc__ or ""
193
+
194
+ # Build parameters list
195
+ parameters = []
171
196
  signature_parts = []
172
- for field_name, field_model in model.__annotations__.items():
173
- field_info = model.model_fields[field_name]
174
- default_value = field_info.default
175
- if default_value:
176
- default_repr = repr(default_value)
177
- signature_part = (
178
- f"{field_name}: {field_model.__name__} = {default_repr}"
197
+
198
+ for field_name, field_info in model.model_fields.items():
199
+ # Get the field type
200
+ field_type = model.__annotations__.get(field_name)
201
+
202
+ # Determine the type name - handle Optional and other generic types
203
+ type_name = str(field_type)
204
+
205
+ # Try to extract the actual type from Optional[T] -> T
206
+ origin = get_origin(field_type)
207
+ args = get_args(field_type)
208
+
209
+ if origin is Union:
210
+ # This is likely Optional[T] which is Union[T, None]
211
+ non_none_types = [arg for arg in args if arg is not type(None)]
212
+ if non_none_types:
213
+ actual_type = non_none_types[0]
214
+ if hasattr(actual_type, '__name__'):
215
+ type_name = actual_type.__name__
216
+ else:
217
+ type_name = str(actual_type)
218
+ elif field_type and hasattr(field_type, '__name__'):
219
+ type_name = field_type.__name__
220
+
221
+ # Check if field is required (no default value)
222
+ # In Pydantic v2, PydanticUndefined means no default
223
+ is_required = (
224
+ field_info.default is PydanticUndefined
225
+ and field_info.default_factory is None
226
+ )
227
+
228
+ # Get the actual default value
229
+ if field_info.default is not PydanticUndefined and field_info.default is not None:
230
+ default_value = field_info.default
231
+ elif field_info.default_factory is not None:
232
+ # For default_factory, we can't always call it without arguments
233
+ # Just use a placeholder to indicate there's a factory
234
+ try:
235
+ # Try calling with no arguments (common case)
236
+ default_value = field_info.default_factory() # type: ignore[call-arg]
237
+ except TypeError:
238
+ # If it needs arguments, just indicate it has a factory default
239
+ default_value = "<factory>"
240
+ else:
241
+ default_value = inspect.Parameter.empty
242
+
243
+ # Add parameter
244
+ parameters.append(
245
+ Parameter(
246
+ name=field_name,
247
+ description=field_info.description,
248
+ type=type_name,
249
+ default=default_value,
250
+ required=is_required,
179
251
  )
252
+ )
253
+
254
+ # Build signature part
255
+ if default_value != inspect.Parameter.empty:
256
+ signature_parts.append(f"{field_name}: {type_name} = {repr(default_value)}")
180
257
  else:
181
- signature_part = f"{field_name}: {field_model.__name__}"
182
- signature_parts.append(signature_part)
183
- signature = f"({', '.join(signature_parts)}) -> str"
258
+ signature_parts.append(f"{field_name}: {type_name}")
259
+
260
+ signature = f"({', '.join(signature_parts)}) -> dict"
261
+
184
262
  return cls.model_construct(
185
- name=model.__class__.__name__,
186
- description=model.__doc__ or "",
263
+ name=name,
264
+ description=description,
187
265
  signature=signature,
188
- output="", # TODO: Implement output
189
- parameters=[],
266
+ output="dict",
267
+ parameters=parameters,
190
268
  )
191
269
 
192
- def to_dict(self) -> dict:
270
+ def to_dict(self) -> dict[str, Any]:
193
271
  schema_dict = {
194
272
  "type": "function",
195
273
  "function": {
@@ -210,14 +288,42 @@ class FunctionSchema(BaseModel):
210
288
  }
211
289
  return schema_dict
212
290
 
213
- def to_openai(self) -> dict:
214
- return self.to_dict()
291
+ def to_openai(self, api: OpenAIAPI=OpenAIAPI.COMPLETIONS) -> dict[str, Any]:
292
+ """Convert the function schema into OpenAI-compatible formats. Supports
293
+ both completions and responses APIs.
294
+
295
+ :param api: The API to convert to.
296
+ :type api: OpenAIAPI
297
+ :return: The function schema in OpenAI-compatible format.
298
+ :rtype: dict
299
+ """
300
+ if api == "completions":
301
+ return self.to_dict()
302
+ elif api == "responses":
303
+ return {
304
+ "type": "function",
305
+ "name": self.name,
306
+ "description": self.description,
307
+ "parameters": {
308
+ "type": "object",
309
+ "properties": {
310
+ k: v
311
+ for param in self.parameters
312
+ for k, v in param.to_dict().items()
313
+ },
314
+ "required": [
315
+ param.name for param in self.parameters if param.required
316
+ ],
317
+ },
318
+ }
319
+ else:
320
+ raise ValueError(f"Unrecognized OpenAI API: {api}")
215
321
 
216
322
 
217
323
  DEFAULT = set(["default", "openai", "ollama", "litellm"])
218
324
 
219
325
 
220
- def get_schemas(callables: list[Callable], format: str = "default") -> list[dict]:
326
+ def get_schemas(callables: list[Callable], format: str = "default") -> list[dict[str, Any]]:
221
327
  if format in DEFAULT:
222
328
  return [
223
329
  FunctionSchema.from_callable(callable).to_dict() for callable in callables
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -3,6 +3,7 @@ pyproject.toml
3
3
  graphai/__init__.py
4
4
  graphai/callback.py
5
5
  graphai/graph.py
6
+ graphai/py.typed
6
7
  graphai/utils.py
7
8
  graphai/nodes/__init__.py
8
9
  graphai/nodes/base.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.8"
3
+ version = "0.0.9"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"
@@ -29,6 +29,9 @@ build-backend = "setuptools.build_meta"
29
29
  [tool.setuptools]
30
30
  packages = ["graphai", "graphai.nodes"]
31
31
 
32
+ [tool.setuptools.package-data]
33
+ graphai = ["py.typed"]
34
+
32
35
  [tool.mypy]
33
36
  python_version = "3.10"
34
37
  warn_return_any = true
File without changes
File without changes