graphai-lib 0.0.6__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: networkx>=3.4.2
8
- Requires-Dist: matplotlib>=3.10.0
9
7
  Requires-Dist: pydantic>=2.11.1
10
- Requires-Dist: colorlog>=6.9.0
11
8
  Provides-Extra: dev
12
9
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
13
10
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -0,0 +1,5 @@
1
+ from graphai.callback import Callback
2
+ from graphai.graph import Graph
3
+ from graphai.nodes import node, router
4
+
5
+ __all__ = ["node", "router", "Callback", "Graph"]
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  from pydantic import Field
3
- from typing import Optional, Any
3
+ from typing import Any
4
4
  from collections.abc import AsyncIterator
5
5
 
6
6
 
@@ -48,7 +48,7 @@ class Callback:
48
48
  description="Whether this is the first token in the stream.",
49
49
  exclude=True,
50
50
  )
51
- _current_node_name: Optional[str] = Field(
51
+ _current_node_name: str | None = Field(
52
52
  default=None, description="The name of the current node.", exclude=True
53
53
  )
54
54
  _active: bool = Field(
@@ -85,11 +85,11 @@ class Callback:
85
85
  self._first_token = value
86
86
 
87
87
  @property
88
- def current_node_name(self) -> Optional[str]:
88
+ def current_node_name(self) -> str | None:
89
89
  return self._current_node_name
90
90
 
91
91
  @current_node_name.setter
92
- def current_node_name(self, value: Optional[str]):
92
+ def current_node_name(self, value: str | None):
93
93
  self._current_node_name = value
94
94
 
95
95
  @property
@@ -100,14 +100,14 @@ class Callback:
100
100
  def active(self, value: bool):
101
101
  self._active = value
102
102
 
103
- def __call__(self, token: str, node_name: Optional[str] = None):
103
+ def __call__(self, token: str, node_name: str | None = None):
104
104
  if self._done:
105
105
  raise RuntimeError("Cannot add tokens to a closed stream")
106
106
  self._check_node_name(node_name=node_name)
107
107
  # otherwise we just assume node is correct and send token
108
108
  self.queue.put_nowait(token)
109
109
 
110
- async def acall(self, token: str, node_name: Optional[str] = None):
110
+ async def acall(self, token: str, node_name: str | None = None):
111
111
  # TODO JB: do we need to have `node_name` param?
112
112
  if self._done:
113
113
  raise RuntimeError("Cannot add tokens to a closed stream")
@@ -176,7 +176,7 @@ class Callback:
176
176
  # Don't wait for queue.join() as it can cause deadlock
177
177
  # The stream will close when aiter processes the END token
178
178
 
179
- def _check_node_name(self, node_name: Optional[str] = None):
179
+ def _check_node_name(self, node_name: str | None = None):
180
180
  if node_name:
181
181
  # we confirm this is the current node
182
182
  if self.current_node_name != node_name:
@@ -1,54 +1,58 @@
1
- from typing import List, Dict, Any, Optional, Protocol, Type
1
+ from typing import Any, Protocol, Type
2
2
  from graphai.callback import Callback
3
3
  from graphai.utils import logger
4
4
 
5
5
 
6
6
  class NodeProtocol(Protocol):
7
7
  """Protocol defining the interface of a decorated node."""
8
+
8
9
  name: str
9
10
  is_start: bool
10
11
  is_end: bool
11
12
  is_router: bool
12
13
  stream: bool
13
-
14
+
14
15
  async def invoke(
15
- self,
16
- input: Dict[str, Any],
17
- callback: Optional[Callback] = None,
18
- state: Optional[Dict[str, Any]] = None
19
- ) -> Dict[str, Any]: ...
16
+ self,
17
+ input: dict[str, Any],
18
+ callback: Callback | None = None,
19
+ state: dict[str, Any] | None = None,
20
+ ) -> dict[str, Any]: ...
20
21
 
21
22
 
22
23
  class Graph:
23
24
  def __init__(
24
- self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
25
+ self, max_steps: int = 10, initial_state: dict[str, Any] | None = None
25
26
  ):
26
- self.nodes: Dict[str, NodeProtocol] = {}
27
- self.edges: List[Any] = []
28
- self.start_node: Optional[NodeProtocol] = None
29
- self.end_nodes: List[NodeProtocol] = []
27
+ self.nodes: dict[str, NodeProtocol] = {}
28
+ self.edges: list[Any] = []
29
+ self.start_node: NodeProtocol | None = None
30
+ self.end_nodes: list[NodeProtocol] = []
30
31
  self.Callback: Type[Callback] = Callback
31
32
  self.max_steps = max_steps
32
33
  self.state = initial_state or {}
33
34
 
34
35
  # Allow getting and setting the graph's internal state
35
- def get_state(self) -> Dict[str, Any]:
36
+ def get_state(self) -> dict[str, Any]:
36
37
  """Get the current graph state."""
37
38
  return self.state
38
39
 
39
- def set_state(self, state: Dict[str, Any]):
40
+ def set_state(self, state: dict[str, Any]) -> "Graph":
40
41
  """Set the graph state."""
41
42
  self.state = state
43
+ return self
42
44
 
43
- def update_state(self, values: Dict[str, Any]):
45
+ def update_state(self, values: dict[str, Any]) -> "Graph":
44
46
  """Update the graph state with new values."""
45
47
  self.state.update(values)
48
+ return self
46
49
 
47
- def reset_state(self):
50
+ def reset_state(self) -> "Graph":
48
51
  """Reset the graph state to an empty dict."""
49
52
  self.state = {}
53
+ return self
50
54
 
51
- def add_node(self, node: NodeProtocol):
55
+ def add_node(self, node: NodeProtocol) -> "Graph":
52
56
  if node.name in self.nodes:
53
57
  raise Exception(f"Node with name '{node.name}' already exists.")
54
58
  self.nodes[node.name] = node
@@ -62,8 +66,9 @@ class Graph:
62
66
  self.start_node = node
63
67
  if node.is_end:
64
68
  self.end_nodes.append(node)
69
+ return self
65
70
 
66
- def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str):
71
+ def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str) -> "Graph":
67
72
  """Adds an edge between two nodes that already exist in the graph.
68
73
 
69
74
  Args:
@@ -72,58 +77,76 @@ class Graph:
72
77
  """
73
78
  source_node, destination_node = None, None
74
79
  # get source node from graph
80
+ source_name: str
75
81
  if isinstance(source, str):
76
82
  source_node = self.nodes.get(source)
83
+ source_name = source
77
84
  else:
78
85
  # Check if it's a node-like object by looking for required attributes
79
86
  if hasattr(source, "name"):
80
87
  source_node = self.nodes.get(source.name)
88
+ source_name = source.name
89
+ else:
90
+ source_name = str(source)
81
91
  if source_node is None:
82
92
  raise ValueError(
83
- f"Node with name '{source.name if hasattr(source, 'name') else source}' not found."
93
+ f"Node with name '{source_name}' not found."
84
94
  )
85
95
  # get destination node from graph
96
+ destination_name: str
86
97
  if isinstance(destination, str):
87
98
  destination_node = self.nodes.get(destination)
99
+ destination_name = destination
88
100
  else:
89
101
  # Check if it's a node-like object by looking for required attributes
90
102
  if hasattr(destination, "name"):
91
103
  destination_node = self.nodes.get(destination.name)
104
+ destination_name = destination.name
105
+ else:
106
+ destination_name = str(destination)
92
107
  if destination_node is None:
93
108
  raise ValueError(
94
- f"Node with name '{destination.name if hasattr(destination, 'name') else destination}' not found."
109
+ f"Node with name '{destination_name}' not found."
95
110
  )
96
111
  edge = Edge(source_node, destination_node)
97
112
  self.edges.append(edge)
113
+ return self
98
114
 
99
115
  def add_router(
100
- self, sources: list[NodeProtocol], router: NodeProtocol, destinations: List[NodeProtocol]
101
- ):
116
+ self,
117
+ sources: list[NodeProtocol],
118
+ router: NodeProtocol,
119
+ destinations: list[NodeProtocol],
120
+ ) -> "Graph":
102
121
  if not router.is_router:
103
122
  raise TypeError("A router object must be passed to the router parameter.")
104
123
  [self.add_edge(source, router) for source in sources]
105
124
  for destination in destinations:
106
125
  self.add_edge(router, destination)
126
+ return self
107
127
 
108
- def set_start_node(self, node: NodeProtocol):
128
+ def set_start_node(self, node: NodeProtocol) -> "Graph":
109
129
  self.start_node = node
130
+ return self
110
131
 
111
- def set_end_node(self, node: NodeProtocol):
132
+ def set_end_node(self, node: NodeProtocol) -> "Graph":
112
133
  self.end_node = node
134
+ return self
113
135
 
114
- def compile(self):
136
+ def compile(self) -> "Graph":
115
137
  if not self.start_node:
116
138
  raise Exception("Start node not defined.")
117
139
  if not self.end_nodes:
118
140
  raise Exception("No end nodes defined.")
119
141
  if not self._is_valid():
120
142
  raise Exception("Graph is not valid.")
143
+ return self
121
144
 
122
145
  def _is_valid(self):
123
146
  # Implement validation logic, e.g., checking for cycles, disconnected components, etc.
124
147
  return True
125
148
 
126
- def _validate_output(self, output: Dict[str, Any], node_name: str):
149
+ def _validate_output(self, output: dict[str, Any], node_name: str):
127
150
  if not isinstance(output, dict):
128
151
  raise ValueError(
129
152
  f"Expected dictionary output from node {node_name}. "
@@ -134,11 +157,11 @@ class Graph:
134
157
  # TODO JB: may need to add init callback here to init the queue on every new execution
135
158
  if callback is None:
136
159
  callback = self.get_callback()
137
-
160
+
138
161
  # Type assertion to tell the type checker that start_node is not None after compile()
139
162
  assert self.start_node is not None, "Graph must be compiled before execution"
140
163
  current_node = self.start_node
141
-
164
+
142
165
  state = input
143
166
  # Don't reset the graph state if it was initialized with initial_state
144
167
  steps = 0
@@ -191,7 +214,7 @@ class Graph:
191
214
  callback = self.Callback()
192
215
  return callback
193
216
 
194
- def set_callback(self, callback_class: Type[Callback]):
217
+ def set_callback(self, callback_class: type[Callback]) -> "Graph":
195
218
  """Set the callback class that is returned by the `get_callback` method and used
196
219
  as the default callback when no callback is passed to the `execute` method.
197
220
 
@@ -199,6 +222,7 @@ class Graph:
199
222
  :type callback_class: Type[Callback]
200
223
  """
201
224
  self.Callback = callback_class
225
+ return self
202
226
 
203
227
  def _get_node_by_name(self, node_name: str) -> NodeProtocol:
204
228
  """Get a node by its name.
@@ -240,7 +264,7 @@ class Graph:
240
264
  "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
241
265
  )
242
266
 
243
- G = nx.DiGraph()
267
+ G: Any = nx.DiGraph()
244
268
 
245
269
  for node in self.nodes.values():
246
270
  G.add_node(node.name)
@@ -264,11 +288,11 @@ class Graph:
264
288
  y_coord[node] = y_max - i - 1
265
289
 
266
290
  # Set up the layout
267
- pos = {}
291
+ pos: dict[Any, tuple[float, float]] = {}
268
292
  for i, generation in enumerate(generations):
269
293
  x = 0
270
294
  for node in generation:
271
- pos[node] = (x, y_coord[node])
295
+ pos[node] = (float(x), float(y_coord[node]))
272
296
  x += 1
273
297
 
274
298
  # Center each level horizontally
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, Callable, Dict, Optional
2
+ from typing import Any, Callable
3
3
  from pydantic import Field
4
4
 
5
5
  from graphai.callback import Callback
@@ -8,7 +8,7 @@ from graphai.utils import FunctionSchema
8
8
 
9
9
  class NodeMeta(type):
10
10
  @staticmethod
11
- def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
11
+ def positional_to_kwargs(cls_type, args) -> dict[str, Any]:
12
12
  init_signature = inspect.signature(cls_type.__init__)
13
13
  init_params = {
14
14
  name: arg
@@ -73,7 +73,7 @@ class _Node:
73
73
  params_dict = await self._parse_params(*args, **kwargs)
74
74
  return await func(**params_dict) # Pass only the necessary arguments
75
75
 
76
- async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
76
+ async def _parse_params(self, *args, **kwargs) -> dict[str, Any]:
77
77
  # filter out unexpected keyword args
78
78
  expected_kwargs = {
79
79
  k: v for k, v in kwargs.items() if k in self._expected_params
@@ -133,9 +133,9 @@ class _Node:
133
133
  @classmethod
134
134
  async def invoke(
135
135
  cls,
136
- input: Dict[str, Any],
137
- callback: Optional[Callback] = None,
138
- state: Optional[Dict[str, Any]] = None,
136
+ input: dict[str, Any],
137
+ callback: Callback | None = None,
138
+ state: dict[str, Any] | None = None,
139
139
  ):
140
140
  if callback:
141
141
  if stream:
@@ -154,8 +154,6 @@ class _Node:
154
154
 
155
155
  NodeClass.__name__ = func.__name__
156
156
  node_class_name = name or func.__name__
157
- if node_class_name is None:
158
- raise ValueError("Unexpected error: node name not set.")
159
157
  NodeClass.name = node_class_name
160
158
  NodeClass.__doc__ = func.__doc__
161
159
  NodeClass.is_start = start
@@ -167,7 +165,7 @@ class _Node:
167
165
 
168
166
  def __call__(
169
167
  self,
170
- func: Optional[Callable] = None,
168
+ func: Callable | None = None,
171
169
  start: bool = False,
172
170
  end: bool = False,
173
171
  stream: bool = False,
@@ -1,33 +1,43 @@
1
1
  import inspect
2
- from typing import Any, Callable, List, Optional
2
+ import os
3
+ from typing import Any, Callable
3
4
  from pydantic import BaseModel, Field
4
5
  import logging
6
+ import sys
5
7
 
6
- import colorlog
7
8
 
9
+ class ColoredFormatter(logging.Formatter):
10
+ """Custom colored formatter for the logger using ANSI escape codes."""
8
11
 
9
- class CustomFormatter(colorlog.ColoredFormatter):
10
- """Custom formatter for the logger."""
12
+ # ANSI escape codes for colors
13
+ COLORS = {
14
+ "DEBUG": "\033[36m", # Cyan
15
+ "INFO": "\033[32m", # Green
16
+ "WARNING": "\033[33m", # Yellow
17
+ "ERROR": "\033[31m", # Red
18
+ "CRITICAL": "\033[1;31m", # Bold Red
19
+ }
20
+ RESET = "\033[0m"
11
21
 
12
22
  def __init__(self):
13
23
  super().__init__(
14
- "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
24
+ "%(asctime)s %(levelname)s %(name)s %(message)s",
15
25
  datefmt="%Y-%m-%d %H:%M:%S",
16
- log_colors={
17
- "DEBUG": "cyan",
18
- "INFO": "green",
19
- "WARNING": "yellow",
20
- "ERROR": "red",
21
- "CRITICAL": "bold_red",
22
- },
23
- reset=True,
24
- style="%",
25
26
  )
26
27
 
28
+ def format(self, record):
29
+ # Check if the output supports color (TTY)
30
+ if hasattr(sys.stderr, "isatty") and sys.stderr.isatty():
31
+ levelname = record.levelname
32
+ if levelname in self.COLORS:
33
+ record.levelname = f"{self.COLORS[levelname]}{levelname}{self.RESET}"
34
+ record.msg = f"{self.COLORS[levelname]}{record.msg}{self.RESET}"
35
+ return super().format(record)
36
+
27
37
 
28
38
  def add_coloured_handler(logger):
29
39
  """Add a coloured handler to the logger."""
30
- formatter = CustomFormatter()
40
+ formatter = ColoredFormatter()
31
41
  console_handler = logging.StreamHandler()
32
42
  console_handler.setFormatter(formatter)
33
43
  logger.addHandler(console_handler)
@@ -40,7 +50,17 @@ def setup_custom_logger(name):
40
50
 
41
51
  if not logger.hasHandlers():
42
52
  add_coloured_handler(logger)
43
- logger.setLevel(logging.INFO)
53
+
54
+ # Set log level from environment variable, default to INFO
55
+ log_level = os.getenv("GRAPHAI_LOG_LEVEL", "INFO").upper()
56
+ level_map = {
57
+ "DEBUG": logging.DEBUG,
58
+ "INFO": logging.INFO,
59
+ "WARNING": logging.WARNING,
60
+ "ERROR": logging.ERROR,
61
+ "CRITICAL": logging.CRITICAL,
62
+ }
63
+ logger.setLevel(level_map.get(log_level, logging.INFO))
44
64
  logger.propagate = False
45
65
 
46
66
  return logger
@@ -68,7 +88,7 @@ class Parameter(BaseModel):
68
88
  :param name: The name of the parameter.
69
89
  :type name: str
70
90
  :param description: The description of the parameter.
71
- :type description: Optional[str]
91
+ :type description: str | None
72
92
  :param type: The type of the parameter.
73
93
  :type type: str
74
94
  :param default: The default value of the parameter.
@@ -78,7 +98,7 @@ class Parameter(BaseModel):
78
98
  """
79
99
 
80
100
  name: str = Field(description="The name of the parameter")
81
- description: Optional[str] = Field(
101
+ description: str | None = Field(
82
102
  default=None, description="The description of the parameter"
83
103
  )
84
104
  type: str = Field(description="The type of the parameter")
@@ -118,34 +138,33 @@ class FunctionSchema(BaseModel):
118
138
  :param function: The function to consume.
119
139
  :type function: Callable
120
140
  """
121
- if callable(function):
122
- name = function.__name__
123
- description = str(inspect.getdoc(function))
124
- if description is None or description == "":
125
- logger.warning(f"Function {name} has no docstring")
126
- signature = str(inspect.signature(function))
127
- output = str(inspect.signature(function).return_annotation)
128
- parameters = []
129
- for param in inspect.signature(function).parameters.values():
130
- parameters.append(
131
- Parameter(
132
- name=param.name,
133
- type=param.annotation.__name__,
134
- default=param.default,
135
- required=param.default is inspect.Parameter.empty,
136
- )
141
+ if not callable(function):
142
+ raise TypeError("Function must be a Callable")
143
+
144
+ name = function.__name__
145
+ doc = inspect.getdoc(function)
146
+ description = str(doc) if doc else ""
147
+ if not description:
148
+ logger.warning(f"Function {name} has no docstring")
149
+ signature = str(inspect.signature(function))
150
+ output = str(inspect.signature(function).return_annotation)
151
+ parameters = []
152
+ for param in inspect.signature(function).parameters.values():
153
+ parameters.append(
154
+ Parameter(
155
+ name=param.name,
156
+ type=param.annotation.__name__,
157
+ default=param.default,
158
+ required=param.default is inspect.Parameter.empty,
137
159
  )
138
- return cls.model_construct(
139
- name=name,
140
- description=description,
141
- signature=signature,
142
- output=output,
143
- parameters=parameters,
144
160
  )
145
- elif isinstance(function, BaseModel):
146
- raise NotImplementedError("Pydantic BaseModel not implemented yet.")
147
- else:
148
- raise TypeError("Function must be a Callable or BaseModel")
161
+ return cls.model_construct(
162
+ name=name,
163
+ description=description,
164
+ signature=signature,
165
+ output=output,
166
+ parameters=parameters,
167
+ )
149
168
 
150
169
  @classmethod
151
170
  def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
@@ -179,7 +198,9 @@ class FunctionSchema(BaseModel):
179
198
  "parameters": {
180
199
  "type": "object",
181
200
  "properties": {
182
- k: v for param in self.parameters for k, v in param.to_dict().items()
201
+ k: v
202
+ for param in self.parameters
203
+ for k, v in param.to_dict().items()
183
204
  },
184
205
  "required": [
185
206
  param.name for param in self.parameters if param.required
@@ -196,7 +217,7 @@ class FunctionSchema(BaseModel):
196
217
  DEFAULT = set(["default", "openai", "ollama", "litellm"])
197
218
 
198
219
 
199
- def get_schemas(callables: List[Callable], format: str = "default") -> list[dict]:
220
+ def get_schemas(callables: list[Callable], format: str = "default") -> list[dict]:
200
221
  if format in DEFAULT:
201
222
  return [
202
223
  FunctionSchema.from_callable(callable).to_dict() for callable in callables
@@ -1,13 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: networkx>=3.4.2
8
- Requires-Dist: matplotlib>=3.10.0
9
7
  Requires-Dist: pydantic>=2.11.1
10
- Requires-Dist: colorlog>=6.9.0
11
8
  Provides-Extra: dev
12
9
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
13
10
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -1,7 +1,4 @@
1
- networkx>=3.4.2
2
- matplotlib>=3.10.0
3
1
  pydantic>=2.11.1
4
- colorlog>=6.9.0
5
2
 
6
3
  [dev]
7
4
  ipykernel>=6.25.0
@@ -1,14 +1,11 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.6"
3
+ version = "0.0.8"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"
7
7
  dependencies = [
8
- "networkx>=3.4.2",
9
- "matplotlib>=3.10.0",
10
8
  "pydantic>=2.11.1",
11
- "colorlog>=6.9.0",
12
9
  ]
13
10
 
14
11
  [project.optional-dependencies]
@@ -31,3 +28,18 @@ build-backend = "setuptools.build_meta"
31
28
 
32
29
  [tool.setuptools]
33
30
  packages = ["graphai", "graphai.nodes"]
31
+
32
+ [tool.mypy]
33
+ python_version = "3.10"
34
+ warn_return_any = true
35
+ warn_unused_configs = true
36
+ ignore_missing_imports = true
37
+ disallow_untyped_defs = false
38
+ disallow_incomplete_defs = false
39
+ check_untyped_defs = true
40
+ no_implicit_optional = true
41
+ warn_redundant_casts = true
42
+ warn_unused_ignores = true
43
+ warn_no_return = true
44
+ follow_imports = "normal"
45
+ strict_optional = true
@@ -1,4 +0,0 @@
1
- from graphai.graph import Graph
2
- from graphai.nodes import node, router
3
-
4
- __all__ = ["node", "router", "Graph"]
File without changes
File without changes