ace-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ace/__init__.py ADDED
@@ -0,0 +1,53 @@
1
+ """Agentic Context Engineering (ACE) reproduction framework."""
2
+
3
+ from .playbook import Bullet, Playbook
4
+ from .delta import DeltaOperation, DeltaBatch
5
+ from .llm import LLMClient, DummyLLMClient, TransformersLLMClient
6
+ from .roles import (
7
+ Generator,
8
+ Reflector,
9
+ Curator,
10
+ GeneratorOutput,
11
+ ReflectorOutput,
12
+ CuratorOutput,
13
+ )
14
+ from .adaptation import (
15
+ OfflineAdapter,
16
+ OnlineAdapter,
17
+ Sample,
18
+ TaskEnvironment,
19
+ EnvironmentResult,
20
+ AdapterStepResult,
21
+ )
22
+
23
+ # Import production LLM clients if available
24
+ try:
25
+ from .llm_providers import LiteLLMClient
26
+ LITELLM_AVAILABLE = True
27
+ except ImportError:
28
+ LiteLLMClient = None
29
+ LITELLM_AVAILABLE = False
30
+
31
+ __all__ = [
32
+ "Bullet",
33
+ "Playbook",
34
+ "DeltaOperation",
35
+ "DeltaBatch",
36
+ "LLMClient",
37
+ "DummyLLMClient",
38
+ "TransformersLLMClient",
39
+ "LiteLLMClient",
40
+ "Generator",
41
+ "Reflector",
42
+ "Curator",
43
+ "GeneratorOutput",
44
+ "ReflectorOutput",
45
+ "CuratorOutput",
46
+ "OfflineAdapter",
47
+ "OnlineAdapter",
48
+ "Sample",
49
+ "TaskEnvironment",
50
+ "EnvironmentResult",
51
+ "AdapterStepResult",
52
+ "LITELLM_AVAILABLE",
53
+ ]
ace/adaptation.py ADDED
@@ -0,0 +1,193 @@
1
+ """Adaptation loops for offline and online ACE training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, Iterable, List, Optional, Sequence
9
+
10
+ from .playbook import Playbook
11
+ from .roles import Curator, CuratorOutput, Generator, GeneratorOutput, Reflector, ReflectorOutput
12
+
13
+
14
+ @dataclass
15
+ class Sample:
16
+ """Single task instance presented to ACE."""
17
+
18
+ question: str
19
+ context: str = ""
20
+ ground_truth: Optional[str] = None
21
+ metadata: Dict[str, object] = field(default_factory=dict)
22
+
23
+
24
+ @dataclass
25
+ class EnvironmentResult:
26
+ """Feedback returned by the task environment after executing the generator output."""
27
+
28
+ feedback: str
29
+ ground_truth: Optional[str]
30
+ metrics: Dict[str, float] = field(default_factory=dict)
31
+
32
+
33
+ class TaskEnvironment(ABC):
34
+ """Defines how to evaluate generator outputs for a sample."""
35
+
36
+ @abstractmethod
37
+ def evaluate(
38
+ self, sample: Sample, generator_output: GeneratorOutput
39
+ ) -> EnvironmentResult:
40
+ """Return environment feedback plus optional ground truth information."""
41
+
42
+
43
+ @dataclass
44
+ class AdapterStepResult:
45
+ sample: Sample
46
+ generator_output: GeneratorOutput
47
+ environment_result: EnvironmentResult
48
+ reflection: ReflectorOutput
49
+ curator_output: CuratorOutput
50
+ playbook_snapshot: str
51
+
52
+
53
+ class AdapterBase:
54
+ """Shared orchestration logic for offline and online ACE adaptation."""
55
+
56
+ def __init__(
57
+ self,
58
+ *,
59
+ playbook: Optional[Playbook] = None,
60
+ generator: Generator,
61
+ reflector: Reflector,
62
+ curator: Curator,
63
+ max_refinement_rounds: int = 1,
64
+ reflection_window: int = 3,
65
+ ) -> None:
66
+ self.playbook = playbook or Playbook()
67
+ self.generator = generator
68
+ self.reflector = reflector
69
+ self.curator = curator
70
+ self.max_refinement_rounds = max_refinement_rounds
71
+ self.reflection_window = reflection_window
72
+ self._recent_reflections: List[str] = []
73
+
74
+ # ------------------------------------------------------------------ #
75
+ def _reflection_context(self) -> str:
76
+ return "\n---\n".join(self._recent_reflections)
77
+
78
+ def _update_recent_reflections(self, reflection: ReflectorOutput) -> None:
79
+ serialized = json.dumps(reflection.raw, ensure_ascii=False)
80
+ self._recent_reflections.append(serialized)
81
+ if len(self._recent_reflections) > self.reflection_window:
82
+ self._recent_reflections = self._recent_reflections[-self.reflection_window :]
83
+
84
+ def _apply_bullet_tags(self, reflection: ReflectorOutput) -> None:
85
+ for tag in reflection.bullet_tags:
86
+ try:
87
+ self.playbook.tag_bullet(tag.id, tag.tag)
88
+ except ValueError:
89
+ continue
90
+
91
+ def _question_context(self, sample: Sample, environment_result: EnvironmentResult) -> str:
92
+ parts = [
93
+ f"question: {sample.question}",
94
+ f"context: {sample.context}",
95
+ f"metadata: {json.dumps(sample.metadata)}",
96
+ f"feedback: {environment_result.feedback}",
97
+ f"ground_truth: {environment_result.ground_truth}",
98
+ ]
99
+ return "\n".join(parts)
100
+
101
+ def _progress_string(self, epoch: int, total_epochs: int, step: int, total_steps: int) -> str:
102
+ return f"epoch {epoch}/{total_epochs} · sample {step}/{total_steps}"
103
+
104
+ def _process_sample(
105
+ self,
106
+ sample: Sample,
107
+ environment: TaskEnvironment,
108
+ *,
109
+ epoch: int,
110
+ total_epochs: int,
111
+ step_index: int,
112
+ total_steps: int,
113
+ ) -> AdapterStepResult:
114
+ generator_output = self.generator.generate(
115
+ question=sample.question,
116
+ context=sample.context,
117
+ playbook=self.playbook,
118
+ reflection=self._reflection_context(),
119
+ )
120
+ env_result = environment.evaluate(sample, generator_output)
121
+ reflection = self.reflector.reflect(
122
+ question=sample.question,
123
+ generator_output=generator_output,
124
+ playbook=self.playbook,
125
+ ground_truth=env_result.ground_truth,
126
+ feedback=env_result.feedback,
127
+ max_refinement_rounds=self.max_refinement_rounds,
128
+ )
129
+ self._apply_bullet_tags(reflection)
130
+ self._update_recent_reflections(reflection)
131
+ curator_output = self.curator.curate(
132
+ reflection=reflection,
133
+ playbook=self.playbook,
134
+ question_context=self._question_context(sample, env_result),
135
+ progress=self._progress_string(epoch, total_epochs, step_index, total_steps),
136
+ )
137
+ self.playbook.apply_delta(curator_output.delta)
138
+ return AdapterStepResult(
139
+ sample=sample,
140
+ generator_output=generator_output,
141
+ environment_result=env_result,
142
+ reflection=reflection,
143
+ curator_output=curator_output,
144
+ playbook_snapshot=self.playbook.as_prompt(),
145
+ )
146
+
147
+
148
+ class OfflineAdapter(AdapterBase):
149
+ """Runs multi-epoch offline adaptation on a training split."""
150
+
151
+ def run(
152
+ self,
153
+ samples: Sequence[Sample],
154
+ environment: TaskEnvironment,
155
+ epochs: int = 1,
156
+ ) -> List[AdapterStepResult]:
157
+ results: List[AdapterStepResult] = []
158
+ total_steps = len(samples)
159
+ for epoch_idx in range(1, epochs + 1):
160
+ for step_idx, sample in enumerate(samples, start=1):
161
+ result = self._process_sample(
162
+ sample,
163
+ environment,
164
+ epoch=epoch_idx,
165
+ total_epochs=epochs,
166
+ step_index=step_idx,
167
+ total_steps=total_steps,
168
+ )
169
+ results.append(result)
170
+ return results
171
+
172
+
173
+ class OnlineAdapter(AdapterBase):
174
+ """Processes a stream of samples sequentially, updating the playbook in-place."""
175
+
176
+ def run(
177
+ self,
178
+ samples: Iterable[Sample],
179
+ environment: TaskEnvironment,
180
+ ) -> List[AdapterStepResult]:
181
+ results: List[AdapterStepResult] = []
182
+ step_idx = 0
183
+ for step_idx, sample in enumerate(samples, start=1):
184
+ result = self._process_sample(
185
+ sample,
186
+ environment,
187
+ epoch=1,
188
+ total_epochs=1,
189
+ step_index=step_idx,
190
+ total_steps=step_idx,
191
+ )
192
+ results.append(result)
193
+ return results
ace/delta.py ADDED
@@ -0,0 +1,67 @@
1
+ """Delta operations produced by the ACE Curator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, Iterable, List, Literal, Optional
7
+
8
+
9
+ OperationType = Literal["ADD", "UPDATE", "TAG", "REMOVE"]
10
+
11
+
12
+ @dataclass
13
+ class DeltaOperation:
14
+ """Single mutation to apply to the playbook."""
15
+
16
+ type: OperationType
17
+ section: str
18
+ content: Optional[str] = None
19
+ bullet_id: Optional[str] = None
20
+ metadata: Dict[str, int] = field(default_factory=dict)
21
+
22
+ @classmethod
23
+ def from_json(cls, payload: Dict[str, object]) -> "DeltaOperation":
24
+ return cls(
25
+ type=str(payload["type"]),
26
+ section=str(payload.get("section", "")),
27
+ content=payload.get("content") and str(payload["content"]),
28
+ bullet_id=payload.get("bullet_id")
29
+ and str(payload.get("bullet_id")), # type: ignore[arg-type]
30
+ metadata={
31
+ str(k): int(v) for k, v in (payload.get("metadata") or {}).items()
32
+ },
33
+ )
34
+
35
+ def to_json(self) -> Dict[str, object]:
36
+ data: Dict[str, object] = {"type": self.type, "section": self.section}
37
+ if self.content is not None:
38
+ data["content"] = self.content
39
+ if self.bullet_id is not None:
40
+ data["bullet_id"] = self.bullet_id
41
+ if self.metadata:
42
+ data["metadata"] = self.metadata
43
+ return data
44
+
45
+
46
+ @dataclass
47
+ class DeltaBatch:
48
+ """Bundle of curator reasoning and operations."""
49
+
50
+ reasoning: str
51
+ operations: List[DeltaOperation] = field(default_factory=list)
52
+
53
+ @classmethod
54
+ def from_json(cls, payload: Dict[str, object]) -> "DeltaBatch":
55
+ ops_payload = payload.get("operations")
56
+ operations = []
57
+ if isinstance(ops_payload, Iterable):
58
+ for item in ops_payload:
59
+ if isinstance(item, dict):
60
+ operations.append(DeltaOperation.from_json(item))
61
+ return cls(reasoning=str(payload.get("reasoning", "")), operations=operations)
62
+
63
+ def to_json(self) -> Dict[str, object]:
64
+ return {
65
+ "reasoning": self.reasoning,
66
+ "operations": [op.to_json() for op in self.operations],
67
+ }
ace/llm.py ADDED
@@ -0,0 +1,169 @@
1
+ """LLM client abstractions used by ACE components."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ import json
7
+ from collections import deque
8
+ from dataclasses import dataclass
9
+ from typing import Any, Deque, Dict, Optional, Union
10
+
11
+
12
+ @dataclass
13
+ class LLMResponse:
14
+ """Container for LLM outputs."""
15
+
16
+ text: str
17
+ raw: Optional[Dict[str, Any]] = None
18
+
19
+
20
+ class LLMClient(ABC):
21
+ """Abstract interface so ACE can plug into any chat/completions API."""
22
+
23
+ def __init__(self, model: Optional[str] = None) -> None:
24
+ self.model = model
25
+
26
+ @abstractmethod
27
+ def complete(self, prompt: str, **kwargs: Any) -> LLMResponse:
28
+ """Return the model text for a given prompt."""
29
+
30
+
31
+ class DummyLLMClient(LLMClient):
32
+ """Deterministic LLM stub for testing and dry runs."""
33
+
34
+ def __init__(self, responses: Optional[Deque[str]] = None) -> None:
35
+ super().__init__(model="dummy")
36
+ self._responses: Deque[str] = responses or deque()
37
+
38
+ def queue(self, text: str) -> None:
39
+ """Enqueue a response to be used on the next completion call."""
40
+ self._responses.append(text)
41
+
42
+ def complete(self, prompt: str, **kwargs: Any) -> LLMResponse:
43
+ if not self._responses:
44
+ raise RuntimeError("DummyLLMClient ran out of queued responses.")
45
+ return LLMResponse(text=self._responses.popleft())
46
+
47
+
48
+ class TransformersLLMClient(LLMClient):
49
+ """LLM client powered by `transformers` pipelines for chat-style models."""
50
+
51
+ def __init__(
52
+ self,
53
+ model_path: str,
54
+ *,
55
+ max_new_tokens: int = 512,
56
+ temperature: float = 0.0,
57
+ top_p: float = 0.9,
58
+ device_map: Union[str, Dict[str, int]] = "auto",
59
+ torch_dtype: Union[str, "torch.dtype"] = "auto",
60
+ trust_remote_code: bool = True,
61
+ system_prompt: Optional[str] = None,
62
+ generation_kwargs: Optional[Dict[str, Any]] = None,
63
+ ) -> None:
64
+ super().__init__(model=model_path)
65
+
66
+ # Import transformers lazily to avoid mandatory dependency for all users.
67
+ from transformers import AutoTokenizer, pipeline # type: ignore[import-untyped]
68
+
69
+ self._tokenizer = AutoTokenizer.from_pretrained(
70
+ model_path, trust_remote_code=trust_remote_code
71
+ )
72
+ self._pipeline = pipeline(
73
+ "text-generation",
74
+ model=model_path,
75
+ tokenizer=self._tokenizer,
76
+ torch_dtype=torch_dtype,
77
+ device_map=device_map,
78
+ trust_remote_code=trust_remote_code,
79
+ )
80
+ self._system_prompt = system_prompt or (
81
+ "You are a JSON-only assistant that MUST reply with a single valid JSON object without extra text.\n"
82
+ "Reasoning: low\n"
83
+ "Do not expose analysis or chain-of-thought. Respond using the final JSON only."
84
+ )
85
+ self._defaults: Dict[str, Any] = {
86
+ "max_new_tokens": max_new_tokens,
87
+ "temperature": temperature,
88
+ "top_p": top_p,
89
+ "do_sample": temperature > 0.0,
90
+ "return_full_text": False,
91
+ }
92
+ if generation_kwargs:
93
+ self._defaults.update(generation_kwargs)
94
+
95
+ def complete(self, prompt: str, **kwargs: Any) -> LLMResponse:
96
+ call_kwargs = dict(self._defaults)
97
+ kwargs = dict(kwargs)
98
+ kwargs.pop("refinement_round", None)
99
+ call_kwargs.update(kwargs)
100
+
101
+ # Build chat-formatted messages to leverage harmony template.
102
+ messages = [
103
+ {"role": "system", "content": self._system_prompt},
104
+ {"role": "user", "content": prompt},
105
+ ]
106
+
107
+ outputs = self._pipeline(messages, **call_kwargs)
108
+ text = self._postprocess_text(self._extract_text(outputs))
109
+ return LLMResponse(text=text, raw={"outputs": outputs})
110
+
111
+ def _extract_text(self, outputs: Any) -> str:
112
+ """Normalize pipeline outputs into a single string response."""
113
+ if not outputs:
114
+ return ""
115
+ candidate = outputs[0]
116
+
117
+ # Newer transformers versions return {"generated_text": [{"role": ..., "content": ...}, ...]}
118
+ if isinstance(candidate, dict) and "generated_text" in candidate:
119
+ generated = candidate["generated_text"]
120
+ if isinstance(generated, list):
121
+ # Grab the assistant role content if present.
122
+ for message in generated:
123
+ if isinstance(message, dict) and message.get("role") == "assistant":
124
+ content = message.get("content")
125
+ if isinstance(content, str):
126
+ return content.strip()
127
+ # Fallback to last item's content/text.
128
+ last = generated[-1]
129
+ if isinstance(last, dict):
130
+ return str(last.get("content") or last.get("text") or "")
131
+ return str(last)
132
+ if isinstance(generated, dict):
133
+ return str(generated.get("content") or generated.get("text") or "")
134
+ return str(generated)
135
+
136
+ # Older versions might return {"generated_text": "..."}
137
+ if isinstance(candidate, dict) and isinstance(candidate.get("generated_text"), str):
138
+ return candidate["generated_text"].strip()
139
+
140
+ # Ultimate fallback: string representation.
141
+ return str(candidate).strip()
142
+
143
+ def _postprocess_text(self, text: str) -> str:
144
+ """Trim analyzer prefixes and isolate JSON payloads when present."""
145
+ trimmed = text.strip()
146
+ if not trimmed:
147
+ return trimmed
148
+
149
+ marker = "assistantfinal"
150
+ if marker in trimmed:
151
+ trimmed = trimmed.split(marker, 1)[1].strip()
152
+
153
+ if trimmed.startswith(marker):
154
+ trimmed = trimmed[len(marker) :].strip()
155
+
156
+ # Attempt to extract the first JSON object substring.
157
+ if trimmed and trimmed[0] != "{":
158
+ start = trimmed.find("{")
159
+ end = trimmed.rfind("}")
160
+ if start != -1 and end != -1 and end > start:
161
+ candidate = trimmed[start : end + 1].strip()
162
+ candidate_clean = candidate.replace("\r", " ").replace("\n", " ")
163
+ try:
164
+ json.loads(candidate_clean)
165
+ return candidate_clean
166
+ except json.JSONDecodeError:
167
+ pass
168
+
169
+ return trimmed.replace("\r", " ").replace("\n", " ")
@@ -0,0 +1,13 @@
1
+ """Production LLM client implementations for ACE."""
2
+
3
+ from .litellm_client import LiteLLMClient
4
+
5
+ try:
6
+ from .langchain_client import LangChainLiteLLMClient
7
+ except ImportError:
8
+ LangChainLiteLLMClient = None # Optional dependency
9
+
10
+ __all__ = [
11
+ "LiteLLMClient",
12
+ "LangChainLiteLLMClient",
13
+ ]