agentforge-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,55 @@
1
+ """Pretty-printable evaluation report."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import statistics
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+
10
+
11
+ @dataclass
12
+ class EvalReport:
13
+ n: int
14
+ means: dict[str, float]
15
+ per_sample: dict[str, list[float]]
16
+ latencies_ms: list[float] = field(default_factory=list)
17
+ extras: dict = field(default_factory=dict)
18
+
19
+ @property
20
+ def p50_ms(self) -> float:
21
+ return statistics.median(self.latencies_ms) if self.latencies_ms else 0.0
22
+
23
+ @property
24
+ def p95_ms(self) -> float:
25
+ if not self.latencies_ms:
26
+ return 0.0
27
+ s = sorted(self.latencies_ms)
28
+ return s[int(0.95 * (len(s) - 1))]
29
+
30
+ def as_table(self) -> str:
31
+ width = max((len(m) for m in self.means), default=14)
32
+ bar = "+" + "-" * (width + 4) + "+--------+"
33
+ lines = [bar, f"| {'metric':<{width + 2}} | mean |", bar]
34
+ for name, val in self.means.items():
35
+ lines.append(f"| {name:<{width + 2}} | {val:.3f} |")
36
+ lines.append(bar)
37
+ if self.latencies_ms:
38
+ lines.append(f"n={self.n} · p50={self.p50_ms:.0f}ms · p95={self.p95_ms:.0f}ms")
39
+ else:
40
+ lines.append(f"n={self.n}")
41
+ return "\n".join(lines)
42
+
43
+ def save(self, path: str | Path) -> None:
44
+ Path(path).write_text(
45
+ json.dumps(
46
+ {
47
+ "n": self.n,
48
+ "means": self.means,
49
+ "per_sample": self.per_sample,
50
+ "latencies_ms": self.latencies_ms,
51
+ "extras": self.extras,
52
+ },
53
+ indent=2,
54
+ )
55
+ )
@@ -0,0 +1,7 @@
1
+ """LLM backends."""
2
+
3
+ from agentforge.llm.base import LLM
4
+ from agentforge.llm.hf import HFLLM
5
+ from agentforge.llm.quantized import QuantizedHFLLM
6
+
7
+ __all__ = ["HFLLM", "LLM", "QuantizedHFLLM"]
agentforge/llm/base.py ADDED
@@ -0,0 +1,16 @@
1
+ """LLM protocol — anything that maps a prompt to text and honors stop strings."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+
8
+ class LLM(Protocol):
9
+ def generate(
10
+ self,
11
+ prompt: str,
12
+ *,
13
+ max_new_tokens: int = 256,
14
+ temperature: float = 0.0,
15
+ stop: list[str] | None = None,
16
+ ) -> str: ...
agentforge/llm/hf.py ADDED
@@ -0,0 +1,83 @@
1
+ """HuggingFace causal LM backend.
2
+
3
+ ReAct-friendly: applies the model's chat template, accepts stop strings,
4
+ and truncates the output at the first stop marker so the next loop iteration
5
+ gets a clean continuation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any
11
+
12
+ import torch
13
+
14
+
15
+ class HFLLM:
16
+ def __init__(
17
+ self,
18
+ model_id: str = "Qwen/Qwen2.5-3B-Instruct",
19
+ *,
20
+ dtype: str = "auto",
21
+ device_map: str | dict | None = "auto",
22
+ trust_remote_code: bool = False,
23
+ **kwargs: Any,
24
+ ) -> None:
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+
27
+ self.model_id = model_id
28
+ self.tokenizer = AutoTokenizer.from_pretrained(
29
+ model_id, trust_remote_code=trust_remote_code
30
+ )
31
+ self.model = AutoModelForCausalLM.from_pretrained(
32
+ model_id,
33
+ torch_dtype=dtype,
34
+ device_map=device_map,
35
+ trust_remote_code=trust_remote_code,
36
+ **kwargs,
37
+ )
38
+ if self.tokenizer.pad_token_id is None:
39
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
40
+
41
+ def generate(
42
+ self,
43
+ prompt: str,
44
+ *,
45
+ max_new_tokens: int = 256,
46
+ temperature: float = 0.0,
47
+ stop: list[str] | None = None,
48
+ ) -> str:
49
+ text = self._apply_chat_template(prompt)
50
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
51
+ gen_kwargs: dict[str, Any] = {
52
+ "max_new_tokens": max_new_tokens,
53
+ "pad_token_id": self.tokenizer.pad_token_id,
54
+ }
55
+ if temperature > 0:
56
+ gen_kwargs.update(do_sample=True, temperature=temperature)
57
+ else:
58
+ gen_kwargs.update(do_sample=False)
59
+
60
+ with torch.no_grad():
61
+ out = self.model.generate(**inputs, **gen_kwargs)
62
+ new_ids = out[0, inputs.input_ids.shape[-1] :]
63
+ answer = self.tokenizer.decode(new_ids, skip_special_tokens=True)
64
+ return _truncate_at_stop(answer, stop)
65
+
66
+ def _apply_chat_template(self, prompt: str) -> str:
67
+ if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
68
+ messages = [{"role": "user", "content": prompt}]
69
+ return self.tokenizer.apply_chat_template(
70
+ messages, tokenize=False, add_generation_prompt=True
71
+ )
72
+ return prompt
73
+
74
+
75
+ def _truncate_at_stop(text: str, stop: list[str] | None) -> str:
76
+ if not stop:
77
+ return text.strip()
78
+ earliest = len(text)
79
+ for s in stop:
80
+ idx = text.find(s)
81
+ if idx != -1:
82
+ earliest = min(earliest, idx)
83
+ return text[:earliest].strip()
@@ -0,0 +1,39 @@
1
+ """Quantized LLM via turboquant-ml — same trick as ragforge-ml."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from agentforge.llm.hf import HFLLM
8
+
9
+
10
+ class QuantizedHFLLM(HFLLM):
11
+ """:class:`HFLLM` that applies a TurboQuant method before serving.
12
+
13
+ Example
14
+ -------
15
+ >>> llm = QuantizedHFLLM("meta-llama/Llama-3.2-3B-Instruct", method="bnb-nf4")
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model_id: str = "meta-llama/Llama-3.2-3B-Instruct",
21
+ *,
22
+ method: str = "bnb-nf4",
23
+ quant_kwargs: dict[str, Any] | None = None,
24
+ **kwargs: Any,
25
+ ) -> None:
26
+ super().__init__(model_id, **kwargs)
27
+ self.method = method
28
+ self.model = _quantize(self.model, method=method, **(quant_kwargs or {}))
29
+
30
+
31
+ def _quantize(model, *, method: str, **kw):
32
+ try:
33
+ from turboquant import quantize
34
+ except ImportError as e: # pragma: no cover
35
+ raise ImportError(
36
+ "turboquant-ml is required for QuantizedHFLLM. "
37
+ 'Install with `pip install "agentforge-ml[quantized]"`.'
38
+ ) from e
39
+ return quantize(model, method=method, **kw)
@@ -0,0 +1,7 @@
1
+ """Conversation memory backends."""
2
+
3
+ from agentforge.memory.base import Memory, Message
4
+ from agentforge.memory.conversation import ConversationMemory
5
+ from agentforge.memory.persistent import PersistentMemory
6
+
7
+ __all__ = ["ConversationMemory", "Memory", "Message", "PersistentMemory"]
@@ -0,0 +1,23 @@
1
+ """Memory protocol + Message type."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from dataclasses import dataclass, field
7
+ from typing import Protocol
8
+
9
+
10
+ @dataclass
11
+ class Message:
12
+ role: str
13
+ content: str
14
+ ts: float = field(default_factory=time.time)
15
+ extras: dict = field(default_factory=dict)
16
+
17
+
18
+ class Memory(Protocol):
19
+ def add(self, role: str, content: str, **extras: object) -> None: ...
20
+
21
+ def get(self, *, limit: int | None = None) -> list[Message]: ...
22
+
23
+ def clear(self) -> None: ...
@@ -0,0 +1,30 @@
1
+ """In-memory conversation history with a rolling window."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import deque
6
+
7
+ from agentforge.memory.base import Message
8
+
9
+
10
+ class ConversationMemory:
11
+ """FIFO window of the last ``max_messages`` messages."""
12
+
13
+ def __init__(self, max_messages: int = 64) -> None:
14
+ self.max_messages = max_messages
15
+ self._buf: deque[Message] = deque(maxlen=max_messages)
16
+
17
+ def add(self, role: str, content: str, **extras: object) -> None:
18
+ self._buf.append(Message(role=role, content=content, extras=dict(extras)))
19
+
20
+ def get(self, *, limit: int | None = None) -> list[Message]:
21
+ msgs = list(self._buf)
22
+ if limit:
23
+ msgs = msgs[-limit:]
24
+ return msgs
25
+
26
+ def clear(self) -> None:
27
+ self._buf.clear()
28
+
29
+ def __len__(self) -> int:
30
+ return len(self._buf)
@@ -0,0 +1,80 @@
1
+ """SQLite-backed persistent memory.
2
+
3
+ Sessions are keyed by an arbitrary string id — useful for multi-user agents
4
+ where each user has their own running history. No ORM, no migration tooling:
5
+ a single ``messages`` table is created on first use.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import sqlite3
12
+ import time
13
+ from pathlib import Path
14
+
15
+ from agentforge.memory.base import Message
16
+
17
+
18
+ class PersistentMemory:
19
+ def __init__(
20
+ self, db_path: str | Path = "agentforge_memory.db", *, session: str = "default"
21
+ ) -> None:
22
+ self.db_path = str(db_path)
23
+ self.session = session
24
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
25
+ self._init()
26
+
27
+ def _init(self) -> None:
28
+ with sqlite3.connect(self.db_path) as cx:
29
+ cx.execute(
30
+ """
31
+ CREATE TABLE IF NOT EXISTS messages (
32
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
33
+ session TEXT NOT NULL,
34
+ role TEXT NOT NULL,
35
+ content TEXT NOT NULL,
36
+ ts REAL NOT NULL,
37
+ extras_json TEXT
38
+ )
39
+ """
40
+ )
41
+ cx.execute(
42
+ "CREATE INDEX IF NOT EXISTS idx_messages_session_ts ON messages(session, ts)"
43
+ )
44
+
45
+ def add(self, role: str, content: str, **extras: object) -> None:
46
+ with sqlite3.connect(self.db_path) as cx:
47
+ cx.execute(
48
+ "INSERT INTO messages (session, role, content, ts, extras_json) VALUES (?, ?, ?, ?, ?)",
49
+ (self.session, role, content, time.time(), json.dumps(extras) if extras else None),
50
+ )
51
+
52
+ def get(self, *, limit: int | None = None) -> list[Message]:
53
+ with sqlite3.connect(self.db_path) as cx:
54
+ cx.row_factory = sqlite3.Row
55
+ q = "SELECT role, content, ts, extras_json FROM messages WHERE session = ? ORDER BY ts ASC"
56
+ rows = cx.execute(q, (self.session,)).fetchall()
57
+ msgs = [
58
+ Message(
59
+ role=r["role"],
60
+ content=r["content"],
61
+ ts=r["ts"],
62
+ extras=json.loads(r["extras_json"]) if r["extras_json"] else {},
63
+ )
64
+ for r in rows
65
+ ]
66
+ if limit:
67
+ msgs = msgs[-limit:]
68
+ return msgs
69
+
70
+ def clear(self) -> None:
71
+ with sqlite3.connect(self.db_path) as cx:
72
+ cx.execute("DELETE FROM messages WHERE session = ?", (self.session,))
73
+
74
+ def __len__(self) -> int:
75
+ with sqlite3.connect(self.db_path) as cx:
76
+ return int(
77
+ cx.execute(
78
+ "SELECT COUNT(*) FROM messages WHERE session = ?", (self.session,)
79
+ ).fetchone()[0]
80
+ )
@@ -0,0 +1,5 @@
1
+ """FastAPI serve module."""
2
+
3
+ from agentforge.serve.app import build_app
4
+
5
+ __all__ = ["build_app"]
@@ -0,0 +1,83 @@
1
+ """FastAPI app for AgentForge."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from agentforge.core.agent import Agent
10
+
11
+
12
+ class AskRequest(BaseModel):
13
+ question: str = Field(..., min_length=1)
14
+ max_steps: int | None = Field(None, ge=1, le=20)
15
+ max_new_tokens: int = Field(256, ge=16, le=2048)
16
+
17
+
18
+ class StepModel(BaseModel):
19
+ thought: str
20
+ tool: str | None
21
+ action_input: str | None
22
+ observation: str | None
23
+ elapsed_ms: float
24
+
25
+
26
+ class AskResponse(BaseModel):
27
+ question: str
28
+ final_answer: str
29
+ n_steps: int
30
+ success: bool
31
+ latency_ms: float
32
+ steps: list[StepModel]
33
+
34
+
35
+ class ToolModel(BaseModel):
36
+ name: str
37
+ description: str
38
+
39
+
40
+ def build_app(agent: Agent) -> Any:
41
+ try:
42
+ from fastapi import FastAPI
43
+ except ImportError as e: # pragma: no cover
44
+ raise ImportError(
45
+ 'FastAPI is required. Install with `pip install "agentforge-ml[serve]"`.'
46
+ ) from e
47
+
48
+ app = FastAPI(title="AgentForge", version="0.1.0")
49
+
50
+ @app.get("/health")
51
+ def health() -> dict:
52
+ return {"status": "ok", "n_tools": len(agent.tools)}
53
+
54
+ @app.get("/tools", response_model=list[ToolModel])
55
+ def list_tools() -> list[ToolModel]:
56
+ return [ToolModel(name=t.name, description=t.description) for t in agent.tools]
57
+
58
+ @app.post("/ask", response_model=AskResponse)
59
+ def ask(req: AskRequest) -> AskResponse:
60
+ result = agent.run(
61
+ req.question,
62
+ max_steps=req.max_steps,
63
+ max_new_tokens=req.max_new_tokens,
64
+ )
65
+ return AskResponse(
66
+ question=result.question,
67
+ final_answer=result.final_answer,
68
+ n_steps=result.n_steps,
69
+ success=result.success,
70
+ latency_ms=result.latency_ms,
71
+ steps=[
72
+ StepModel(
73
+ thought=s.thought,
74
+ tool=s.tool,
75
+ action_input=s.action_input,
76
+ observation=s.observation,
77
+ elapsed_ms=s.elapsed_ms,
78
+ )
79
+ for s in result.steps
80
+ ],
81
+ )
82
+
83
+ return app
@@ -0,0 +1,18 @@
1
+ """Built-in tool registry."""
2
+
3
+ from agentforge.tools.base import Tool, ToolRegistry
4
+ from agentforge.tools.calculator import Calculator
5
+ from agentforge.tools.python_repl import PythonREPL
6
+ from agentforge.tools.rag import RAGTool
7
+ from agentforge.tools.sql import SQLTool
8
+ from agentforge.tools.web_search import WebSearch
9
+
10
+ __all__ = [
11
+ "Calculator",
12
+ "PythonREPL",
13
+ "RAGTool",
14
+ "SQLTool",
15
+ "Tool",
16
+ "ToolRegistry",
17
+ "WebSearch",
18
+ ]
@@ -0,0 +1,55 @@
1
+ """Tool protocol and registry.
2
+
3
+ A tool is *anything* with three things: a name (so the LLM can request it), a
4
+ description (so the LLM knows what it does), and a ``run(input_str) -> str``
5
+ method (so the orchestrator can call it). We avoid JSON schemas on purpose —
6
+ small open models often emit malformed JSON and the cost/benefit isn't there.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from collections.abc import Iterator
12
+ from typing import Protocol, runtime_checkable
13
+
14
+
15
+ @runtime_checkable
16
+ class Tool(Protocol):
17
+ name: str
18
+ description: str
19
+
20
+ def run(self, input_str: str) -> str: ...
21
+
22
+
23
+ class ToolRegistry:
24
+ def __init__(self, tools: list[Tool] | None = None) -> None:
25
+ self._tools: dict[str, Tool] = {}
26
+ for t in tools or []:
27
+ self.register(t)
28
+
29
+ def register(self, tool: Tool) -> None:
30
+ if not getattr(tool, "name", None):
31
+ raise ValueError(f"Tool has no `.name`: {tool!r}")
32
+ self._tools[tool.name] = tool
33
+
34
+ def get(self, name: str) -> Tool | None:
35
+ # Lenient lookup so the LLM can write "Calculator", "calculator", or "calc".
36
+ if name in self._tools:
37
+ return self._tools[name]
38
+ lc = name.strip().lower()
39
+ for k, v in self._tools.items():
40
+ if k.lower() == lc:
41
+ return v
42
+ return None
43
+
44
+ def __iter__(self) -> Iterator[Tool]:
45
+ return iter(self._tools.values())
46
+
47
+ def __len__(self) -> int:
48
+ return len(self._tools)
49
+
50
+ def __contains__(self, name: str) -> bool:
51
+ return self.get(name) is not None
52
+
53
+ @property
54
+ def names(self) -> list[str]:
55
+ return list(self._tools)
@@ -0,0 +1,115 @@
1
+ """Calculator tool — safe arithmetic via AST whitelist.
2
+
3
+ We deliberately avoid ``eval()`` and ``sympy.sympify`` directly. The whitelist
4
+ covers the operators a calculator actually needs: + - * / // % **, unary +/-,
5
+ parentheses, int/float literals, and a fixed set of math functions.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import ast
11
+ import math
12
+ import operator as op
13
+
14
+ _BIN_OPS = {
15
+ ast.Add: op.add,
16
+ ast.Sub: op.sub,
17
+ ast.Mult: op.mul,
18
+ ast.Div: op.truediv,
19
+ ast.FloorDiv: op.floordiv,
20
+ ast.Mod: op.mod,
21
+ ast.Pow: op.pow,
22
+ }
23
+
24
+ _UNARY_OPS = {
25
+ ast.UAdd: op.pos,
26
+ ast.USub: op.neg,
27
+ }
28
+
29
+ _FUNCS = {
30
+ "sqrt": math.sqrt,
31
+ "log": math.log,
32
+ "ln": math.log,
33
+ "log2": math.log2,
34
+ "log10": math.log10,
35
+ "exp": math.exp,
36
+ "sin": math.sin,
37
+ "cos": math.cos,
38
+ "tan": math.tan,
39
+ "asin": math.asin,
40
+ "acos": math.acos,
41
+ "atan": math.atan,
42
+ "abs": abs,
43
+ "round": round,
44
+ "floor": math.floor,
45
+ "ceil": math.ceil,
46
+ "max": max,
47
+ "min": min,
48
+ }
49
+
50
+ _CONSTS = {
51
+ "pi": math.pi,
52
+ "e": math.e,
53
+ "tau": math.tau,
54
+ }
55
+
56
+
57
+ class Calculator:
58
+ name = "calculator"
59
+ description = (
60
+ "Safely evaluate an arithmetic expression. "
61
+ "Supports + - * / // % ** parentheses and functions "
62
+ "sqrt, log, exp, sin/cos/tan, abs, round, floor, ceil, max, min. "
63
+ "Constants: pi, e. Example: '47 * 1337'."
64
+ )
65
+
66
+ def run(self, input_str: str) -> str:
67
+ expr = input_str.strip()
68
+ if not expr:
69
+ return "Error: empty expression"
70
+ try:
71
+ tree = ast.parse(expr, mode="eval")
72
+ value = _eval(tree.body)
73
+ except Exception as e:
74
+ return f"Error: {type(e).__name__}: {e}"
75
+ return _fmt(value)
76
+
77
+
78
+ def _eval(node):
79
+ if isinstance(node, ast.Constant):
80
+ if isinstance(node.value, (int, float)):
81
+ return node.value
82
+ raise ValueError(f"unsupported literal: {node.value!r}")
83
+ if isinstance(node, ast.BinOp):
84
+ left = _eval(node.left)
85
+ right = _eval(node.right)
86
+ op_fn = _BIN_OPS.get(type(node.op))
87
+ if op_fn is None:
88
+ raise ValueError(f"unsupported operator: {type(node.op).__name__}")
89
+ return op_fn(left, right)
90
+ if isinstance(node, ast.UnaryOp):
91
+ op_fn = _UNARY_OPS.get(type(node.op))
92
+ if op_fn is None:
93
+ raise ValueError(f"unsupported unary operator: {type(node.op).__name__}")
94
+ return op_fn(_eval(node.operand))
95
+ if isinstance(node, ast.Call):
96
+ if not isinstance(node.func, ast.Name):
97
+ raise ValueError("only simple function calls allowed")
98
+ name = node.func.id
99
+ if name not in _FUNCS:
100
+ raise ValueError(f"unknown function: {name}")
101
+ args = [_eval(a) for a in node.args]
102
+ return _FUNCS[name](*args)
103
+ if isinstance(node, ast.Name):
104
+ if node.id not in _CONSTS:
105
+ raise ValueError(f"unknown identifier: {node.id}")
106
+ return _CONSTS[node.id]
107
+ raise ValueError(f"unsupported expression: {ast.dump(node)}")
108
+
109
+
110
+ def _fmt(v) -> str:
111
+ if isinstance(v, float) and v.is_integer():
112
+ return str(int(v))
113
+ if isinstance(v, float):
114
+ return f"{v:.6g}"
115
+ return str(v)