aleph-rlm 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aleph/core.py ADDED
@@ -0,0 +1,874 @@
1
+ """Aleph core implementation.
2
+
3
+ Aleph is a production-oriented implementation of Recursive Language Models
4
+ (RLMs): instead of stuffing massive context into an LLM prompt, Aleph stores the
5
+ context as a variable in a sandboxed REPL and lets the model write code to
6
+ inspect, search, and chunk that context.
7
+
8
+ The root LLM runs a loop:
9
+ 1. It produces either Python code (```python) or a final answer (FINAL(...)).
10
+ 2. Aleph executes the code in the REPL, captures output, and feeds the output back.
11
+ 3. The model iterates until it emits FINAL(answer) or FINAL_VAR(name).
12
+
13
+ Sub-queries are supported via `sub_query(...)` and deeper recursion via
14
+ `sub_aleph(...)` available inside the REPL.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import hashlib
21
+ import json
22
+ import re
23
+ import time
24
+ from datetime import datetime
25
+ from typing import Awaitable, Callable, cast
26
+
27
+ from .types import (
28
+ ActionType,
29
+ AlephResponse,
30
+ Budget,
31
+ BudgetStatus,
32
+ ContextCollection,
33
+ ContextMetadata,
34
+ ContextType,
35
+ ContentFormat,
36
+ ExecutionResult,
37
+ ParsedAction,
38
+ SubQueryResult,
39
+ SubQueryFn,
40
+ TrajectoryStep,
41
+ Message,
42
+ SubAlephFn,
43
+ )
44
+ from .providers.base import LLMProvider, ProviderError
45
+ from .providers.registry import get_provider
46
+ from .cache.memory import MemoryCache
47
+ from .repl.sandbox import REPLEnvironment, SandboxConfig
48
+ from .prompts.system import DEFAULT_SYSTEM_PROMPT
49
+
50
+
51
+ _FINAL_RE = re.compile(r"FINAL\((.*?)\)", re.DOTALL)
52
+ _FINAL_VAR_RE = re.compile(r"FINAL_VAR\((.*?)\)", re.DOTALL)
53
+ _CODE_BLOCK_RE = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL)
54
+
55
+
56
+ class Aleph:
57
+ """Recursive Language Model runner."""
58
+
59
+ def __init__(
60
+ self,
61
+ provider: LLMProvider | str = "anthropic",
62
+ root_model: str = "claude-sonnet-4-20250514",
63
+ sub_model: str | None = None,
64
+ budget: Budget | None = None,
65
+ sandbox_config: SandboxConfig | None = None,
66
+ system_prompt: str | None = None,
67
+ enable_caching: bool = True,
68
+ log_trajectory: bool = True,
69
+ context_var_name: str = "ctx",
70
+ ) -> None:
71
+ """Create an Aleph runner.
72
+
73
+ Args:
74
+ provider: LLM provider instance or provider name.
75
+ root_model: Model used for the root loop.
76
+ sub_model: Model used for sub-queries/sub-aleph (defaults to root_model).
77
+ budget: Resource limits (tokens/iterations/depth/wall-time/sub-queries).
78
+ sandbox_config: REPL sandbox limits and allowed imports.
79
+ system_prompt: Custom system prompt template.
80
+ enable_caching: Enable memoization for sub-queries.
81
+ log_trajectory: Record a full trajectory in the response.
82
+ context_var_name: Variable name used to expose context in the REPL.
83
+ """
84
+ if isinstance(provider, str):
85
+ self.provider = get_provider(provider)
86
+ else:
87
+ self.provider = provider
88
+
89
+ self.root_model = root_model
90
+ self.sub_model = sub_model or root_model
91
+ self.budget = budget or Budget()
92
+ self.sandbox_config = sandbox_config or SandboxConfig()
93
+ self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
94
+ self.enable_caching = enable_caching
95
+ self.log_trajectory = log_trajectory
96
+ self.context_var_name = context_var_name
97
+
98
+ self._cache: MemoryCache[str] | None = MemoryCache() if enable_caching else None
99
+
100
+ async def complete(self, query: str, context: ContextType, **kwargs: object) -> AlephResponse:
101
+ """Answer `query` using `context` via the RLM loop."""
102
+
103
+ # Allow per-call overrides
104
+ root_model = cast(str, kwargs.get("root_model", self.root_model))
105
+ sub_model = cast(str, kwargs.get("sub_model", self.sub_model))
106
+
107
+ budget_obj = kwargs.get("budget", self.budget)
108
+ budget = budget_obj if isinstance(budget_obj, Budget) else self.budget
109
+
110
+ temperature_obj = kwargs.get("temperature", 0.0)
111
+ if isinstance(temperature_obj, (int, float)):
112
+ temperature = float(temperature_obj)
113
+ elif isinstance(temperature_obj, str):
114
+ try:
115
+ temperature = float(temperature_obj)
116
+ except ValueError:
117
+ temperature = 0.0
118
+ else:
119
+ temperature = 0.0
120
+
121
+ start_time = time.time()
122
+ budget_status = BudgetStatus(depth_current=0)
123
+ trajectory: list[TrajectoryStep] = []
124
+
125
+ # Global step counter across root steps + subcalls
126
+ step_counter = 0
127
+ step_lock = asyncio.Lock()
128
+
129
+ # A helper to allocate a new step number in a concurrency-safe way
130
+ async def next_step_number() -> int:
131
+ nonlocal step_counter
132
+ async with step_lock:
133
+ step_counter += 1
134
+ return step_counter
135
+
136
+ # Run the root call
137
+ response = await self._run(
138
+ query=query,
139
+ context=context,
140
+ depth=0,
141
+ root_model=root_model,
142
+ sub_model=sub_model,
143
+ budget=budget,
144
+ budget_status=budget_status,
145
+ start_time=start_time,
146
+ trajectory=trajectory,
147
+ temperature=temperature,
148
+ next_step_number=next_step_number,
149
+ )
150
+
151
+ # Fill in top-level stats
152
+ return response
153
+
154
+ def complete_sync(self, query: str, context: ContextType, **kwargs: object) -> AlephResponse:
155
+ """Synchronous wrapper around :meth:`complete`.
156
+
157
+ Note: cannot be called from within an existing asyncio event loop.
158
+ """
159
+
160
+ try:
161
+ loop = asyncio.get_running_loop()
162
+ except RuntimeError:
163
+ loop = None
164
+
165
+ if loop is not None and loop.is_running():
166
+ raise RuntimeError("complete_sync() cannot be called from a running event loop. Use `await aleph.complete(...)`.")
167
+
168
+ return asyncio.run(self.complete(query, context, **kwargs))
169
+
170
+ # ---------------------------------------------------------------------
171
+ # Internal execution
172
+ # ---------------------------------------------------------------------
173
+
174
+ async def _run(
175
+ self,
176
+ query: str,
177
+ context: ContextType,
178
+ depth: int,
179
+ root_model: str,
180
+ sub_model: str,
181
+ budget: Budget,
182
+ budget_status: BudgetStatus,
183
+ start_time: float,
184
+ trajectory: list[TrajectoryStep],
185
+ temperature: float,
186
+ next_step_number: Callable[[], Awaitable[int]],
187
+ ) -> AlephResponse:
188
+ """Internal runner used for recursion."""
189
+
190
+ # Depth check early
191
+ if budget.max_depth is not None and depth > budget.max_depth:
192
+ return AlephResponse(
193
+ answer="",
194
+ success=False,
195
+ total_iterations=0,
196
+ max_depth_reached=depth,
197
+ total_tokens=budget_status.tokens_used,
198
+ total_cost_usd=budget_status.cost_used,
199
+ wall_time_seconds=time.time() - start_time,
200
+ trajectory=trajectory,
201
+ error=f"Max depth exceeded: depth={depth} > max_depth={budget.max_depth}",
202
+ error_type="budget_exceeded",
203
+ )
204
+
205
+ budget_status.depth_current = max(budget_status.depth_current, depth)
206
+
207
+ # Analyze context and create REPL
208
+ meta = self._analyze_context(context)
209
+ loop = asyncio.get_running_loop()
210
+ repl = REPLEnvironment(
211
+ context=context,
212
+ context_var_name=self.context_var_name,
213
+ config=self.sandbox_config,
214
+ loop=loop,
215
+ )
216
+
217
+ # Inject sub_query and sub_aleph
218
+ repl.inject_sub_query(self._make_sub_query(
219
+ depth=depth,
220
+ sub_model=sub_model,
221
+ budget=budget,
222
+ budget_status=budget_status,
223
+ start_time=start_time,
224
+ trajectory=trajectory,
225
+ next_step_number=next_step_number,
226
+ temperature=temperature,
227
+ ))
228
+ repl.inject_sub_aleph(self._make_sub_aleph(
229
+ depth=depth,
230
+ root_model=root_model,
231
+ sub_model=sub_model,
232
+ budget=budget,
233
+ budget_status=budget_status,
234
+ start_time=start_time,
235
+ trajectory=trajectory,
236
+ temperature=temperature,
237
+ next_step_number=next_step_number,
238
+ ))
239
+
240
+ messages = self._build_initial_messages(query, meta)
241
+
242
+ max_iterations = budget.max_iterations or 100
243
+ local_iterations = 0
244
+
245
+ max_depth_reached = depth
246
+
247
+ while local_iterations < max_iterations:
248
+ local_iterations += 1
249
+ # Global iteration counter (across recursion)
250
+ budget_status.iterations_used += 1
251
+ budget_status.wall_time_used = time.time() - start_time
252
+
253
+ exceeded, reason = budget_status.exceeds(budget)
254
+ if exceeded:
255
+ return AlephResponse(
256
+ answer="",
257
+ success=False,
258
+ total_iterations=budget_status.iterations_used,
259
+ max_depth_reached=max_depth_reached,
260
+ total_tokens=budget_status.tokens_used,
261
+ total_cost_usd=budget_status.cost_used,
262
+ wall_time_seconds=time.time() - start_time,
263
+ trajectory=trajectory,
264
+ error=reason,
265
+ error_type="budget_exceeded",
266
+ )
267
+
268
+ # Keep messages within context window (best-effort)
269
+ self._trim_messages(messages, model=root_model)
270
+
271
+ # Calculate remaining wall-time for timeout enforcement
272
+ remaining_time: float | None = None
273
+ if budget.max_wall_time_seconds is not None:
274
+ remaining_time = budget.max_wall_time_seconds - budget_status.wall_time_used
275
+ if remaining_time <= 0:
276
+ return AlephResponse(
277
+ answer="",
278
+ success=False,
279
+ total_iterations=budget_status.iterations_used,
280
+ max_depth_reached=max_depth_reached,
281
+ total_tokens=budget_status.tokens_used,
282
+ total_cost_usd=budget_status.cost_used,
283
+ wall_time_seconds=time.time() - start_time,
284
+ trajectory=trajectory,
285
+ error="Wall-time budget exhausted before provider call",
286
+ error_type="budget_exceeded",
287
+ )
288
+
289
+ # Call provider with wall-time enforcement
290
+ try:
291
+ out_limit = self.provider.get_output_limit(root_model)
292
+ max_tokens = min(out_limit, 8192)
293
+ provider_coro = self.provider.complete(
294
+ messages=messages,
295
+ model=root_model,
296
+ max_tokens=max_tokens,
297
+ temperature=temperature,
298
+ )
299
+ if remaining_time is not None:
300
+ llm_text, in_tok, out_tok, cost = await asyncio.wait_for(
301
+ provider_coro, timeout=remaining_time
302
+ )
303
+ else:
304
+ llm_text, in_tok, out_tok, cost = await provider_coro
305
+ except asyncio.TimeoutError:
306
+ return AlephResponse(
307
+ answer="",
308
+ success=False,
309
+ total_iterations=budget_status.iterations_used,
310
+ max_depth_reached=max_depth_reached,
311
+ total_tokens=budget_status.tokens_used,
312
+ total_cost_usd=budget_status.cost_used,
313
+ wall_time_seconds=time.time() - start_time,
314
+ trajectory=trajectory,
315
+ error="Wall-time budget exceeded during provider call",
316
+ error_type="budget_exceeded",
317
+ )
318
+ except ProviderError as e:
319
+ return AlephResponse(
320
+ answer="",
321
+ success=False,
322
+ total_iterations=budget_status.iterations_used,
323
+ max_depth_reached=max_depth_reached,
324
+ total_tokens=budget_status.tokens_used,
325
+ total_cost_usd=budget_status.cost_used,
326
+ wall_time_seconds=time.time() - start_time,
327
+ trajectory=trajectory,
328
+ error=str(e),
329
+ error_type="provider_error",
330
+ )
331
+ except Exception as e:
332
+ return AlephResponse(
333
+ answer="",
334
+ success=False,
335
+ total_iterations=budget_status.iterations_used,
336
+ max_depth_reached=max_depth_reached,
337
+ total_tokens=budget_status.tokens_used,
338
+ total_cost_usd=budget_status.cost_used,
339
+ wall_time_seconds=time.time() - start_time,
340
+ trajectory=trajectory,
341
+ error=f"Unexpected provider error: {e}",
342
+ error_type="provider_error",
343
+ )
344
+
345
+ budget_status.tokens_used += int(in_tok + out_tok)
346
+ budget_status.cost_used += float(cost)
347
+ budget_status.wall_time_used = time.time() - start_time
348
+
349
+ # Stop immediately if the call pushed us over budget.
350
+ exceeded, reason = budget_status.exceeds(budget)
351
+ if exceeded:
352
+ if self.log_trajectory:
353
+ step_no = await next_step_number()
354
+ trajectory.append(
355
+ TrajectoryStep(
356
+ step_number=step_no,
357
+ depth=depth,
358
+ timestamp=datetime.now(),
359
+ prompt_tokens=int(in_tok),
360
+ prompt_summary=(messages[-1].get("content", "")[:500]),
361
+ action=self._parse_response(llm_text),
362
+ result="[BUDGET_EXCEEDED_AFTER_PROVIDER_CALL]",
363
+ result_tokens=int(out_tok),
364
+ cumulative_tokens=budget_status.tokens_used,
365
+ cumulative_cost=budget_status.cost_used,
366
+ )
367
+ )
368
+
369
+ return AlephResponse(
370
+ answer="",
371
+ success=False,
372
+ total_iterations=budget_status.iterations_used,
373
+ max_depth_reached=max_depth_reached,
374
+ total_tokens=budget_status.tokens_used,
375
+ total_cost_usd=budget_status.cost_used,
376
+ wall_time_seconds=time.time() - start_time,
377
+ trajectory=trajectory,
378
+ error=reason,
379
+ error_type="budget_exceeded",
380
+ )
381
+
382
+ action = self._parse_response(llm_text)
383
+
384
+ # FINAL(answer)
385
+ if action.action_type == ActionType.FINAL_ANSWER:
386
+ answer = self._extract_final(llm_text)
387
+ if self.log_trajectory:
388
+ step_no = await next_step_number()
389
+ trajectory.append(
390
+ TrajectoryStep(
391
+ step_number=step_no,
392
+ depth=depth,
393
+ timestamp=datetime.now(),
394
+ prompt_tokens=int(in_tok),
395
+ prompt_summary=(messages[-1].get("content", "")[:500]),
396
+ action=action,
397
+ result=answer,
398
+ result_tokens=int(out_tok),
399
+ cumulative_tokens=budget_status.tokens_used,
400
+ cumulative_cost=budget_status.cost_used,
401
+ )
402
+ )
403
+ return AlephResponse(
404
+ answer=answer,
405
+ success=True,
406
+ total_iterations=budget_status.iterations_used,
407
+ max_depth_reached=max_depth_reached,
408
+ total_tokens=budget_status.tokens_used,
409
+ total_cost_usd=budget_status.cost_used,
410
+ wall_time_seconds=time.time() - start_time,
411
+ trajectory=trajectory,
412
+ )
413
+
414
+ # FINAL_VAR(name)
415
+ if action.action_type == ActionType.FINAL_VAR:
416
+ var_name = self._extract_final_var(llm_text)
417
+ value = repl.get_variable(var_name)
418
+ if value is None:
419
+ answer = f"[ERROR: Variable '{var_name}' not found]"
420
+ else:
421
+ answer = str(value)
422
+ if self.log_trajectory:
423
+ step_no = await next_step_number()
424
+ trajectory.append(
425
+ TrajectoryStep(
426
+ step_number=step_no,
427
+ depth=depth,
428
+ timestamp=datetime.now(),
429
+ prompt_tokens=int(in_tok),
430
+ prompt_summary=(messages[-1].get("content", "")[:500]),
431
+ action=action,
432
+ result=answer,
433
+ result_tokens=int(out_tok),
434
+ cumulative_tokens=budget_status.tokens_used,
435
+ cumulative_cost=budget_status.cost_used,
436
+ )
437
+ )
438
+ return AlephResponse(
439
+ answer=answer,
440
+ success=True,
441
+ total_iterations=budget_status.iterations_used,
442
+ max_depth_reached=max_depth_reached,
443
+ total_tokens=budget_status.tokens_used,
444
+ total_cost_usd=budget_status.cost_used,
445
+ wall_time_seconds=time.time() - start_time,
446
+ trajectory=trajectory,
447
+ )
448
+
449
+ # CODE
450
+ if action.action_type == ActionType.CODE_BLOCK:
451
+ exec_result = await repl.execute_async(action.content)
452
+
453
+ # Log trajectory step
454
+ if self.log_trajectory:
455
+ step_no = await next_step_number()
456
+ result_text_for_tokens = exec_result.stdout
457
+ if exec_result.stderr:
458
+ result_text_for_tokens += "\n" + exec_result.stderr
459
+
460
+ trajectory.append(
461
+ TrajectoryStep(
462
+ step_number=step_no,
463
+ depth=depth,
464
+ timestamp=datetime.now(),
465
+ prompt_tokens=int(in_tok),
466
+ prompt_summary=(messages[-1].get("content", "")[:500]),
467
+ action=action,
468
+ result=exec_result,
469
+ result_tokens=self.provider.count_tokens(result_text_for_tokens, root_model),
470
+ cumulative_tokens=budget_status.tokens_used,
471
+ cumulative_cost=budget_status.cost_used,
472
+ )
473
+ )
474
+
475
+ # Add the assistant response and REPL output back into the conversation
476
+ messages.append({"role": "assistant", "content": llm_text})
477
+ messages.append(
478
+ {
479
+ "role": "user",
480
+ "content": self._format_repl_result(exec_result),
481
+ }
482
+ )
483
+
484
+ # Track depth reached if subcalls happened
485
+ max_depth_reached = max(max_depth_reached, budget_status.depth_current)
486
+ continue
487
+
488
+ # CONTINUE / unknown
489
+ if self.log_trajectory:
490
+ step_no = await next_step_number()
491
+ trajectory.append(
492
+ TrajectoryStep(
493
+ step_number=step_no,
494
+ depth=depth,
495
+ timestamp=datetime.now(),
496
+ prompt_tokens=int(in_tok),
497
+ prompt_summary=(messages[-1].get("content", "")[:500]),
498
+ action=action,
499
+ result="[CONTINUE]",
500
+ result_tokens=int(out_tok),
501
+ cumulative_tokens=budget_status.tokens_used,
502
+ cumulative_cost=budget_status.cost_used,
503
+ )
504
+ )
505
+ messages.append({"role": "assistant", "content": llm_text})
506
+ messages.append(
507
+ {
508
+ "role": "user",
509
+ "content": "Continue. When you have the answer, use FINAL(answer) or FINAL_VAR(variable_name).",
510
+ }
511
+ )
512
+
513
+ max_depth_reached = max(max_depth_reached, budget_status.depth_current)
514
+
515
+ return AlephResponse(
516
+ answer="",
517
+ success=False,
518
+ total_iterations=budget_status.iterations_used,
519
+ max_depth_reached=max_depth_reached,
520
+ total_tokens=budget_status.tokens_used,
521
+ total_cost_usd=budget_status.cost_used,
522
+ wall_time_seconds=time.time() - start_time,
523
+ trajectory=trajectory,
524
+ error="Max iterations reached without a final answer",
525
+ error_type="max_iterations",
526
+ )
527
+
528
+ # ---------------------------------------------------------------------
529
+ # Sub-calls (sub_query, sub_aleph)
530
+ # ---------------------------------------------------------------------
531
+
532
+ def _make_sub_query(
533
+ self,
534
+ depth: int,
535
+ sub_model: str,
536
+ budget: Budget,
537
+ budget_status: BudgetStatus,
538
+ start_time: float,
539
+ trajectory: list[TrajectoryStep],
540
+ next_step_number: Callable[[], Awaitable[int]],
541
+ temperature: float,
542
+ ) -> SubQueryFn:
543
+ """Create an async sub_query function for the REPL."""
544
+
545
+ async def sub_query(prompt: str, context_slice: str | None = None) -> str:
546
+ # Budget checks
547
+ budget_status.wall_time_used = time.time() - start_time
548
+
549
+ if budget.max_sub_queries is not None and budget_status.sub_queries_used >= budget.max_sub_queries:
550
+ return "[ERROR: Sub-query budget exceeded]"
551
+ if budget.max_depth is not None and (depth + 1) > budget.max_depth:
552
+ return "[ERROR: Max recursion depth reached]"
553
+
554
+ # Cache key
555
+ cache_key = None
556
+ if self._cache is not None:
557
+ h = hashlib.sha256()
558
+ h.update(sub_model.encode())
559
+ h.update(b"\0")
560
+ h.update(prompt.encode())
561
+ h.update(b"\0")
562
+ if context_slice:
563
+ h.update(context_slice.encode())
564
+ cache_key = f"subq:{h.hexdigest()}"
565
+ cached = self._cache.get(cache_key)
566
+ if isinstance(cached, str):
567
+ return cached
568
+
569
+ messages: list[Message] = [{"role": "user", "content": prompt}]
570
+ if context_slice:
571
+ messages[0]["content"] = f"{prompt}\n\nContext:\n{context_slice}"
572
+
573
+ out_limit = self.provider.get_output_limit(sub_model)
574
+ max_tokens = min(out_limit, 4096)
575
+
576
+ try:
577
+ text, in_tok, out_tok, cost = await self.provider.complete(
578
+ messages=messages,
579
+ model=sub_model,
580
+ max_tokens=max_tokens,
581
+ temperature=temperature,
582
+ )
583
+ except Exception as e:
584
+ return f"[ERROR: sub_query failed: {e}]"
585
+
586
+ budget_status.tokens_used += int(in_tok + out_tok)
587
+ budget_status.cost_used += float(cost)
588
+ budget_status.sub_queries_used += 1
589
+ budget_status.depth_current = max(budget_status.depth_current, depth + 1)
590
+ budget_status.wall_time_used = time.time() - start_time
591
+
592
+ exceeded, reason = budget_status.exceeds(budget)
593
+ if exceeded:
594
+ return f"[ERROR: Budget exceeded after sub_query: {reason}]"
595
+
596
+ # Log sub-query as trajectory step
597
+ if self.log_trajectory:
598
+ step_no = await next_step_number()
599
+ trajectory.append(
600
+ TrajectoryStep(
601
+ step_number=step_no,
602
+ depth=depth + 1,
603
+ timestamp=datetime.now(),
604
+ prompt_tokens=int(in_tok),
605
+ prompt_summary=prompt[:500],
606
+ action=ParsedAction(
607
+ action_type=ActionType.TOOL_CALL,
608
+ content="sub_query",
609
+ raw_response="sub_query(...)",
610
+ ),
611
+ result=SubQueryResult(
612
+ answer=text,
613
+ tokens_input=int(in_tok),
614
+ tokens_output=int(out_tok),
615
+ cost_usd=float(cost),
616
+ model_used=sub_model,
617
+ depth=depth + 1,
618
+ ),
619
+ result_tokens=int(out_tok),
620
+ cumulative_tokens=budget_status.tokens_used,
621
+ cumulative_cost=budget_status.cost_used,
622
+ )
623
+ )
624
+
625
+ if cache_key and self._cache is not None:
626
+ self._cache.set(cache_key, text)
627
+
628
+ return text
629
+
630
+ return sub_query
631
+
632
+ def _make_sub_aleph(
633
+ self,
634
+ depth: int,
635
+ root_model: str,
636
+ sub_model: str,
637
+ budget: Budget,
638
+ budget_status: BudgetStatus,
639
+ start_time: float,
640
+ trajectory: list[TrajectoryStep],
641
+ temperature: float,
642
+ next_step_number: Callable[[], Awaitable[int]],
643
+ ) -> SubAlephFn:
644
+ """Create an async sub_aleph function for the REPL."""
645
+
646
+ async def sub_aleph(query: str, context: ContextType | None = None) -> AlephResponse:
647
+ budget_status.wall_time_used = time.time() - start_time
648
+
649
+ if budget.max_depth is not None and (depth + 1) > budget.max_depth:
650
+ return AlephResponse(
651
+ answer="",
652
+ success=False,
653
+ total_iterations=0,
654
+ max_depth_reached=depth + 1,
655
+ total_tokens=budget_status.tokens_used,
656
+ total_cost_usd=budget_status.cost_used,
657
+ wall_time_seconds=time.time() - start_time,
658
+ trajectory=trajectory,
659
+ error="Max recursion depth reached",
660
+ error_type="budget_exceeded",
661
+ )
662
+
663
+ # Use provided context or default to empty string
664
+ sub_ctx: ContextType = context if context is not None else ""
665
+
666
+ resp = await self._run(
667
+ query=query,
668
+ context=sub_ctx,
669
+ depth=depth + 1,
670
+ root_model=root_model,
671
+ sub_model=sub_model,
672
+ budget=budget,
673
+ budget_status=budget_status,
674
+ start_time=start_time,
675
+ trajectory=trajectory,
676
+ temperature=temperature,
677
+ next_step_number=next_step_number,
678
+ )
679
+
680
+ return resp
681
+
682
+ return sub_aleph
683
+
684
+ # ---------------------------------------------------------------------
685
+ # Prompting / parsing
686
+ # ---------------------------------------------------------------------
687
+
688
+ def _build_initial_messages(self, query: str, meta: ContextMetadata) -> list[Message]:
689
+ system = self.system_prompt.format(
690
+ context_var=self.context_var_name,
691
+ context_format=meta.format.value,
692
+ context_size_chars=meta.size_chars,
693
+ context_size_lines=meta.size_lines,
694
+ context_size_tokens=meta.size_tokens_estimate,
695
+ context_preview=meta.sample_preview,
696
+ structure_hint=meta.structure_hint or "N/A",
697
+ )
698
+ return [
699
+ {"role": "system", "content": system},
700
+ {"role": "user", "content": query},
701
+ ]
702
+
703
+ def _parse_response(self, text: str) -> ParsedAction:
704
+ if _FINAL_VAR_RE.search(text):
705
+ return ParsedAction(ActionType.FINAL_VAR, "", text)
706
+ if _FINAL_RE.search(text):
707
+ return ParsedAction(ActionType.FINAL_ANSWER, "", text)
708
+
709
+ m = _CODE_BLOCK_RE.search(text)
710
+ if m:
711
+ return ParsedAction(ActionType.CODE_BLOCK, m.group(1).strip(), text)
712
+
713
+ return ParsedAction(ActionType.CONTINUE, "", text)
714
+
715
+ def _extract_final(self, text: str) -> str:
716
+ m = _FINAL_RE.search(text)
717
+ if not m:
718
+ return text.strip()
719
+ return m.group(1).strip()
720
+
721
+ def _extract_final_var(self, text: str) -> str:
722
+ m = _FINAL_VAR_RE.search(text)
723
+ if not m:
724
+ return ""
725
+ raw = m.group(1).strip()
726
+ # Allow FINAL_VAR("name") or FINAL_VAR('name')
727
+ if len(raw) >= 2 and ((raw[0] == raw[-1] == '"') or (raw[0] == raw[-1] == "'")):
728
+ raw = raw[1:-1].strip()
729
+ return raw
730
+
731
+ def _format_repl_result(self, result: ExecutionResult) -> str:
732
+ parts: list[str] = []
733
+ if result.stdout:
734
+ parts.append(result.stdout)
735
+ if result.stderr:
736
+ parts.append("[STDERR]\n" + result.stderr)
737
+ if result.error and (not result.stderr):
738
+ parts.append("[ERROR]\n" + result.error)
739
+ if result.return_value is not None:
740
+ parts.append(f"[RETURN_VALUE]\n{result.return_value}")
741
+
742
+ out = "\n".join(parts).strip()
743
+ if not out:
744
+ out = "(no output)"
745
+ return f"```output\n{out}\n```"
746
+
747
+ def _trim_messages(self, messages: list[Message], model: str) -> None:
748
+ """Best-effort trimming to avoid blowing past the context limit."""
749
+
750
+ context_limit = self.provider.get_context_limit(model)
751
+ # Reserve room for the model's output.
752
+ reserve = min(self.provider.get_output_limit(model), 8192)
753
+ target = max(1_000, context_limit - reserve)
754
+
755
+ # Rough token count
756
+ def msg_tokens(ms: list[Message]) -> int:
757
+ return sum(self.provider.count_tokens(m.get("content", ""), model) for m in ms)
758
+
759
+ if msg_tokens(messages) <= target:
760
+ return
761
+
762
+ # Always keep the system message. Keep the last few turns.
763
+ system = messages[0:1]
764
+ tail = messages[-8:]
765
+ pruned = system + tail
766
+ # If still too large, progressively drop older tail messages (but keep user query at least)
767
+ while len(pruned) > 2 and msg_tokens(pruned) > target:
768
+ pruned = system + pruned[2:]
769
+
770
+ messages.clear()
771
+ messages.extend(pruned)
772
+
773
+ # ---------------------------------------------------------------------
774
+ # Context analysis
775
+ # ---------------------------------------------------------------------
776
+
777
+ def _analyze_context(self, context: ContextType) -> ContextMetadata:
778
+ """Compute lightweight metadata for the root prompt."""
779
+
780
+ # Multi-doc collection
781
+ if isinstance(context, ContextCollection):
782
+ total_bytes = 0
783
+ total_chars = 0
784
+ total_lines = 0
785
+ preview = ""
786
+ structure = f"ContextCollection with {len(context.items)} items"
787
+
788
+ for i, (name, item) in enumerate(context.items):
789
+ item_meta = self._analyze_context(item)
790
+ total_bytes += item_meta.size_bytes
791
+ total_chars += item_meta.size_chars
792
+ total_lines += item_meta.size_lines
793
+ if i == 0:
794
+ preview = f"[{name}]\n" + item_meta.sample_preview
795
+
796
+ est_tokens = total_chars // 4 if total_chars else 0
797
+ context.total_size_bytes = total_bytes
798
+ context.total_size_tokens_estimate = est_tokens
799
+
800
+ return ContextMetadata(
801
+ format=ContentFormat.MIXED,
802
+ size_bytes=total_bytes,
803
+ size_chars=total_chars,
804
+ size_lines=total_lines,
805
+ size_tokens_estimate=est_tokens,
806
+ structure_hint=structure,
807
+ sample_preview=preview[:500],
808
+ )
809
+
810
+ # Plain text
811
+ if isinstance(context, str):
812
+ return ContextMetadata(
813
+ format=ContentFormat.TEXT,
814
+ size_bytes=len(context.encode("utf-8", errors="ignore")),
815
+ size_chars=len(context),
816
+ size_lines=context.count("\n") + 1,
817
+ size_tokens_estimate=max(1, len(context) // 4) if context else 0,
818
+ structure_hint=None,
819
+ sample_preview=context[:500],
820
+ )
821
+
822
+ # Bytes
823
+ if isinstance(context, (bytes, bytearray)):
824
+ b = bytes(context)
825
+ preview = b[:200].decode("utf-8", errors="replace")
826
+ return ContextMetadata(
827
+ format=ContentFormat.BINARY,
828
+ size_bytes=len(b),
829
+ size_chars=len(preview),
830
+ size_lines=preview.count("\n") + 1,
831
+ size_tokens_estimate=max(1, len(preview) // 4) if preview else 0,
832
+ structure_hint="binary payload (preview decoded as utf-8)",
833
+ sample_preview=preview[:500],
834
+ )
835
+
836
+ # JSON-like
837
+ if isinstance(context, dict):
838
+ text = json.dumps(context, indent=2, ensure_ascii=False)
839
+ keys = list(context.keys())
840
+ hint = f"JSON object with keys: {keys[:10]}" if keys else "JSON object"
841
+ return ContextMetadata(
842
+ format=ContentFormat.JSON,
843
+ size_bytes=len(text.encode("utf-8", errors="ignore")),
844
+ size_chars=len(text),
845
+ size_lines=text.count("\n") + 1,
846
+ size_tokens_estimate=max(1, len(text) // 4) if text else 0,
847
+ structure_hint=hint,
848
+ sample_preview=text[:500],
849
+ )
850
+
851
+ if isinstance(context, list):
852
+ text = json.dumps(context[:100], indent=2, ensure_ascii=False)
853
+ hint = f"JSON array (showing first {min(len(context), 100)} of {len(context)})"
854
+ return ContextMetadata(
855
+ format=ContentFormat.JSON,
856
+ size_bytes=len(text.encode("utf-8", errors="ignore")),
857
+ size_chars=len(text),
858
+ size_lines=text.count("\n") + 1,
859
+ size_tokens_estimate=max(1, len(text) // 4) if text else 0,
860
+ structure_hint=hint,
861
+ sample_preview=text[:500],
862
+ )
863
+
864
+ # Fallback
865
+ text = str(context)
866
+ return ContextMetadata(
867
+ format=ContentFormat.TEXT,
868
+ size_bytes=len(text.encode("utf-8", errors="ignore")),
869
+ size_chars=len(text),
870
+ size_lines=text.count("\n") + 1,
871
+ size_tokens_estimate=max(1, len(text) // 4) if text else 0,
872
+ structure_hint=f"Python object: {type(context).__name__}",
873
+ sample_preview=text[:500],
874
+ )