llmwalk 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmwalk-0.1.2/.gitignore +176 -0
- llmwalk-0.1.2/AGENTS.md +12 -0
- llmwalk-0.1.2/PKG-INFO +42 -0
- llmwalk-0.1.2/README.md +32 -0
- llmwalk-0.1.2/llmwalk/__init__.py +3 -0
- llmwalk-0.1.2/llmwalk/__main__.py +6 -0
- llmwalk-0.1.2/llmwalk/cli.py +412 -0
- llmwalk-0.1.2/pyproject.toml +26 -0
llmwalk-0.1.2/.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Created by https://www.toptal.com/developers/gitignore/api/python
|
|
2
|
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
|
3
|
+
|
|
4
|
+
### Python ###
|
|
5
|
+
# Byte-compiled / optimized / DLL files
|
|
6
|
+
__pycache__/
|
|
7
|
+
*.py[cod]
|
|
8
|
+
*$py.class
|
|
9
|
+
|
|
10
|
+
# C extensions
|
|
11
|
+
*.so
|
|
12
|
+
|
|
13
|
+
# Distribution / packaging
|
|
14
|
+
.Python
|
|
15
|
+
build/
|
|
16
|
+
develop-eggs/
|
|
17
|
+
dist/
|
|
18
|
+
downloads/
|
|
19
|
+
eggs/
|
|
20
|
+
.eggs/
|
|
21
|
+
lib/
|
|
22
|
+
lib64/
|
|
23
|
+
parts/
|
|
24
|
+
sdist/
|
|
25
|
+
var/
|
|
26
|
+
wheels/
|
|
27
|
+
share/python-wheels/
|
|
28
|
+
*.egg-info/
|
|
29
|
+
.installed.cfg
|
|
30
|
+
*.egg
|
|
31
|
+
MANIFEST
|
|
32
|
+
|
|
33
|
+
# PyInstaller
|
|
34
|
+
# Usually these files are written by a python script from a template
|
|
35
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
36
|
+
*.manifest
|
|
37
|
+
*.spec
|
|
38
|
+
|
|
39
|
+
# Installer logs
|
|
40
|
+
pip-log.txt
|
|
41
|
+
pip-delete-this-directory.txt
|
|
42
|
+
|
|
43
|
+
# Unit test / coverage reports
|
|
44
|
+
htmlcov/
|
|
45
|
+
.tox/
|
|
46
|
+
.nox/
|
|
47
|
+
.coverage
|
|
48
|
+
.coverage.*
|
|
49
|
+
.cache
|
|
50
|
+
nosetests.xml
|
|
51
|
+
coverage.xml
|
|
52
|
+
*.cover
|
|
53
|
+
*.py,cover
|
|
54
|
+
.hypothesis/
|
|
55
|
+
.pytest_cache/
|
|
56
|
+
cover/
|
|
57
|
+
|
|
58
|
+
# Translations
|
|
59
|
+
*.mo
|
|
60
|
+
*.pot
|
|
61
|
+
|
|
62
|
+
# Django stuff:
|
|
63
|
+
*.log
|
|
64
|
+
local_settings.py
|
|
65
|
+
db.sqlite3
|
|
66
|
+
db.sqlite3-journal
|
|
67
|
+
|
|
68
|
+
# Flask stuff:
|
|
69
|
+
instance/
|
|
70
|
+
.webassets-cache
|
|
71
|
+
|
|
72
|
+
# Scrapy stuff:
|
|
73
|
+
.scrapy
|
|
74
|
+
|
|
75
|
+
# Sphinx documentation
|
|
76
|
+
docs/_build/
|
|
77
|
+
|
|
78
|
+
# PyBuilder
|
|
79
|
+
.pybuilder/
|
|
80
|
+
target/
|
|
81
|
+
|
|
82
|
+
# Jupyter Notebook
|
|
83
|
+
.ipynb_checkpoints
|
|
84
|
+
|
|
85
|
+
# IPython
|
|
86
|
+
profile_default/
|
|
87
|
+
ipython_config.py
|
|
88
|
+
|
|
89
|
+
# pyenv
|
|
90
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
91
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
92
|
+
# .python-version
|
|
93
|
+
|
|
94
|
+
# pipenv
|
|
95
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
96
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
97
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
98
|
+
# install all needed dependencies.
|
|
99
|
+
#Pipfile.lock
|
|
100
|
+
|
|
101
|
+
# poetry
|
|
102
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
103
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
104
|
+
# commonly ignored for libraries.
|
|
105
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
106
|
+
#poetry.lock
|
|
107
|
+
|
|
108
|
+
# pdm
|
|
109
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
110
|
+
#pdm.lock
|
|
111
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
112
|
+
# in version control.
|
|
113
|
+
# https://pdm.fming.dev/#use-with-ide
|
|
114
|
+
.pdm.toml
|
|
115
|
+
|
|
116
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
117
|
+
__pypackages__/
|
|
118
|
+
|
|
119
|
+
# Celery stuff
|
|
120
|
+
celerybeat-schedule
|
|
121
|
+
celerybeat.pid
|
|
122
|
+
|
|
123
|
+
# SageMath parsed files
|
|
124
|
+
*.sage.py
|
|
125
|
+
|
|
126
|
+
# Environments
|
|
127
|
+
.env
|
|
128
|
+
.venv
|
|
129
|
+
env/
|
|
130
|
+
venv/
|
|
131
|
+
ENV/
|
|
132
|
+
env.bak/
|
|
133
|
+
venv.bak/
|
|
134
|
+
|
|
135
|
+
# Spyder project settings
|
|
136
|
+
.spyderproject
|
|
137
|
+
.spyproject
|
|
138
|
+
|
|
139
|
+
# Rope project settings
|
|
140
|
+
.ropeproject
|
|
141
|
+
|
|
142
|
+
# mkdocs documentation
|
|
143
|
+
/site
|
|
144
|
+
|
|
145
|
+
# mypy
|
|
146
|
+
.mypy_cache/
|
|
147
|
+
.dmypy.json
|
|
148
|
+
dmypy.json
|
|
149
|
+
|
|
150
|
+
# Pyre type checker
|
|
151
|
+
.pyre/
|
|
152
|
+
|
|
153
|
+
# pytype static type analyzer
|
|
154
|
+
.pytype/
|
|
155
|
+
|
|
156
|
+
# Cython debug symbols
|
|
157
|
+
cython_debug/
|
|
158
|
+
|
|
159
|
+
# PyCharm
|
|
160
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
161
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
162
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
163
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
164
|
+
#.idea/
|
|
165
|
+
|
|
166
|
+
### Python Patch ###
|
|
167
|
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
|
168
|
+
poetry.toml
|
|
169
|
+
|
|
170
|
+
# ruff
|
|
171
|
+
.ruff_cache/
|
|
172
|
+
|
|
173
|
+
# LSP config files
|
|
174
|
+
pyrightconfig.json
|
|
175
|
+
|
|
176
|
+
# End of https://www.toptal.com/developers/gitignore/api/python
|
llmwalk-0.1.2/AGENTS.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
## Project notes for Codex/agents
|
|
2
|
+
|
|
3
|
+
This repo is a single-file `uv` inline script.
|
|
4
|
+
|
|
5
|
+
- Entry point: `llmtree.py`
|
|
6
|
+
- Dependencies are declared in the `# /// script` block at the top of `llmtree.py`
|
|
7
|
+
- Prefer running via `uv` so dependencies resolve correctly:
|
|
8
|
+
- `uv run llmtree.py -- --help`
|
|
9
|
+
- `uv run llmtree.py -- -p "Your prompt here"`
|
|
10
|
+
|
|
11
|
+
Avoid using `pip install` directly for this repo unless you have a specific reason; the intended workflow is `uv run` with the inline dependency block.
|
|
12
|
+
|
llmwalk-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: llmwalk
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Explore the answer-space for any prompt and any MLX-supported model.
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Requires-Dist: mlx-lm==0.28.4
|
|
7
|
+
Requires-Dist: rich==14.2.0
|
|
8
|
+
Requires-Dist: sortedcontainers==2.4.0
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
|
|
11
|
+
# llmwalk
|
|
12
|
+
|
|
13
|
+
Explore the answer-space for any prompt and any MLX-supported model. See
|
|
14
|
+
<https://huggingface.co/mlx-community/models> for supported models.
|
|
15
|
+
|
|
16
|
+

|
|
17
|
+
|
|
18
|
+
Instead of sampling from the possible tokens each step, llmwalk branches out
|
|
19
|
+
and completes all of the branches the sampler would consider based on
|
|
20
|
+
`--top-k`, `--top-p` and `--temperature`, ranking the results by probability
|
|
21
|
+
as it goes.
|
|
22
|
+
|
|
23
|
+
The tree is walked prioritising the most likely branches, until it finds `-n`
|
|
24
|
+
branches and then it stops. It doesn't enumerate all possibilities, just enough
|
|
25
|
+
to know for sure it has found the `-n` most likely branches.
|
|
26
|
+
|
|
27
|
+
## Usage
|
|
28
|
+
|
|
29
|
+
- `uvx llmwalk -p "In what year was Barack Obama born?"`
|
|
30
|
+
- `uvx llmwalk -p "Write a haiku about compilers" -n 5`
|
|
31
|
+
- `uvx llmwalk -p "Give me one word: " --top-k 200 --temperature 0.7`
|
|
32
|
+
|
|
33
|
+
## Options
|
|
34
|
+
|
|
35
|
+
- `-p, --prompt TEXT`: Prompt to score (wrapped with the model’s chat template).
|
|
36
|
+
- `-m, --model MODEL`: MLX-LM model identifier or path (default: `mlx-community/Llama-3.2-1B-Instruct-4bit`), supported models can be found at <https://huggingface.co/mlx-community/models>
|
|
37
|
+
- `-n N`: Number of answers to show. The search stops once it has `N` finished answers and no unfinished branch can beat the worst of those `N`.
|
|
38
|
+
- `--min-probability FLOAT`: Any branch whose cumulative probability falls below this is marked finished (`low_probability`) and not expanded further.
|
|
39
|
+
- `--top-k INT`: At each step, expand at most `k` next tokens (highest probability).
|
|
40
|
+
- `--top-p FLOAT`: Nucleus cutoff applied *within the top-k tokens* at each step (keep adding tokens until cumulative probability ≥ `p`).
|
|
41
|
+
- `--temperature FLOAT`: Softmax temperature applied when computing per-step probabilities (`1.0` is the model distribution; must be `> 0`).
|
|
42
|
+
- `--stats-interval SECONDS`: How often to refresh the live view (`<= 0` disables periodic refresh; still renders at start/end).
|
llmwalk-0.1.2/README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# llmwalk
|
|
2
|
+
|
|
3
|
+
Explore the answer-space for any prompt and any MLX-supported model. See
|
|
4
|
+
<https://huggingface.co/mlx-community/models> for supported models.
|
|
5
|
+
|
|
6
|
+

|
|
7
|
+
|
|
8
|
+
Instead of sampling from the possible tokens each step, llmwalk branches out
|
|
9
|
+
and completes all of the branches the sampler would consider based on
|
|
10
|
+
`--top-k`, `--top-p` and `--temperature`, ranking the results by probability
|
|
11
|
+
as it goes.
|
|
12
|
+
|
|
13
|
+
The tree is walked prioritising the most likely branches, until it finds `-n`
|
|
14
|
+
branches and then it stops. It doesn't enumerate all possibilities, just enough
|
|
15
|
+
to know for sure it has found the `-n` most likely branches.
|
|
16
|
+
|
|
17
|
+
## Usage
|
|
18
|
+
|
|
19
|
+
- `uvx llmwalk -p "In what year was Barack Obama born?"`
|
|
20
|
+
- `uvx llmwalk -p "Write a haiku about compilers" -n 5`
|
|
21
|
+
- `uvx llmwalk -p "Give me one word: " --top-k 200 --temperature 0.7`
|
|
22
|
+
|
|
23
|
+
## Options
|
|
24
|
+
|
|
25
|
+
- `-p, --prompt TEXT`: Prompt to score (wrapped with the model’s chat template).
|
|
26
|
+
- `-m, --model MODEL`: MLX-LM model identifier or path (default: `mlx-community/Llama-3.2-1B-Instruct-4bit`), supported models can be found at <https://huggingface.co/mlx-community/models>
|
|
27
|
+
- `-n N`: Number of answers to show. The search stops once it has `N` finished answers and no unfinished branch can beat the worst of those `N`.
|
|
28
|
+
- `--min-probability FLOAT`: Any branch whose cumulative probability falls below this is marked finished (`low_probability`) and not expanded further.
|
|
29
|
+
- `--top-k INT`: At each step, expand at most `k` next tokens (highest probability).
|
|
30
|
+
- `--top-p FLOAT`: Nucleus cutoff applied *within the top-k tokens* at each step (keep adding tokens until cumulative probability ≥ `p`).
|
|
31
|
+
- `--temperature FLOAT`: Softmax temperature applied when computing per-step probabilities (`1.0` is the model distribution; must be `> 0`).
|
|
32
|
+
- `--stats-interval SECONDS`: How often to refresh the live view (`<= 0` disables periodic refresh; still renders at start/end).
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import heapq
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from functools import lru_cache
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
from mlx.nn import Module
|
|
13
|
+
from mlx_lm import load
|
|
14
|
+
from mlx_lm.models.cache import KVCache
|
|
15
|
+
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
|
16
|
+
from rich.console import Console, Group
|
|
17
|
+
from rich.live import Live
|
|
18
|
+
from rich.style import Style
|
|
19
|
+
from rich.table import Table
|
|
20
|
+
from rich.text import Text
|
|
21
|
+
from sortedcontainers import SortedList
|
|
22
|
+
|
|
23
|
+
args: argparse.Namespace
|
|
24
|
+
|
|
25
|
+
_BAND_COLORS = [
|
|
26
|
+
"#7f7f7f", # 0-10%: grey
|
|
27
|
+
"#ff3b30", # 10-20%: red
|
|
28
|
+
"#ff6a00", # 20-30%: orange
|
|
29
|
+
"#ff8c00", # 30-40%: dark orange
|
|
30
|
+
"#ffb000", # 40-50%: amber
|
|
31
|
+
"#ffd000", # 50-60%: yellow
|
|
32
|
+
"#d7e500", # 60-70%: yellow-green
|
|
33
|
+
"#a8e600", # 70-80%: greenish
|
|
34
|
+
"#4cd964", # 80-90%: green
|
|
35
|
+
"#00c853", # 90-100%: bright green
|
|
36
|
+
]
|
|
37
|
+
_BAND_STYLES = [Style(color=c) for c in _BAND_COLORS]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class OutputToken:
|
|
42
|
+
token: int
|
|
43
|
+
prob: float
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(eq=False)
|
|
47
|
+
class Branch:
|
|
48
|
+
parent: Branch | None
|
|
49
|
+
token: OutputToken | None
|
|
50
|
+
probability: float = 1.0
|
|
51
|
+
finish_reason: str | None = None
|
|
52
|
+
cache: list[KVCache] | None = None
|
|
53
|
+
|
|
54
|
+
def answer_tokens(self) -> list[OutputToken]:
|
|
55
|
+
toks: list[OutputToken] = []
|
|
56
|
+
cur: Branch | None = self
|
|
57
|
+
while cur is not None and cur.token is not None:
|
|
58
|
+
toks.append(cur.token)
|
|
59
|
+
cur = cur.parent
|
|
60
|
+
toks.reverse()
|
|
61
|
+
return toks
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _clone_kv_cache(c: KVCache) -> KVCache:
|
|
65
|
+
cloned = KVCache()
|
|
66
|
+
cloned.offset = c.offset
|
|
67
|
+
cloned.keys = mx.array(c.keys) if c.keys is not None else None
|
|
68
|
+
cloned.values = mx.array(c.values) if c.values is not None else None
|
|
69
|
+
return cloned
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _clone_prompt_cache(cache: list[KVCache]) -> list[KVCache]:
|
|
73
|
+
return [_clone_kv_cache(c) for c in cache]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _top_tokens_from_logprobs(logprobs: mx.array) -> list[OutputToken]:
|
|
77
|
+
vocab = int(logprobs.shape[0])
|
|
78
|
+
k = min(args.top_k, vocab)
|
|
79
|
+
part = mx.argpartition(logprobs, vocab - k)
|
|
80
|
+
top_idx = part[vocab - k :]
|
|
81
|
+
top_lp = mx.take(logprobs, top_idx)
|
|
82
|
+
order = mx.argsort(top_lp)[::-1]
|
|
83
|
+
sorted_indices = mx.take(top_idx, order)
|
|
84
|
+
|
|
85
|
+
if args.temperature == 1.0:
|
|
86
|
+
probs = mx.exp(mx.take(logprobs, sorted_indices))
|
|
87
|
+
else:
|
|
88
|
+
lse = mx.logsumexp(logprobs / args.temperature, axis=-1)
|
|
89
|
+
probs = mx.exp(mx.take(logprobs, sorted_indices) / args.temperature - lse)
|
|
90
|
+
|
|
91
|
+
mx.eval(sorted_indices, probs)
|
|
92
|
+
token_ids: list[int] = sorted_indices.astype(mx.int64).tolist()
|
|
93
|
+
token_probs: list[float] = mx.reshape(probs, (-1,)).tolist()
|
|
94
|
+
|
|
95
|
+
output_tokens: list[OutputToken] = []
|
|
96
|
+
cum_prob = 0.0
|
|
97
|
+
for token_id, prob in zip(token_ids, token_probs): # type: ignore[call-arg]
|
|
98
|
+
if output_tokens and cum_prob >= args.top_p:
|
|
99
|
+
break
|
|
100
|
+
output_tokens.append(OutputToken(token=token_id, prob=float(prob)))
|
|
101
|
+
cum_prob += float(prob)
|
|
102
|
+
|
|
103
|
+
return output_tokens
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class PromptTreeSearch:
|
|
107
|
+
model: Module
|
|
108
|
+
tokenizer: TokenizerWrapper
|
|
109
|
+
prompt: list[int]
|
|
110
|
+
_frontier: list[tuple[float, int, Branch]]
|
|
111
|
+
_finished_eos: SortedList[Branch]
|
|
112
|
+
_heap_counter: int = 0
|
|
113
|
+
_stopped: bool = False
|
|
114
|
+
|
|
115
|
+
tokens: int = 0
|
|
116
|
+
pruned: int = 0
|
|
117
|
+
|
|
118
|
+
_low_watermark: float | None = None
|
|
119
|
+
_start: datetime | None = None
|
|
120
|
+
_end: datetime | None = None
|
|
121
|
+
|
|
122
|
+
def __init__(self, model: Module, tokenizer: TokenizerWrapper, prompt: list[int]) -> None:
|
|
123
|
+
self.model = model
|
|
124
|
+
self.tokenizer = tokenizer
|
|
125
|
+
self.prompt = prompt
|
|
126
|
+
self._frontier = []
|
|
127
|
+
self._finished_eos = SortedList(key=lambda b: -b.probability)
|
|
128
|
+
|
|
129
|
+
root = Branch(parent=None, token=None)
|
|
130
|
+
self.branches = SortedList(key=lambda b: -b.probability)
|
|
131
|
+
self.branches.add(root)
|
|
132
|
+
self._push_frontier(root)
|
|
133
|
+
|
|
134
|
+
@lru_cache(maxsize=65536)
|
|
135
|
+
def decode_token(self, token_id: int) -> str:
|
|
136
|
+
return self.tokenizer.decode([token_id], skip_special_tokens=True) # type: ignore[call-arg]
|
|
137
|
+
|
|
138
|
+
def _run_model(self, cache: list[KVCache], input_ids: list[int]) -> mx.array:
|
|
139
|
+
inputs = mx.array([input_ids], mx.int32)
|
|
140
|
+
logits = self.model(inputs, cache=cache)[:, -1, :]
|
|
141
|
+
logits = logits.astype(mx.float32)
|
|
142
|
+
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
|
143
|
+
return mx.reshape(logprobs, (-1,))
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def active(self) -> int:
|
|
147
|
+
return len(self._frontier)
|
|
148
|
+
|
|
149
|
+
def top_branches(self, n: int) -> list[Branch]:
|
|
150
|
+
return list(self.branches[:n])
|
|
151
|
+
|
|
152
|
+
def _push_frontier(self, branch: Branch) -> None:
|
|
153
|
+
self._heap_counter += 1
|
|
154
|
+
heapq.heappush(self._frontier, (-branch.probability, self._heap_counter, branch))
|
|
155
|
+
|
|
156
|
+
def _update_low_watermark(self) -> None:
|
|
157
|
+
if len(self._finished_eos) < args.n:
|
|
158
|
+
self._low_watermark = None
|
|
159
|
+
return
|
|
160
|
+
self._low_watermark = self._finished_eos[args.n - 1].probability
|
|
161
|
+
|
|
162
|
+
def stop(self) -> None:
|
|
163
|
+
self._stopped = True
|
|
164
|
+
|
|
165
|
+
def should_stop(self) -> bool:
|
|
166
|
+
if self._stopped:
|
|
167
|
+
return True
|
|
168
|
+
if not self._frontier:
|
|
169
|
+
return True
|
|
170
|
+
if self._low_watermark is None:
|
|
171
|
+
return False
|
|
172
|
+
best_prob = -self._frontier[0][0]
|
|
173
|
+
return best_prob < self._low_watermark
|
|
174
|
+
|
|
175
|
+
def begin(self) -> None:
|
|
176
|
+
self._start = datetime.now()
|
|
177
|
+
|
|
178
|
+
def step(self) -> None:
|
|
179
|
+
if self.should_stop():
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
_, _, branch = heapq.heappop(self._frontier)
|
|
183
|
+
|
|
184
|
+
if self._low_watermark is not None and branch.probability < self._low_watermark:
|
|
185
|
+
self.pruned += 1
|
|
186
|
+
branch.finish_reason = "pruned"
|
|
187
|
+
branch.cache = None
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
if branch.token is None: # root branch
|
|
191
|
+
cache_after = self.model.make_cache() if hasattr(self.model, "make_cache") else [] # type: ignore[assignment]
|
|
192
|
+
logprobs = self._run_model(cache_after, self.prompt)
|
|
193
|
+
else:
|
|
194
|
+
if branch.cache is None:
|
|
195
|
+
raise RuntimeError("Branch cache missing while expanding leaf")
|
|
196
|
+
cache_after = _clone_prompt_cache(branch.cache)
|
|
197
|
+
logprobs = self._run_model(cache_after, [branch.token.token])
|
|
198
|
+
|
|
199
|
+
self.branches.remove(branch)
|
|
200
|
+
|
|
201
|
+
new_branches: list[Branch] = []
|
|
202
|
+
frontier_add: list[Branch] = []
|
|
203
|
+
eos_add: list[Branch] = []
|
|
204
|
+
for tok in _top_tokens_from_logprobs(logprobs):
|
|
205
|
+
new_prob = branch.probability * tok.prob
|
|
206
|
+
|
|
207
|
+
if new_prob < args.min_probability:
|
|
208
|
+
self.pruned += 1
|
|
209
|
+
new_branch = Branch(
|
|
210
|
+
parent=branch,
|
|
211
|
+
token=tok,
|
|
212
|
+
probability=new_prob,
|
|
213
|
+
finish_reason="low_probability",
|
|
214
|
+
)
|
|
215
|
+
new_branches.append(new_branch)
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
if tok.token in self.tokenizer.eos_token_ids:
|
|
219
|
+
new_branch = Branch(
|
|
220
|
+
parent=branch,
|
|
221
|
+
token=tok,
|
|
222
|
+
probability=new_prob,
|
|
223
|
+
finish_reason="eos_token",
|
|
224
|
+
)
|
|
225
|
+
new_branches.append(new_branch)
|
|
226
|
+
eos_add.append(new_branch)
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
new_branch = Branch(
|
|
230
|
+
parent=branch,
|
|
231
|
+
token=tok,
|
|
232
|
+
probability=new_prob,
|
|
233
|
+
cache=cache_after,
|
|
234
|
+
)
|
|
235
|
+
new_branches.append(new_branch)
|
|
236
|
+
frontier_add.append(new_branch)
|
|
237
|
+
|
|
238
|
+
for b in new_branches:
|
|
239
|
+
self.branches.add(b)
|
|
240
|
+
for b in frontier_add:
|
|
241
|
+
self._push_frontier(b)
|
|
242
|
+
for b in eos_add:
|
|
243
|
+
self._finished_eos.add(b)
|
|
244
|
+
if eos_add:
|
|
245
|
+
self._update_low_watermark()
|
|
246
|
+
self.tokens += 1
|
|
247
|
+
|
|
248
|
+
branch.cache = None
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def style_for_token_probability(prob: float) -> Style:
|
|
252
|
+
if prob != prob: # NaN
|
|
253
|
+
prob = 0.0
|
|
254
|
+
elif prob < 0.0:
|
|
255
|
+
prob = 0.0
|
|
256
|
+
elif prob > 1.0:
|
|
257
|
+
prob = 1.0
|
|
258
|
+
|
|
259
|
+
band = min(int(prob * 10), 9) # 0..9
|
|
260
|
+
return _BAND_STYLES[band]
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def render_probability_legend() -> Text:
|
|
264
|
+
legend = Text("Legend: ", style="bold", no_wrap=True, overflow="ellipsis")
|
|
265
|
+
for i in range(9, -1, -1):
|
|
266
|
+
style = style_for_token_probability((i + 0.5) / 10)
|
|
267
|
+
if i == 9:
|
|
268
|
+
label = "90%+"
|
|
269
|
+
elif i == 0:
|
|
270
|
+
label = "0–10%"
|
|
271
|
+
else:
|
|
272
|
+
label = f"{i * 10}%+"
|
|
273
|
+
|
|
274
|
+
legend.append("■", style=style)
|
|
275
|
+
legend.append(f" {label}")
|
|
276
|
+
if i != 0:
|
|
277
|
+
legend.append(" ")
|
|
278
|
+
|
|
279
|
+
return legend
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
_PROBABILITY_LEGEND = render_probability_legend()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def render_branches(walker: PromptTreeSearch) -> Table:
|
|
286
|
+
table = Table(expand=True)
|
|
287
|
+
table.add_column("Fin", justify="center", no_wrap=True, width=3)
|
|
288
|
+
table.add_column("Prob.", justify="right", no_wrap=True, width=8)
|
|
289
|
+
table.add_column("Answer", ratio=1)
|
|
290
|
+
|
|
291
|
+
branches = walker.top_branches(args.n)
|
|
292
|
+
for i in range(args.n):
|
|
293
|
+
if i >= len(branches):
|
|
294
|
+
table.add_row("", "", "")
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
branch = branches[i]
|
|
298
|
+
answer_text = Text()
|
|
299
|
+
for tok in branch.answer_tokens():
|
|
300
|
+
piece = walker.decode_token(tok.token)
|
|
301
|
+
if not piece:
|
|
302
|
+
continue
|
|
303
|
+
answer_text.append(piece, style=style_for_token_probability(tok.prob))
|
|
304
|
+
probability_text = f"{branch.probability * 100:6.2f}%"
|
|
305
|
+
status: Text
|
|
306
|
+
if branch.finish_reason == "eos_token":
|
|
307
|
+
status = Text("✓", style="green")
|
|
308
|
+
elif branch.finish_reason == "low_probability":
|
|
309
|
+
status = Text("✓", style="yellow")
|
|
310
|
+
elif branch.finish_reason == "pruned":
|
|
311
|
+
status = Text("✓", style="dim")
|
|
312
|
+
else:
|
|
313
|
+
status = Text("")
|
|
314
|
+
|
|
315
|
+
table.add_row(status, probability_text, answer_text)
|
|
316
|
+
|
|
317
|
+
return table
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def render_stats_bar(walker: PromptTreeSearch) -> Table:
|
|
321
|
+
elapsed = (datetime.now() - walker._start).total_seconds() if walker._start else 0.0
|
|
322
|
+
tps = walker.tokens / elapsed if elapsed > 0 else 0.0
|
|
323
|
+
left = f"active {walker.active} pruned {walker.pruned} tps {tps:0.1f}"
|
|
324
|
+
grid = Table.grid(expand=True)
|
|
325
|
+
grid.add_column(ratio=1)
|
|
326
|
+
grid.add_column(justify="right", no_wrap=True)
|
|
327
|
+
grid.add_row(
|
|
328
|
+
Text(left, overflow="ellipsis", no_wrap=True),
|
|
329
|
+
Text(
|
|
330
|
+
f"top_k={args.top_k} top_p={args.top_p} temp={args.temperature}",
|
|
331
|
+
no_wrap=True,
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
return grid
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def render_view(walker: PromptTreeSearch) -> Group:
|
|
338
|
+
return Group(
|
|
339
|
+
_PROBABILITY_LEGEND,
|
|
340
|
+
render_branches(walker),
|
|
341
|
+
render_stats_bar(walker),
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def run() -> None:
|
|
346
|
+
load_resp = load(args.model)
|
|
347
|
+
model = load_resp[0]
|
|
348
|
+
tokenizer = load_resp[1]
|
|
349
|
+
|
|
350
|
+
prompt = tokenizer.apply_chat_template( # type: ignore[call-arg]
|
|
351
|
+
[{"role": "user", "content": args.prompt}],
|
|
352
|
+
add_generation_prompt=True,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
console = Console()
|
|
356
|
+
|
|
357
|
+
walker = PromptTreeSearch(model, tokenizer, prompt)
|
|
358
|
+
walker.begin()
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
with Live(console=console, transient=False) as live:
|
|
362
|
+
interval = max(0.1, args.stats_interval)
|
|
363
|
+
next_render = time.monotonic()
|
|
364
|
+
live.update(render_view(walker))
|
|
365
|
+
while not walker.should_stop():
|
|
366
|
+
walker.step()
|
|
367
|
+
if args.stats_interval > 0 and time.monotonic() >= next_render:
|
|
368
|
+
live.update(render_view(walker))
|
|
369
|
+
next_render = time.monotonic() + interval
|
|
370
|
+
live.update(render_view(walker))
|
|
371
|
+
except KeyboardInterrupt:
|
|
372
|
+
walker.stop()
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def parse_args(argv: list[str] | None) -> argparse.Namespace:
|
|
376
|
+
parser = argparse.ArgumentParser()
|
|
377
|
+
parser.add_argument("-p", "--prompt", default="What is 2+2?", help="Prompt to score")
|
|
378
|
+
parser.add_argument("-m", "--model", default="mlx-community/Llama-3.2-1B-Instruct-4bit")
|
|
379
|
+
parser.add_argument("-n", default=10, type=int, help="Number of answers to show")
|
|
380
|
+
parser.add_argument("--min-probability", type=float, default=0.0001)
|
|
381
|
+
parser.add_argument("--top-k", dest="top_k", default=50, type=int)
|
|
382
|
+
parser.add_argument(
|
|
383
|
+
"--top-p",
|
|
384
|
+
dest="top_p",
|
|
385
|
+
default=1.0,
|
|
386
|
+
type=float,
|
|
387
|
+
help="Nucleus sampling threshold (0 < p <= 1)",
|
|
388
|
+
)
|
|
389
|
+
parser.add_argument("--temperature", type=float, default=1.0, help="Softmax temperature (> 0)")
|
|
390
|
+
parser.add_argument(
|
|
391
|
+
"--stats-interval",
|
|
392
|
+
type=float,
|
|
393
|
+
default=0.1,
|
|
394
|
+
help="Seconds between live stats bar updates (<=0 disables)",
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
raw = list(sys.argv[1:] if argv is None else argv)
|
|
398
|
+
filtered = [a for a in raw if a != "--"]
|
|
399
|
+
parsed = parser.parse_args(filtered)
|
|
400
|
+
|
|
401
|
+
if parsed.temperature <= 0:
|
|
402
|
+
parser.error("--temperature must be > 0")
|
|
403
|
+
if not (0 < parsed.top_p <= 1):
|
|
404
|
+
parser.error("--top-p must be in the range (0, 1]")
|
|
405
|
+
|
|
406
|
+
return parsed
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def main(argv: list[str] | None = None) -> None:
|
|
410
|
+
global args
|
|
411
|
+
args = parse_args(argv)
|
|
412
|
+
run()
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling>=1.22.0"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "llmwalk"
|
|
7
|
+
version = "0.1.2"
|
|
8
|
+
description = "Explore the answer-space for any prompt and any MLX-supported model."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"mlx-lm==0.28.4",
|
|
13
|
+
"rich==14.2.0",
|
|
14
|
+
"sortedcontainers==2.4.0",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[project.scripts]
|
|
18
|
+
llmwalk = "llmwalk.cli:main"
|
|
19
|
+
|
|
20
|
+
[tool.hatch.build]
|
|
21
|
+
exclude = [
|
|
22
|
+
"example1.gif",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[tool.hatch.build.targets.wheel]
|
|
26
|
+
packages = ["llmwalk"]
|