llmwalk 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,176 @@
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
@@ -0,0 +1,12 @@
1
+ ## Project notes for Codex/agents
2
+
3
+ This repo is a single-file `uv` inline script.
4
+
5
+ - Entry point: `llmtree.py`
6
+ - Dependencies are declared in the `# /// script` block at the top of `llmtree.py`
7
+ - Prefer running via `uv` so dependencies resolve correctly:
8
+ - `uv run llmtree.py -- --help`
9
+ - `uv run llmtree.py -- -p "Your prompt here"`
10
+
11
+ Avoid using `pip install` directly for this repo unless you have a specific reason; the intended workflow is `uv run` with the inline dependency block.
12
+
llmwalk-0.1.2/PKG-INFO ADDED
@@ -0,0 +1,42 @@
1
+ Metadata-Version: 2.4
2
+ Name: llmwalk
3
+ Version: 0.1.2
4
+ Summary: Explore the answer-space for any prompt and any MLX-supported model.
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: mlx-lm==0.28.4
7
+ Requires-Dist: rich==14.2.0
8
+ Requires-Dist: sortedcontainers==2.4.0
9
+ Description-Content-Type: text/markdown
10
+
11
+ # llmwalk
12
+
13
+ Explore the answer-space for any prompt and any MLX-supported model. See
14
+ <https://huggingface.co/mlx-community/models> for supported models.
15
+
16
+ ![Usage example gif](example1.gif)
17
+
18
+ Instead of sampling from the possible tokens each step, llmwalk branches out
19
+ and completes all of the branches the sampler would consider based on
20
+ `--top-k`, `--top-p` and `--temperature`, ranking the results by probability
21
+ as it goes.
22
+
23
+ The tree is walked prioritising the most likely branches, until it finds `-n`
24
+ branches and then it stops. It doesn't enumerate all possibilities, just enough
25
+ to know for sure it has found the `-n` most likely branches.
26
+
27
+ ## Usage
28
+
29
+ - `uvx llmwalk -p "In what year was Barack Obama born?"`
30
+ - `uvx llmwalk -p "Write a haiku about compilers" -n 5`
31
+ - `uvx llmwalk -p "Give me one word: " --top-k 200 --temperature 0.7`
32
+
33
+ ## Options
34
+
35
+ - `-p, --prompt TEXT`: Prompt to score (wrapped with the model’s chat template).
36
+ - `-m, --model MODEL`: MLX-LM model identifier or path (default: `mlx-community/Llama-3.2-1B-Instruct-4bit`), supported models can be found at <https://huggingface.co/mlx-community/models>
37
+ - `-n N`: Number of answers to show. The search stops once it has `N` finished answers and no unfinished branch can beat the worst of those `N`.
38
+ - `--min-probability FLOAT`: Any branch whose cumulative probability falls below this is marked finished (`low_probability`) and not expanded further.
39
+ - `--top-k INT`: At each step, expand at most `k` next tokens (highest probability).
40
+ - `--top-p FLOAT`: Nucleus cutoff applied *within the top-k tokens* at each step (keep adding tokens until cumulative probability ≥ `p`).
41
+ - `--temperature FLOAT`: Softmax temperature applied when computing per-step probabilities (`1.0` is the model distribution; must be `> 0`).
42
+ - `--stats-interval SECONDS`: How often to refresh the live view (`<= 0` disables periodic refresh; still renders at start/end).
@@ -0,0 +1,32 @@
1
+ # llmwalk
2
+
3
+ Explore the answer-space for any prompt and any MLX-supported model. See
4
+ <https://huggingface.co/mlx-community/models> for supported models.
5
+
6
+ ![Usage example gif](example1.gif)
7
+
8
+ Instead of sampling from the possible tokens each step, llmwalk branches out
9
+ and completes all of the branches the sampler would consider based on
10
+ `--top-k`, `--top-p` and `--temperature`, ranking the results by probability
11
+ as it goes.
12
+
13
+ The tree is walked prioritising the most likely branches, until it finds `-n`
14
+ branches and then it stops. It doesn't enumerate all possibilities, just enough
15
+ to know for sure it has found the `-n` most likely branches.
16
+
17
+ ## Usage
18
+
19
+ - `uvx llmwalk -p "In what year was Barack Obama born?"`
20
+ - `uvx llmwalk -p "Write a haiku about compilers" -n 5`
21
+ - `uvx llmwalk -p "Give me one word: " --top-k 200 --temperature 0.7`
22
+
23
+ ## Options
24
+
25
+ - `-p, --prompt TEXT`: Prompt to score (wrapped with the model’s chat template).
26
+ - `-m, --model MODEL`: MLX-LM model identifier or path (default: `mlx-community/Llama-3.2-1B-Instruct-4bit`), supported models can be found at <https://huggingface.co/mlx-community/models>
27
+ - `-n N`: Number of answers to show. The search stops once it has `N` finished answers and no unfinished branch can beat the worst of those `N`.
28
+ - `--min-probability FLOAT`: Any branch whose cumulative probability falls below this is marked finished (`low_probability`) and not expanded further.
29
+ - `--top-k INT`: At each step, expand at most `k` next tokens (highest probability).
30
+ - `--top-p FLOAT`: Nucleus cutoff applied *within the top-k tokens* at each step (keep adding tokens until cumulative probability ≥ `p`).
31
+ - `--temperature FLOAT`: Softmax temperature applied when computing per-step probabilities (`1.0` is the model distribution; must be `> 0`).
32
+ - `--stats-interval SECONDS`: How often to refresh the live view (`<= 0` disables periodic refresh; still renders at start/end).
@@ -0,0 +1,3 @@
1
+ __all__ = ["__version__"]
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+ from .cli import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
@@ -0,0 +1,412 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import heapq
5
+ import sys
6
+ import time
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ from functools import lru_cache
10
+
11
+ import mlx.core as mx
12
+ from mlx.nn import Module
13
+ from mlx_lm import load
14
+ from mlx_lm.models.cache import KVCache
15
+ from mlx_lm.tokenizer_utils import TokenizerWrapper
16
+ from rich.console import Console, Group
17
+ from rich.live import Live
18
+ from rich.style import Style
19
+ from rich.table import Table
20
+ from rich.text import Text
21
+ from sortedcontainers import SortedList
22
+
23
+ args: argparse.Namespace
24
+
25
+ _BAND_COLORS = [
26
+ "#7f7f7f", # 0-10%: grey
27
+ "#ff3b30", # 10-20%: red
28
+ "#ff6a00", # 20-30%: orange
29
+ "#ff8c00", # 30-40%: dark orange
30
+ "#ffb000", # 40-50%: amber
31
+ "#ffd000", # 50-60%: yellow
32
+ "#d7e500", # 60-70%: yellow-green
33
+ "#a8e600", # 70-80%: greenish
34
+ "#4cd964", # 80-90%: green
35
+ "#00c853", # 90-100%: bright green
36
+ ]
37
+ _BAND_STYLES = [Style(color=c) for c in _BAND_COLORS]
38
+
39
+
40
+ @dataclass
41
+ class OutputToken:
42
+ token: int
43
+ prob: float
44
+
45
+
46
+ @dataclass(eq=False)
47
+ class Branch:
48
+ parent: Branch | None
49
+ token: OutputToken | None
50
+ probability: float = 1.0
51
+ finish_reason: str | None = None
52
+ cache: list[KVCache] | None = None
53
+
54
+ def answer_tokens(self) -> list[OutputToken]:
55
+ toks: list[OutputToken] = []
56
+ cur: Branch | None = self
57
+ while cur is not None and cur.token is not None:
58
+ toks.append(cur.token)
59
+ cur = cur.parent
60
+ toks.reverse()
61
+ return toks
62
+
63
+
64
+ def _clone_kv_cache(c: KVCache) -> KVCache:
65
+ cloned = KVCache()
66
+ cloned.offset = c.offset
67
+ cloned.keys = mx.array(c.keys) if c.keys is not None else None
68
+ cloned.values = mx.array(c.values) if c.values is not None else None
69
+ return cloned
70
+
71
+
72
+ def _clone_prompt_cache(cache: list[KVCache]) -> list[KVCache]:
73
+ return [_clone_kv_cache(c) for c in cache]
74
+
75
+
76
+ def _top_tokens_from_logprobs(logprobs: mx.array) -> list[OutputToken]:
77
+ vocab = int(logprobs.shape[0])
78
+ k = min(args.top_k, vocab)
79
+ part = mx.argpartition(logprobs, vocab - k)
80
+ top_idx = part[vocab - k :]
81
+ top_lp = mx.take(logprobs, top_idx)
82
+ order = mx.argsort(top_lp)[::-1]
83
+ sorted_indices = mx.take(top_idx, order)
84
+
85
+ if args.temperature == 1.0:
86
+ probs = mx.exp(mx.take(logprobs, sorted_indices))
87
+ else:
88
+ lse = mx.logsumexp(logprobs / args.temperature, axis=-1)
89
+ probs = mx.exp(mx.take(logprobs, sorted_indices) / args.temperature - lse)
90
+
91
+ mx.eval(sorted_indices, probs)
92
+ token_ids: list[int] = sorted_indices.astype(mx.int64).tolist()
93
+ token_probs: list[float] = mx.reshape(probs, (-1,)).tolist()
94
+
95
+ output_tokens: list[OutputToken] = []
96
+ cum_prob = 0.0
97
+ for token_id, prob in zip(token_ids, token_probs): # type: ignore[call-arg]
98
+ if output_tokens and cum_prob >= args.top_p:
99
+ break
100
+ output_tokens.append(OutputToken(token=token_id, prob=float(prob)))
101
+ cum_prob += float(prob)
102
+
103
+ return output_tokens
104
+
105
+
106
+ class PromptTreeSearch:
107
+ model: Module
108
+ tokenizer: TokenizerWrapper
109
+ prompt: list[int]
110
+ _frontier: list[tuple[float, int, Branch]]
111
+ _finished_eos: SortedList[Branch]
112
+ _heap_counter: int = 0
113
+ _stopped: bool = False
114
+
115
+ tokens: int = 0
116
+ pruned: int = 0
117
+
118
+ _low_watermark: float | None = None
119
+ _start: datetime | None = None
120
+ _end: datetime | None = None
121
+
122
+ def __init__(self, model: Module, tokenizer: TokenizerWrapper, prompt: list[int]) -> None:
123
+ self.model = model
124
+ self.tokenizer = tokenizer
125
+ self.prompt = prompt
126
+ self._frontier = []
127
+ self._finished_eos = SortedList(key=lambda b: -b.probability)
128
+
129
+ root = Branch(parent=None, token=None)
130
+ self.branches = SortedList(key=lambda b: -b.probability)
131
+ self.branches.add(root)
132
+ self._push_frontier(root)
133
+
134
+ @lru_cache(maxsize=65536)
135
+ def decode_token(self, token_id: int) -> str:
136
+ return self.tokenizer.decode([token_id], skip_special_tokens=True) # type: ignore[call-arg]
137
+
138
+ def _run_model(self, cache: list[KVCache], input_ids: list[int]) -> mx.array:
139
+ inputs = mx.array([input_ids], mx.int32)
140
+ logits = self.model(inputs, cache=cache)[:, -1, :]
141
+ logits = logits.astype(mx.float32)
142
+ logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
143
+ return mx.reshape(logprobs, (-1,))
144
+
145
+ @property
146
+ def active(self) -> int:
147
+ return len(self._frontier)
148
+
149
+ def top_branches(self, n: int) -> list[Branch]:
150
+ return list(self.branches[:n])
151
+
152
+ def _push_frontier(self, branch: Branch) -> None:
153
+ self._heap_counter += 1
154
+ heapq.heappush(self._frontier, (-branch.probability, self._heap_counter, branch))
155
+
156
+ def _update_low_watermark(self) -> None:
157
+ if len(self._finished_eos) < args.n:
158
+ self._low_watermark = None
159
+ return
160
+ self._low_watermark = self._finished_eos[args.n - 1].probability
161
+
162
+ def stop(self) -> None:
163
+ self._stopped = True
164
+
165
+ def should_stop(self) -> bool:
166
+ if self._stopped:
167
+ return True
168
+ if not self._frontier:
169
+ return True
170
+ if self._low_watermark is None:
171
+ return False
172
+ best_prob = -self._frontier[0][0]
173
+ return best_prob < self._low_watermark
174
+
175
+ def begin(self) -> None:
176
+ self._start = datetime.now()
177
+
178
+ def step(self) -> None:
179
+ if self.should_stop():
180
+ return
181
+
182
+ _, _, branch = heapq.heappop(self._frontier)
183
+
184
+ if self._low_watermark is not None and branch.probability < self._low_watermark:
185
+ self.pruned += 1
186
+ branch.finish_reason = "pruned"
187
+ branch.cache = None
188
+ return
189
+
190
+ if branch.token is None: # root branch
191
+ cache_after = self.model.make_cache() if hasattr(self.model, "make_cache") else [] # type: ignore[assignment]
192
+ logprobs = self._run_model(cache_after, self.prompt)
193
+ else:
194
+ if branch.cache is None:
195
+ raise RuntimeError("Branch cache missing while expanding leaf")
196
+ cache_after = _clone_prompt_cache(branch.cache)
197
+ logprobs = self._run_model(cache_after, [branch.token.token])
198
+
199
+ self.branches.remove(branch)
200
+
201
+ new_branches: list[Branch] = []
202
+ frontier_add: list[Branch] = []
203
+ eos_add: list[Branch] = []
204
+ for tok in _top_tokens_from_logprobs(logprobs):
205
+ new_prob = branch.probability * tok.prob
206
+
207
+ if new_prob < args.min_probability:
208
+ self.pruned += 1
209
+ new_branch = Branch(
210
+ parent=branch,
211
+ token=tok,
212
+ probability=new_prob,
213
+ finish_reason="low_probability",
214
+ )
215
+ new_branches.append(new_branch)
216
+ continue
217
+
218
+ if tok.token in self.tokenizer.eos_token_ids:
219
+ new_branch = Branch(
220
+ parent=branch,
221
+ token=tok,
222
+ probability=new_prob,
223
+ finish_reason="eos_token",
224
+ )
225
+ new_branches.append(new_branch)
226
+ eos_add.append(new_branch)
227
+ continue
228
+
229
+ new_branch = Branch(
230
+ parent=branch,
231
+ token=tok,
232
+ probability=new_prob,
233
+ cache=cache_after,
234
+ )
235
+ new_branches.append(new_branch)
236
+ frontier_add.append(new_branch)
237
+
238
+ for b in new_branches:
239
+ self.branches.add(b)
240
+ for b in frontier_add:
241
+ self._push_frontier(b)
242
+ for b in eos_add:
243
+ self._finished_eos.add(b)
244
+ if eos_add:
245
+ self._update_low_watermark()
246
+ self.tokens += 1
247
+
248
+ branch.cache = None
249
+
250
+
251
+ def style_for_token_probability(prob: float) -> Style:
252
+ if prob != prob: # NaN
253
+ prob = 0.0
254
+ elif prob < 0.0:
255
+ prob = 0.0
256
+ elif prob > 1.0:
257
+ prob = 1.0
258
+
259
+ band = min(int(prob * 10), 9) # 0..9
260
+ return _BAND_STYLES[band]
261
+
262
+
263
+ def render_probability_legend() -> Text:
264
+ legend = Text("Legend: ", style="bold", no_wrap=True, overflow="ellipsis")
265
+ for i in range(9, -1, -1):
266
+ style = style_for_token_probability((i + 0.5) / 10)
267
+ if i == 9:
268
+ label = "90%+"
269
+ elif i == 0:
270
+ label = "0–10%"
271
+ else:
272
+ label = f"{i * 10}%+"
273
+
274
+ legend.append("■", style=style)
275
+ legend.append(f" {label}")
276
+ if i != 0:
277
+ legend.append(" ")
278
+
279
+ return legend
280
+
281
+
282
+ _PROBABILITY_LEGEND = render_probability_legend()
283
+
284
+
285
+ def render_branches(walker: PromptTreeSearch) -> Table:
286
+ table = Table(expand=True)
287
+ table.add_column("Fin", justify="center", no_wrap=True, width=3)
288
+ table.add_column("Prob.", justify="right", no_wrap=True, width=8)
289
+ table.add_column("Answer", ratio=1)
290
+
291
+ branches = walker.top_branches(args.n)
292
+ for i in range(args.n):
293
+ if i >= len(branches):
294
+ table.add_row("", "", "")
295
+ continue
296
+
297
+ branch = branches[i]
298
+ answer_text = Text()
299
+ for tok in branch.answer_tokens():
300
+ piece = walker.decode_token(tok.token)
301
+ if not piece:
302
+ continue
303
+ answer_text.append(piece, style=style_for_token_probability(tok.prob))
304
+ probability_text = f"{branch.probability * 100:6.2f}%"
305
+ status: Text
306
+ if branch.finish_reason == "eos_token":
307
+ status = Text("✓", style="green")
308
+ elif branch.finish_reason == "low_probability":
309
+ status = Text("✓", style="yellow")
310
+ elif branch.finish_reason == "pruned":
311
+ status = Text("✓", style="dim")
312
+ else:
313
+ status = Text("")
314
+
315
+ table.add_row(status, probability_text, answer_text)
316
+
317
+ return table
318
+
319
+
320
+ def render_stats_bar(walker: PromptTreeSearch) -> Table:
321
+ elapsed = (datetime.now() - walker._start).total_seconds() if walker._start else 0.0
322
+ tps = walker.tokens / elapsed if elapsed > 0 else 0.0
323
+ left = f"active {walker.active} pruned {walker.pruned} tps {tps:0.1f}"
324
+ grid = Table.grid(expand=True)
325
+ grid.add_column(ratio=1)
326
+ grid.add_column(justify="right", no_wrap=True)
327
+ grid.add_row(
328
+ Text(left, overflow="ellipsis", no_wrap=True),
329
+ Text(
330
+ f"top_k={args.top_k} top_p={args.top_p} temp={args.temperature}",
331
+ no_wrap=True,
332
+ ),
333
+ )
334
+ return grid
335
+
336
+
337
+ def render_view(walker: PromptTreeSearch) -> Group:
338
+ return Group(
339
+ _PROBABILITY_LEGEND,
340
+ render_branches(walker),
341
+ render_stats_bar(walker),
342
+ )
343
+
344
+
345
+ def run() -> None:
346
+ load_resp = load(args.model)
347
+ model = load_resp[0]
348
+ tokenizer = load_resp[1]
349
+
350
+ prompt = tokenizer.apply_chat_template( # type: ignore[call-arg]
351
+ [{"role": "user", "content": args.prompt}],
352
+ add_generation_prompt=True,
353
+ )
354
+
355
+ console = Console()
356
+
357
+ walker = PromptTreeSearch(model, tokenizer, prompt)
358
+ walker.begin()
359
+
360
+ try:
361
+ with Live(console=console, transient=False) as live:
362
+ interval = max(0.1, args.stats_interval)
363
+ next_render = time.monotonic()
364
+ live.update(render_view(walker))
365
+ while not walker.should_stop():
366
+ walker.step()
367
+ if args.stats_interval > 0 and time.monotonic() >= next_render:
368
+ live.update(render_view(walker))
369
+ next_render = time.monotonic() + interval
370
+ live.update(render_view(walker))
371
+ except KeyboardInterrupt:
372
+ walker.stop()
373
+
374
+
375
+ def parse_args(argv: list[str] | None) -> argparse.Namespace:
376
+ parser = argparse.ArgumentParser()
377
+ parser.add_argument("-p", "--prompt", default="What is 2+2?", help="Prompt to score")
378
+ parser.add_argument("-m", "--model", default="mlx-community/Llama-3.2-1B-Instruct-4bit")
379
+ parser.add_argument("-n", default=10, type=int, help="Number of answers to show")
380
+ parser.add_argument("--min-probability", type=float, default=0.0001)
381
+ parser.add_argument("--top-k", dest="top_k", default=50, type=int)
382
+ parser.add_argument(
383
+ "--top-p",
384
+ dest="top_p",
385
+ default=1.0,
386
+ type=float,
387
+ help="Nucleus sampling threshold (0 < p <= 1)",
388
+ )
389
+ parser.add_argument("--temperature", type=float, default=1.0, help="Softmax temperature (> 0)")
390
+ parser.add_argument(
391
+ "--stats-interval",
392
+ type=float,
393
+ default=0.1,
394
+ help="Seconds between live stats bar updates (<=0 disables)",
395
+ )
396
+
397
+ raw = list(sys.argv[1:] if argv is None else argv)
398
+ filtered = [a for a in raw if a != "--"]
399
+ parsed = parser.parse_args(filtered)
400
+
401
+ if parsed.temperature <= 0:
402
+ parser.error("--temperature must be > 0")
403
+ if not (0 < parsed.top_p <= 1):
404
+ parser.error("--top-p must be in the range (0, 1]")
405
+
406
+ return parsed
407
+
408
+
409
+ def main(argv: list[str] | None = None) -> None:
410
+ global args
411
+ args = parse_args(argv)
412
+ run()
@@ -0,0 +1,26 @@
1
+ [build-system]
2
+ requires = ["hatchling>=1.22.0"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "llmwalk"
7
+ version = "0.1.2"
8
+ description = "Explore the answer-space for any prompt and any MLX-supported model."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "mlx-lm==0.28.4",
13
+ "rich==14.2.0",
14
+ "sortedcontainers==2.4.0",
15
+ ]
16
+
17
+ [project.scripts]
18
+ llmwalk = "llmwalk.cli:main"
19
+
20
+ [tool.hatch.build]
21
+ exclude = [
22
+ "example1.gif",
23
+ ]
24
+
25
+ [tool.hatch.build.targets.wheel]
26
+ packages = ["llmwalk"]