continualcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- continualcode/__init__.py +17 -0
- continualcode/cli.py +67 -0
- continualcode/session.py +357 -0
- continualcode/tools.py +230 -0
- continualcode/tui.py +994 -0
- continualcode-0.1.0.dist-info/METADATA +115 -0
- continualcode-0.1.0.dist-info/RECORD +10 -0
- continualcode-0.1.0.dist-info/WHEEL +4 -0
- continualcode-0.1.0.dist-info/entry_points.txt +2 -0
- continualcode-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ContinualCode - Human-in-the-loop coding agent with online learning.
|
|
3
|
+
|
|
4
|
+
A Claude Code-style TUI that learns from your approvals via prompt distillation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
|
|
9
|
+
from continualcode.session import ContextDistillSession
|
|
10
|
+
from continualcode.tools import TOOL_SPECS, READONLY_TOOLS, execute_tool
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ContextDistillSession",
|
|
14
|
+
"TOOL_SPECS",
|
|
15
|
+
"READONLY_TOOLS",
|
|
16
|
+
"execute_tool",
|
|
17
|
+
]
|
continualcode/cli.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
CLI entry point for ContinualCode.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
continualcode # Run the TUI
|
|
7
|
+
continualcode --help # Show help
|
|
8
|
+
continualcode --version # Show version
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import sys
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def main() -> int:
|
|
16
|
+
parser = argparse.ArgumentParser(
|
|
17
|
+
prog="continualcode",
|
|
18
|
+
description="Human-in-the-loop coding agent with online learning via prompt distillation",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--version",
|
|
22
|
+
action="version",
|
|
23
|
+
version="%(prog)s 0.1.0",
|
|
24
|
+
)
|
|
25
|
+
parser.add_argument(
|
|
26
|
+
"--model",
|
|
27
|
+
default=None,
|
|
28
|
+
help="Override MODEL environment variable",
|
|
29
|
+
)
|
|
30
|
+
parser.add_argument(
|
|
31
|
+
"--policy",
|
|
32
|
+
default=None,
|
|
33
|
+
help="Path to policy file (default: ./policy_memory.md)",
|
|
34
|
+
)
|
|
35
|
+
parser.add_argument(
|
|
36
|
+
"--checkpoint",
|
|
37
|
+
default=None,
|
|
38
|
+
help="Load checkpoint on startup",
|
|
39
|
+
)
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"--no-training",
|
|
42
|
+
action="store_true",
|
|
43
|
+
help="Disable training (inference only)",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
args = parser.parse_args()
|
|
47
|
+
|
|
48
|
+
# Set environment variables from CLI args
|
|
49
|
+
import os
|
|
50
|
+
if args.model:
|
|
51
|
+
os.environ["MODEL"] = args.model
|
|
52
|
+
if args.policy:
|
|
53
|
+
os.environ["POLICY_PATH"] = args.policy
|
|
54
|
+
if args.checkpoint:
|
|
55
|
+
os.environ["LOAD_CHECKPOINT"] = args.checkpoint
|
|
56
|
+
if args.no_training:
|
|
57
|
+
os.environ["ENABLE_TRAINING"] = "0"
|
|
58
|
+
|
|
59
|
+
# Import and run the TUI
|
|
60
|
+
from continualcode.tui import TinkerCodeApp
|
|
61
|
+
app = TinkerCodeApp()
|
|
62
|
+
app.run()
|
|
63
|
+
return 0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
if __name__ == "__main__":
|
|
67
|
+
sys.exit(main())
|
continualcode/session.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Context (prompt) distillation session for a coding agent.
|
|
4
|
+
|
|
5
|
+
We maintain two prefixes:
|
|
6
|
+
- Teacher prefix: includes a long policy prompt (e.g. repo/user rules).
|
|
7
|
+
- Student prefix: excludes that policy prompt.
|
|
8
|
+
|
|
9
|
+
Generation uses the teacher prefix (better behavior now).
|
|
10
|
+
|
|
11
|
+
Two training modes:
|
|
12
|
+
- off_policy: cross-entropy on the approved assistant message with the student prefix
|
|
13
|
+
(classic "prompt distillation").
|
|
14
|
+
- on_policy: sample a rollout from the student prefix, then compute teacher logprobs
|
|
15
|
+
(teacher prefix) on the student's exact tokens and update via importance
|
|
16
|
+
sampling / KL advantages (on-policy distillation).
|
|
17
|
+
|
|
18
|
+
This is "prompt distillation" / "context distillation":
|
|
19
|
+
teacher(p + x) -> y
|
|
20
|
+
train student(x) -> y
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import os
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
import tinker
|
|
30
|
+
import torch
|
|
31
|
+
from tinker import types
|
|
32
|
+
from tinker.types.tensor_data import TensorData
|
|
33
|
+
from tinker_cookbook.renderers import get_renderer
|
|
34
|
+
from tinker_cookbook.renderers.base import TrainOnWhat
|
|
35
|
+
from tinker_cookbook.supervised.common import datum_from_model_input_weights
|
|
36
|
+
from tinker_cookbook.tokenizer_utils import get_tokenizer
|
|
37
|
+
|
|
38
|
+
from continualcode.tools import TOOL_SPECS
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ContextDistillSession:
|
|
42
|
+
"""Sampling + online prompt distillation training from approved tool calls."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
model: str,
|
|
47
|
+
tinker_url: str | None = None,
|
|
48
|
+
*,
|
|
49
|
+
enable_training: bool = True,
|
|
50
|
+
lora_rank: int = 32,
|
|
51
|
+
learning_rate: float = 1e-5,
|
|
52
|
+
max_tokens: int = 4096,
|
|
53
|
+
temperature: float = 0.7,
|
|
54
|
+
policy_path: str = "./policy_memory.md",
|
|
55
|
+
policy_tag: str = "policy_memory",
|
|
56
|
+
distill_mode: str = "on_policy", # "on_policy" or "off_policy"
|
|
57
|
+
# On-policy distillation config
|
|
58
|
+
train_max_tokens: int = 1024,
|
|
59
|
+
train_temperature: float = 1.0,
|
|
60
|
+
kl_coef: float = 1.0,
|
|
61
|
+
):
|
|
62
|
+
self.model = model
|
|
63
|
+
self.learning_rate = learning_rate
|
|
64
|
+
self.max_tokens = max_tokens
|
|
65
|
+
self.temperature = temperature
|
|
66
|
+
self.train_steps = 0
|
|
67
|
+
|
|
68
|
+
self.policy_path = Path(policy_path)
|
|
69
|
+
self.policy_tag = policy_tag
|
|
70
|
+
self.distill_mode = distill_mode
|
|
71
|
+
self.train_max_tokens = train_max_tokens
|
|
72
|
+
self.train_temperature = train_temperature
|
|
73
|
+
self.kl_coef = kl_coef
|
|
74
|
+
|
|
75
|
+
# Setup tokenizer and renderer
|
|
76
|
+
self.tokenizer = get_tokenizer(model)
|
|
77
|
+
self.renderer = get_renderer("qwen3_instruct", tokenizer=self.tokenizer)
|
|
78
|
+
self.service = tinker.ServiceClient(base_url=tinker_url)
|
|
79
|
+
|
|
80
|
+
# Conversation state (no prefix)
|
|
81
|
+
self.messages: list[dict[str, Any]] = []
|
|
82
|
+
|
|
83
|
+
# Build prefixes
|
|
84
|
+
self._base_system_prompt = f"You are a helpful coding assistant. cwd: {os.getcwd()}"
|
|
85
|
+
self._teacher_prefix: list[dict[str, Any]] = []
|
|
86
|
+
self._student_prefix: list[dict[str, Any]] = []
|
|
87
|
+
self._rebuild_prefixes()
|
|
88
|
+
|
|
89
|
+
# Clients (initialized in init())
|
|
90
|
+
self.enable_training = enable_training
|
|
91
|
+
self.lora_rank = lora_rank
|
|
92
|
+
self.training_client: tinker.TrainingClient | None = None
|
|
93
|
+
self.sampling_client: tinker.SamplingClient | None = None
|
|
94
|
+
|
|
95
|
+
def _load_policy_text(self) -> str:
|
|
96
|
+
try:
|
|
97
|
+
return self.policy_path.read_text(encoding="utf-8")
|
|
98
|
+
except FileNotFoundError:
|
|
99
|
+
return ""
|
|
100
|
+
except Exception:
|
|
101
|
+
# Avoid making the entire session crash on a bad policy file.
|
|
102
|
+
return ""
|
|
103
|
+
|
|
104
|
+
def _rebuild_prefixes(self) -> None:
|
|
105
|
+
policy_text = self._load_policy_text().strip()
|
|
106
|
+
if policy_text:
|
|
107
|
+
teacher_prompt = (
|
|
108
|
+
self._base_system_prompt
|
|
109
|
+
+ f"\n\n<{self.policy_tag}>\n{policy_text}\n</{self.policy_tag}>\n"
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
teacher_prompt = self._base_system_prompt
|
|
113
|
+
|
|
114
|
+
self._teacher_prefix = self.renderer.create_conversation_prefix_with_tools(
|
|
115
|
+
TOOL_SPECS, teacher_prompt
|
|
116
|
+
)
|
|
117
|
+
self._student_prefix = self.renderer.create_conversation_prefix_with_tools(
|
|
118
|
+
TOOL_SPECS, self._base_system_prompt
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def reload_policy(self) -> None:
|
|
122
|
+
"""Reload policy from disk for subsequent generations."""
|
|
123
|
+
self._rebuild_prefixes()
|
|
124
|
+
|
|
125
|
+
async def init(self) -> None:
|
|
126
|
+
"""Initialize training/sampling clients."""
|
|
127
|
+
if self.enable_training:
|
|
128
|
+
self.training_client = self.service.create_lora_training_client(
|
|
129
|
+
base_model=self.model, rank=self.lora_rank
|
|
130
|
+
)
|
|
131
|
+
self.sampling_client = self.training_client.save_weights_and_get_sampling_client("current")
|
|
132
|
+
else:
|
|
133
|
+
self.sampling_client = self.service.create_sampling_client(base_model=self.model)
|
|
134
|
+
|
|
135
|
+
async def sample(self) -> tuple[dict[str, Any], bool]:
|
|
136
|
+
"""Sample a completion using the teacher prefix."""
|
|
137
|
+
if self.sampling_client is None:
|
|
138
|
+
raise RuntimeError("Session not initialized. Call init() first.")
|
|
139
|
+
|
|
140
|
+
# Policy can be updated externally; keep it fresh per turn.
|
|
141
|
+
self.reload_policy()
|
|
142
|
+
|
|
143
|
+
model_input = self.renderer.build_generation_prompt(self._teacher_prefix + self.messages)
|
|
144
|
+
response = await self.sampling_client.sample_async(
|
|
145
|
+
prompt=model_input,
|
|
146
|
+
num_samples=1,
|
|
147
|
+
sampling_params=types.SamplingParams(
|
|
148
|
+
stop=self.renderer.get_stop_sequences(),
|
|
149
|
+
max_tokens=self.max_tokens,
|
|
150
|
+
temperature=self.temperature,
|
|
151
|
+
),
|
|
152
|
+
)
|
|
153
|
+
tokens = response.sequences[0].tokens
|
|
154
|
+
return self.renderer.parse_response(tokens)
|
|
155
|
+
|
|
156
|
+
def train_on_approval(
|
|
157
|
+
self,
|
|
158
|
+
*,
|
|
159
|
+
prompt_messages: list[dict[str, Any]],
|
|
160
|
+
assistant_message: dict[str, Any],
|
|
161
|
+
scale: float = 1.0,
|
|
162
|
+
) -> dict[str, float]:
|
|
163
|
+
"""Train a single update step from an approved action.
|
|
164
|
+
|
|
165
|
+
off_policy: cross-entropy imitation of the approved assistant message.
|
|
166
|
+
on_policy: on-policy distillation (teacher has policy context; student does not).
|
|
167
|
+
"""
|
|
168
|
+
if self.training_client is None:
|
|
169
|
+
return {}
|
|
170
|
+
|
|
171
|
+
if self.distill_mode == "off_policy":
|
|
172
|
+
return self._train_off_policy(prompt_messages, assistant_message, scale=scale)
|
|
173
|
+
return self._train_on_policy(prompt_messages, scale=scale)
|
|
174
|
+
|
|
175
|
+
def _train_off_policy(
|
|
176
|
+
self,
|
|
177
|
+
prompt_messages: list[dict[str, Any]],
|
|
178
|
+
assistant_message: dict[str, Any],
|
|
179
|
+
*,
|
|
180
|
+
scale: float,
|
|
181
|
+
) -> dict[str, float]:
|
|
182
|
+
"""Classic prompt distillation: SFT student(x) -> teacher(p+x) output."""
|
|
183
|
+
if self.training_client is None:
|
|
184
|
+
return {}
|
|
185
|
+
|
|
186
|
+
train_messages = self._student_prefix + prompt_messages + [assistant_message]
|
|
187
|
+
model_input, weights = self.renderer.build_supervised_example(
|
|
188
|
+
train_messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_MESSAGE
|
|
189
|
+
)
|
|
190
|
+
if scale != 1.0:
|
|
191
|
+
weights = [float(w) * float(scale) for w in weights]
|
|
192
|
+
|
|
193
|
+
datum = datum_from_model_input_weights(model_input, weights, max_length=None)
|
|
194
|
+
fwd_bwd = self.training_client.forward_backward([datum], loss_fn="cross_entropy").result()
|
|
195
|
+
self.training_client.optim_step(
|
|
196
|
+
types.AdamParams(learning_rate=self.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8)
|
|
197
|
+
).result()
|
|
198
|
+
|
|
199
|
+
self.sampling_client = self.training_client.save_weights_and_get_sampling_client("current")
|
|
200
|
+
self.train_steps += 1
|
|
201
|
+
|
|
202
|
+
metrics = fwd_bwd.metrics or {}
|
|
203
|
+
loss = metrics.get("loss:sum", metrics.get("loss", 0.0))
|
|
204
|
+
num_tokens = int(sum(1 for w in weights if w > 0))
|
|
205
|
+
return {"step": float(self.train_steps), "loss": float(loss), "tokens": float(num_tokens)}
|
|
206
|
+
|
|
207
|
+
def _train_on_policy(self, prompt_messages: list[dict[str, Any]], *, scale: float) -> dict[str, float]:
|
|
208
|
+
"""On-policy distillation: sample student tokens; teacher (with policy) re-scores them."""
|
|
209
|
+
if self.training_client is None or self.sampling_client is None:
|
|
210
|
+
return {}
|
|
211
|
+
|
|
212
|
+
# Keep policy fresh for teacher scoring (policy may be edited between turns).
|
|
213
|
+
self.reload_policy()
|
|
214
|
+
|
|
215
|
+
student_prompt = self.renderer.build_generation_prompt(self._student_prefix + prompt_messages)
|
|
216
|
+
student_prompt_len = student_prompt.length
|
|
217
|
+
|
|
218
|
+
# 1) Sample a rollout from the student (no policy prompt).
|
|
219
|
+
response = self.sampling_client.sample(
|
|
220
|
+
prompt=student_prompt,
|
|
221
|
+
num_samples=1,
|
|
222
|
+
sampling_params=types.SamplingParams(
|
|
223
|
+
stop=self.renderer.get_stop_sequences(),
|
|
224
|
+
max_tokens=min(self.train_max_tokens, self.max_tokens),
|
|
225
|
+
temperature=self.train_temperature,
|
|
226
|
+
),
|
|
227
|
+
).result()
|
|
228
|
+
seq = response.sequences[0]
|
|
229
|
+
tokens = seq.tokens
|
|
230
|
+
student_logps = seq.logprobs or []
|
|
231
|
+
if not tokens or not student_logps:
|
|
232
|
+
return {}
|
|
233
|
+
|
|
234
|
+
# 2) Compute teacher logprobs on the student's exact tokens (teacher has policy prompt).
|
|
235
|
+
teacher_prompt = self.renderer.build_generation_prompt(self._teacher_prefix + prompt_messages)
|
|
236
|
+
teacher_prompt_len = teacher_prompt.length
|
|
237
|
+
full_teacher_input = teacher_prompt.append(types.EncodedTextChunk(tokens=tokens))
|
|
238
|
+
teacher_logprobs_full = self.sampling_client.compute_logprobs(full_teacher_input).result()
|
|
239
|
+
|
|
240
|
+
# compute_logprobs is shifted by 1 (logprob[i] predicts token[i+1]).
|
|
241
|
+
teacher_logps = teacher_logprobs_full[
|
|
242
|
+
teacher_prompt_len - 1 : teacher_prompt_len - 1 + len(tokens)
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
min_len = min(len(tokens), len(student_logps), len(teacher_logps))
|
|
246
|
+
if min_len <= 0:
|
|
247
|
+
return {}
|
|
248
|
+
tokens = tokens[:min_len]
|
|
249
|
+
student_logps = student_logps[:min_len]
|
|
250
|
+
teacher_logps = teacher_logps[:min_len]
|
|
251
|
+
|
|
252
|
+
# Reverse KL per token = student - teacher. We optimize negative reverse KL:
|
|
253
|
+
# advantage[t] = scale * kl_coef * (teacher_lp - student_lp)
|
|
254
|
+
advantages: list[float] = []
|
|
255
|
+
total_kl = 0.0
|
|
256
|
+
for s_lp, t_lp in zip(student_logps, teacher_logps):
|
|
257
|
+
total_kl += (s_lp - t_lp)
|
|
258
|
+
advantages.append(float(scale) * float(self.kl_coef) * (t_lp - s_lp))
|
|
259
|
+
|
|
260
|
+
# 3) Build datum for importance_sampling.
|
|
261
|
+
full_input = student_prompt.append(types.EncodedTextChunk(tokens=tokens[:-1]))
|
|
262
|
+
target_tokens = [0] * (student_prompt_len - 1) + list(tokens)
|
|
263
|
+
padded_logprobs = [0.0] * (student_prompt_len - 1) + list(student_logps)
|
|
264
|
+
padded_advantages = [0.0] * (student_prompt_len - 1) + advantages
|
|
265
|
+
|
|
266
|
+
datum = types.Datum(
|
|
267
|
+
model_input=full_input,
|
|
268
|
+
loss_fn_inputs={
|
|
269
|
+
"target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
|
|
270
|
+
"logprobs": TensorData.from_torch(torch.tensor(padded_logprobs)),
|
|
271
|
+
"advantages": TensorData.from_torch(torch.tensor(padded_advantages)),
|
|
272
|
+
},
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
fwd_bwd = self.training_client.forward_backward([datum], loss_fn="importance_sampling").result()
|
|
276
|
+
self.training_client.optim_step(
|
|
277
|
+
types.AdamParams(learning_rate=self.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8)
|
|
278
|
+
).result()
|
|
279
|
+
|
|
280
|
+
# Post-update drift metrics (PPO-style sanity check)
|
|
281
|
+
approx_kl = 0.0
|
|
282
|
+
ratio_mean = 1.0
|
|
283
|
+
try:
|
|
284
|
+
fwd_after = self.training_client.forward([datum], loss_fn="importance_sampling").result()
|
|
285
|
+
out = fwd_after.loss_fn_outputs[0] if fwd_after.loss_fn_outputs else {}
|
|
286
|
+
td = out.get("logprobs") if isinstance(out, dict) else None
|
|
287
|
+
if td is not None:
|
|
288
|
+
new_lp = td.to_torch().flatten().to(torch.float32)
|
|
289
|
+
old_lp = datum.loss_fn_inputs["logprobs"].to_torch().flatten().to(torch.float32)
|
|
290
|
+
adv = datum.loss_fn_inputs["advantages"].to_torch().flatten().to(torch.float32)
|
|
291
|
+
mask = adv != 0
|
|
292
|
+
if mask.any():
|
|
293
|
+
new_cat = new_lp[mask]
|
|
294
|
+
old_cat = old_lp[mask]
|
|
295
|
+
approx_kl = (old_cat - new_cat).mean().item()
|
|
296
|
+
ratio_mean = torch.exp(new_cat - old_cat).mean().item()
|
|
297
|
+
except Exception:
|
|
298
|
+
pass
|
|
299
|
+
|
|
300
|
+
self.sampling_client = self.training_client.save_weights_and_get_sampling_client("current")
|
|
301
|
+
self.train_steps += 1
|
|
302
|
+
|
|
303
|
+
metrics = fwd_bwd.metrics or {}
|
|
304
|
+
loss = metrics.get("loss:sum", metrics.get("loss", 0.0))
|
|
305
|
+
kl_st = total_kl / max(1, min_len)
|
|
306
|
+
adv_mean = sum(advantages) / max(1, len(advantages))
|
|
307
|
+
return {
|
|
308
|
+
"step": float(self.train_steps),
|
|
309
|
+
"loss": float(loss),
|
|
310
|
+
"tokens": float(min_len),
|
|
311
|
+
"kl_student_teacher": float(kl_st),
|
|
312
|
+
"adv_mean": float(adv_mean),
|
|
313
|
+
"approx_kl": float(approx_kl),
|
|
314
|
+
"ratio": float(ratio_mean),
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def teacher_prefix_messages(self) -> list[dict[str, Any]]:
|
|
319
|
+
return list(self._teacher_prefix)
|
|
320
|
+
|
|
321
|
+
@property
|
|
322
|
+
def student_prefix_messages(self) -> list[dict[str, Any]]:
|
|
323
|
+
return list(self._student_prefix)
|
|
324
|
+
|
|
325
|
+
def add_user(self, content: str) -> None:
|
|
326
|
+
self.messages.append({"role": "user", "content": content})
|
|
327
|
+
|
|
328
|
+
def add_assistant(self, msg: dict[str, Any]) -> None:
|
|
329
|
+
self.messages.append(msg)
|
|
330
|
+
|
|
331
|
+
def add_tool_result(self, tool_call_id: str, result: str) -> None:
|
|
332
|
+
self.messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result})
|
|
333
|
+
|
|
334
|
+
def clear(self) -> None:
|
|
335
|
+
self.messages.clear()
|
|
336
|
+
|
|
337
|
+
def save_checkpoint(self, path: str) -> bool:
|
|
338
|
+
"""Save current LoRA weights to disk. Returns True on success."""
|
|
339
|
+
if self.training_client is None:
|
|
340
|
+
return False
|
|
341
|
+
try:
|
|
342
|
+
self.training_client.save_state(path).result()
|
|
343
|
+
return True
|
|
344
|
+
except Exception:
|
|
345
|
+
return False
|
|
346
|
+
|
|
347
|
+
def load_checkpoint(self, path: str) -> bool:
|
|
348
|
+
"""Load LoRA weights from disk. Returns True on success."""
|
|
349
|
+
if self.training_client is None:
|
|
350
|
+
return False
|
|
351
|
+
try:
|
|
352
|
+
self.training_client.load_state(path).result()
|
|
353
|
+
# Update sampling client to use loaded weights
|
|
354
|
+
self.sampling_client = self.training_client.save_weights_and_get_sampling_client("current")
|
|
355
|
+
return True
|
|
356
|
+
except Exception:
|
|
357
|
+
return False
|
continualcode/tools.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Core tool implementations and schemas shared by tinkercode harnesses.
|
|
4
|
+
|
|
5
|
+
Tool names intentionally match the "Claude Code"-style set:
|
|
6
|
+
read, write, edit, glob, grep, bash
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import glob as globlib
|
|
12
|
+
import os
|
|
13
|
+
import re
|
|
14
|
+
import subprocess
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
# Only required by the Tinker renderer; keep optional for simple imports.
|
|
20
|
+
from tinker_cookbook.renderers import ToolSpec
|
|
21
|
+
except Exception: # pragma: no cover
|
|
22
|
+
ToolSpec = dict # type: ignore[misc,assignment]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def tool_read(args: dict[str, Any]) -> str:
|
|
26
|
+
path = args["path"]
|
|
27
|
+
if not os.path.isfile(path):
|
|
28
|
+
return f"error: file not found: {path}"
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
32
|
+
lines = f.readlines()
|
|
33
|
+
except UnicodeDecodeError:
|
|
34
|
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
|
35
|
+
lines = f.readlines()
|
|
36
|
+
|
|
37
|
+
offset = max(0, int(args.get("offset", 0) or 0))
|
|
38
|
+
limit = args.get("limit")
|
|
39
|
+
limit_int = len(lines) if limit is None else max(0, int(limit))
|
|
40
|
+
selected = lines[offset : offset + limit_int]
|
|
41
|
+
return "".join(f"{offset + idx + 1:4}| {line}" for idx, line in enumerate(selected))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def tool_write(args: dict[str, Any]) -> str:
|
|
45
|
+
path = args["path"]
|
|
46
|
+
content = args["content"]
|
|
47
|
+
Path(path).expanduser().parent.mkdir(parents=True, exist_ok=True)
|
|
48
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
49
|
+
f.write(content)
|
|
50
|
+
return "ok"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def tool_edit(args: dict[str, Any]) -> str:
|
|
54
|
+
path = args["path"]
|
|
55
|
+
if not os.path.isfile(path):
|
|
56
|
+
return f"error: file not found: {path}"
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
60
|
+
text = f.read()
|
|
61
|
+
except UnicodeDecodeError:
|
|
62
|
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
|
63
|
+
text = f.read()
|
|
64
|
+
|
|
65
|
+
old = args["old"]
|
|
66
|
+
new = args["new"]
|
|
67
|
+
replace_all = bool(args.get("all", False))
|
|
68
|
+
|
|
69
|
+
if old not in text:
|
|
70
|
+
return "error: old_string not found"
|
|
71
|
+
|
|
72
|
+
count = text.count(old)
|
|
73
|
+
if not replace_all and count > 1:
|
|
74
|
+
return f"error: old_string appears {count} times, must be unique (use all=true)"
|
|
75
|
+
|
|
76
|
+
replacement = text.replace(old, new) if replace_all else text.replace(old, new, 1)
|
|
77
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
78
|
+
f.write(replacement)
|
|
79
|
+
return "ok"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def tool_glob(args: dict[str, Any]) -> str:
|
|
83
|
+
base = args.get("path", ".") or "."
|
|
84
|
+
pattern = (str(base) + "/" + args["pat"]).replace("//", "/")
|
|
85
|
+
files = globlib.glob(pattern, recursive=True)
|
|
86
|
+
files = sorted(
|
|
87
|
+
files,
|
|
88
|
+
key=lambda p: os.path.getmtime(p) if os.path.isfile(p) else 0,
|
|
89
|
+
reverse=True,
|
|
90
|
+
)
|
|
91
|
+
return "\n".join(files[:50]) or "none"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def tool_grep(args: dict[str, Any]) -> str:
|
|
95
|
+
try:
|
|
96
|
+
pattern = re.compile(args["pat"])
|
|
97
|
+
except re.error:
|
|
98
|
+
return "error: invalid regex pattern"
|
|
99
|
+
|
|
100
|
+
base = args.get("path", ".") or "."
|
|
101
|
+
hits: list[str] = []
|
|
102
|
+
|
|
103
|
+
for filepath in globlib.glob(str(base) + "/**", recursive=True):
|
|
104
|
+
if not os.path.isfile(filepath):
|
|
105
|
+
continue
|
|
106
|
+
try:
|
|
107
|
+
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
|
|
108
|
+
for line_num, line in enumerate(f, 1):
|
|
109
|
+
if pattern.search(line):
|
|
110
|
+
hits.append(f"{filepath}:{line_num}:{line.rstrip()}")
|
|
111
|
+
if len(hits) >= 50:
|
|
112
|
+
return "\n".join(hits)
|
|
113
|
+
except Exception:
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
return "\n".join(hits) or "none"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def tool_bash(args: dict[str, Any]) -> str:
|
|
120
|
+
cmd = args["cmd"]
|
|
121
|
+
try:
|
|
122
|
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
|
|
123
|
+
output = (result.stdout or "") + (result.stderr or "")
|
|
124
|
+
return output.strip() or "(empty output)"
|
|
125
|
+
except subprocess.TimeoutExpired:
|
|
126
|
+
return "error: command timed out after 30s"
|
|
127
|
+
except Exception as e:
|
|
128
|
+
return f"error: {e}"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
TOOL_FUNCTIONS: dict[str, Any] = {
|
|
132
|
+
"read": tool_read,
|
|
133
|
+
"write": tool_write,
|
|
134
|
+
"edit": tool_edit,
|
|
135
|
+
"glob": tool_glob,
|
|
136
|
+
"grep": tool_grep,
|
|
137
|
+
"bash": tool_bash,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def execute_tool(name: str, args: dict[str, Any]) -> str:
|
|
142
|
+
fn = TOOL_FUNCTIONS.get(name)
|
|
143
|
+
if fn is None:
|
|
144
|
+
return f"error: unknown tool '{name}'"
|
|
145
|
+
try:
|
|
146
|
+
return fn(args)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
return f"error: {e}"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
TOOL_SPECS: list[ToolSpec] = [
|
|
152
|
+
{
|
|
153
|
+
"name": "read",
|
|
154
|
+
"description": "Read a file and return its contents with line numbers",
|
|
155
|
+
"parameters": {
|
|
156
|
+
"type": "object",
|
|
157
|
+
"properties": {
|
|
158
|
+
"path": {"type": "string", "description": "Path to the file to read"},
|
|
159
|
+
"offset": {"type": "integer", "description": "Line number to start from (0-indexed)"},
|
|
160
|
+
"limit": {"type": "integer", "description": "Maximum number of lines to read"},
|
|
161
|
+
},
|
|
162
|
+
"required": ["path"],
|
|
163
|
+
},
|
|
164
|
+
},
|
|
165
|
+
{
|
|
166
|
+
"name": "write",
|
|
167
|
+
"description": "Write content to a file, creating it if it doesn't exist",
|
|
168
|
+
"parameters": {
|
|
169
|
+
"type": "object",
|
|
170
|
+
"properties": {
|
|
171
|
+
"path": {"type": "string", "description": "Path to the file to write"},
|
|
172
|
+
"content": {"type": "string", "description": "Content to write to the file"},
|
|
173
|
+
},
|
|
174
|
+
"required": ["path", "content"],
|
|
175
|
+
},
|
|
176
|
+
},
|
|
177
|
+
{
|
|
178
|
+
"name": "edit",
|
|
179
|
+
"description": "Edit a file by replacing old with new. old must appear once unless all=true.",
|
|
180
|
+
"parameters": {
|
|
181
|
+
"type": "object",
|
|
182
|
+
"properties": {
|
|
183
|
+
"path": {"type": "string", "description": "Path to the file to edit"},
|
|
184
|
+
"old": {"type": "string", "description": "The exact string to find and replace"},
|
|
185
|
+
"new": {"type": "string", "description": "The string to replace it with"},
|
|
186
|
+
"all": {"type": "boolean", "description": "Replace all occurrences (default false)"},
|
|
187
|
+
},
|
|
188
|
+
"required": ["path", "old", "new"],
|
|
189
|
+
},
|
|
190
|
+
},
|
|
191
|
+
{
|
|
192
|
+
"name": "glob",
|
|
193
|
+
"description": "Find files matching a glob pattern",
|
|
194
|
+
"parameters": {
|
|
195
|
+
"type": "object",
|
|
196
|
+
"properties": {
|
|
197
|
+
"pat": {"type": "string", "description": "Glob pattern (e.g., '**/*.py')"},
|
|
198
|
+
"path": {"type": "string", "description": "Base path to search from"},
|
|
199
|
+
},
|
|
200
|
+
"required": ["pat"],
|
|
201
|
+
},
|
|
202
|
+
},
|
|
203
|
+
{
|
|
204
|
+
"name": "grep",
|
|
205
|
+
"description": "Search for a regex pattern in files",
|
|
206
|
+
"parameters": {
|
|
207
|
+
"type": "object",
|
|
208
|
+
"properties": {
|
|
209
|
+
"pat": {"type": "string", "description": "Regex pattern to search for"},
|
|
210
|
+
"path": {"type": "string", "description": "Base path to search from"},
|
|
211
|
+
},
|
|
212
|
+
"required": ["pat"],
|
|
213
|
+
},
|
|
214
|
+
},
|
|
215
|
+
{
|
|
216
|
+
"name": "bash",
|
|
217
|
+
"description": "Run a shell command",
|
|
218
|
+
"parameters": {
|
|
219
|
+
"type": "object",
|
|
220
|
+
"properties": {
|
|
221
|
+
"cmd": {"type": "string", "description": "The shell command to run"},
|
|
222
|
+
},
|
|
223
|
+
"required": ["cmd"],
|
|
224
|
+
},
|
|
225
|
+
},
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
READONLY_TOOLS = {"read", "glob", "grep"}
|
|
230
|
+
|