lalamo 0.6.3__tar.gz → 0.6.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.6.3 → lalamo-0.6.5}/PKG-INFO +1 -1
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/__init__.py +1 -1
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/common.py +55 -1
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/main.py +18 -4
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/speculator/estimator.py +11 -9
- lalamo-0.6.5/lalamo/speculator/inference.py +112 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo.egg-info/PKG-INFO +1 -1
- lalamo-0.6.3/lalamo/speculator/inference.py +0 -101
- {lalamo-0.6.3 → lalamo-0.6.5}/LICENSE +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/README.md +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/commands.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/data/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/data/utils.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/message_processor.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/loaders/huggingface.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/essential_ai.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/lfm2.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/mirai.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/models/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/models/classifier.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/models/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/models/language_model.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/activations.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/classifier.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/decoder.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/linear.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/mlp.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/rope.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/attention.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/mamba.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/short_conv.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/transformer.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/transformer_layer.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/modules/utils.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/quantization.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/registry_abc.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/safetensors.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/sampling.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/speculator/common.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo/utils.py +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo.egg-info/SOURCES.txt +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/pyproject.toml +0 -0
- {lalamo-0.6.3 → lalamo-0.6.5}/setup.cfg +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
from collections import defaultdict
|
|
2
|
-
from collections.abc import Mapping, Sequence
|
|
3
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
3
4
|
from typing import cast
|
|
4
5
|
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
from jax._src.api import ShapeDtypeStruct
|
|
8
|
+
from jax.errors import JaxRuntimeError
|
|
7
9
|
from jaxtyping import Array, DTypeLike
|
|
8
10
|
|
|
9
11
|
from lalamo.utils import MapDictValues, MapSequence
|
|
@@ -11,8 +13,10 @@ from lalamo.utils import MapDictValues, MapSequence
|
|
|
11
13
|
__all__ = [
|
|
12
14
|
"DEFAULT_PRECISION",
|
|
13
15
|
"ArrayLike",
|
|
16
|
+
"LalamoWarning",
|
|
14
17
|
"ParameterPath",
|
|
15
18
|
"ParameterTree",
|
|
19
|
+
"decrease_batchsize_on_oom",
|
|
16
20
|
"dummy_array",
|
|
17
21
|
"flatten_parameters",
|
|
18
22
|
"require_array",
|
|
@@ -23,6 +27,10 @@ __all__ = [
|
|
|
23
27
|
DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
|
|
24
28
|
|
|
25
29
|
|
|
30
|
+
class LalamoWarning(UserWarning):
|
|
31
|
+
"""Custom warning class for Lalamo-specific warnings."""
|
|
32
|
+
|
|
33
|
+
|
|
26
34
|
type ArrayLike = Array | ShapeDtypeStruct
|
|
27
35
|
|
|
28
36
|
|
|
@@ -121,3 +129,49 @@ class ParameterPath(str):
|
|
|
121
129
|
if not self:
|
|
122
130
|
return ParameterPath(str(other))
|
|
123
131
|
return ParameterPath(self + "." + str(other))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def decrease_batchsize_on_oom[T](
|
|
135
|
+
fn: Callable[[int], Iterable[T]],
|
|
136
|
+
starting_batch_size: int,
|
|
137
|
+
) -> Iterable[T]:
|
|
138
|
+
"""
|
|
139
|
+
Execute fn(batch_size) with automatic batch size reduction on OOM.
|
|
140
|
+
Only reduces batch size if OOM happened on the first batch.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
fn: Function that takes batch_size and returns an iterable
|
|
144
|
+
starting_batch_size: Initial batch size to try
|
|
145
|
+
|
|
146
|
+
Yields:
|
|
147
|
+
Results from fn(batch_size)
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
JaxRuntimeError: If OOM occurs after first batch completes or at batch_size=1
|
|
151
|
+
"""
|
|
152
|
+
first_batch_completed = False
|
|
153
|
+
effective_batch_size = starting_batch_size
|
|
154
|
+
|
|
155
|
+
while True:
|
|
156
|
+
try:
|
|
157
|
+
for result in fn(effective_batch_size):
|
|
158
|
+
yield result
|
|
159
|
+
|
|
160
|
+
# as soon as we yielded we are not allowed to retry anymore
|
|
161
|
+
# to make sure we don't ever miss/duplicate outputs
|
|
162
|
+
first_batch_completed = True
|
|
163
|
+
break
|
|
164
|
+
except JaxRuntimeError:
|
|
165
|
+
if first_batch_completed:
|
|
166
|
+
raise
|
|
167
|
+
# because OOM's sometimes generate stuff that won't be garbage collected,
|
|
168
|
+
# we need to be very aggressive with decreasing batchsize here
|
|
169
|
+
new_bs = max(int(0.7 * effective_batch_size - 1), 1)
|
|
170
|
+
if new_bs == 1 and effective_batch_size == 1:
|
|
171
|
+
raise
|
|
172
|
+
warnings.warn(
|
|
173
|
+
f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
|
|
174
|
+
LalamoWarning,
|
|
175
|
+
stacklevel=3,
|
|
176
|
+
)
|
|
177
|
+
effective_batch_size = new_bs
|
|
@@ -49,7 +49,10 @@ from lalamo.message_processor import UserMessage
|
|
|
49
49
|
from lalamo.model_import import REPO_TO_MODEL, ModelSpec
|
|
50
50
|
from lalamo.model_import.common import FileSpec
|
|
51
51
|
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
52
|
-
from lalamo.speculator.estimator import
|
|
52
|
+
from lalamo.speculator.estimator import (
|
|
53
|
+
get_default_device_bytes,
|
|
54
|
+
get_usable_memory_from_bytes,
|
|
55
|
+
)
|
|
53
56
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
54
57
|
from lalamo.speculator.utils import test_speculator
|
|
55
58
|
|
|
@@ -384,6 +387,7 @@ class CliTraceCallbacks(TraceCallbacks):
|
|
|
384
387
|
self.stack.close()
|
|
385
388
|
console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
|
|
386
389
|
|
|
390
|
+
|
|
387
391
|
@app.command(help="Trace a model.")
|
|
388
392
|
def trace(
|
|
389
393
|
model_path: Annotated[
|
|
@@ -557,14 +561,24 @@ def estimate_batchsize(
|
|
|
557
561
|
] = None,
|
|
558
562
|
) -> None:
|
|
559
563
|
if vram_gb is not None:
|
|
560
|
-
|
|
561
|
-
|
|
564
|
+
# note that in practice GPUs use GiB in their docs, e.g. H100 actually has 85GB of memory
|
|
565
|
+
mem_bytes = vram_gb * 1000 * 1000 * 1000
|
|
566
|
+
elif (mem_bytes := get_default_device_bytes()) is None:
|
|
562
567
|
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
563
568
|
raise Exit(1)
|
|
564
569
|
|
|
570
|
+
usable_mem = get_usable_memory_from_bytes(mem_bytes)
|
|
571
|
+
|
|
565
572
|
callbacks_type = CliEstimateBatchsizeCallbacks
|
|
566
573
|
|
|
567
|
-
_estimate_batchsize(
|
|
574
|
+
_estimate_batchsize(
|
|
575
|
+
model_path,
|
|
576
|
+
usable_mem,
|
|
577
|
+
max_input_length,
|
|
578
|
+
max_output_length,
|
|
579
|
+
num_logits_per_token,
|
|
580
|
+
callbacks_type,
|
|
581
|
+
)
|
|
568
582
|
|
|
569
583
|
|
|
570
584
|
@dataclass
|
|
@@ -10,7 +10,7 @@ import jax.numpy as jnp
|
|
|
10
10
|
from lalamo.models import LanguageModel
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def get_default_device_bytes() -> int | None:
|
|
14
14
|
dynamic_allocate = False
|
|
15
15
|
|
|
16
16
|
preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
|
|
@@ -26,20 +26,22 @@ def get_default_device_memory() -> int | None:
|
|
|
26
26
|
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
27
27
|
return None
|
|
28
28
|
|
|
29
|
-
# discount by 0.98 because if you try to allocate exactly bytes_limit
|
|
30
|
-
# jax will _try_ to allocate a bit more, and then throws an OOM
|
|
31
|
-
bytes_limit = memory_stats["bytes_limit"] * 0.98
|
|
32
|
-
|
|
33
29
|
mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
|
|
34
30
|
try:
|
|
35
31
|
mem_fraction = float(mem_fraction_raw)
|
|
36
32
|
except ValueError:
|
|
37
33
|
mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
|
|
38
34
|
|
|
39
|
-
#
|
|
40
|
-
#
|
|
41
|
-
|
|
42
|
-
|
|
35
|
+
# 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
|
|
36
|
+
# so it should correspond to something you'd see in nvidia-smi
|
|
37
|
+
memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
|
|
38
|
+
|
|
39
|
+
return get_usable_memory_from_bytes(memory_limit)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_usable_memory_from_bytes(limit_bytes: int) -> int:
|
|
43
|
+
# JAX allocates a bit more than it needs, so we discount it by some safety factor
|
|
44
|
+
return int(limit_bytes * 0.93)
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
def estimate_memory_from_batchsize(
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
|
+
from itertools import batched, chain, islice
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jax._src.stages import Compiled
|
|
9
|
+
|
|
10
|
+
from lalamo.common import decrease_batchsize_on_oom
|
|
11
|
+
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
12
|
+
from lalamo.data.utils import get_prefixes_ending_in_user_message
|
|
13
|
+
from lalamo.message_processor import Message
|
|
14
|
+
from lalamo.models import LanguageModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CollectTracesEvent(NamedTuple):
|
|
18
|
+
sequences_processed: int
|
|
19
|
+
tokens_generated: int
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def inference_collect_traces(
|
|
23
|
+
model: LanguageModel,
|
|
24
|
+
conversations: Iterable[Iterable[Message]],
|
|
25
|
+
num_top_logits_to_collect: int = 8,
|
|
26
|
+
batch_size: int = 1,
|
|
27
|
+
max_input_length: int = 1024,
|
|
28
|
+
max_output_length: int = 1024,
|
|
29
|
+
tokens_to_generate: int | None = None,
|
|
30
|
+
progress_callback: Callable[[CollectTracesEvent], None] | None = None,
|
|
31
|
+
) -> Iterable[LalamoCompletion]:
|
|
32
|
+
def make_generate_tokens_compiled(batch_size: int) -> Compiled:
|
|
33
|
+
return (
|
|
34
|
+
jax.jit(
|
|
35
|
+
functools.partial(
|
|
36
|
+
LanguageModel.generate_tokens,
|
|
37
|
+
max_output_length=max_output_length,
|
|
38
|
+
num_top_logits_to_return=num_top_logits_to_collect,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
.lower(
|
|
42
|
+
model,
|
|
43
|
+
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
44
|
+
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
45
|
+
)
|
|
46
|
+
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
47
|
+
# 0 - no autotune, gpu shouldn't be touched
|
|
48
|
+
# 1 - basic level, gpu should be touched veeery little
|
|
49
|
+
# 2,3 - gpu touched more and more
|
|
50
|
+
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
51
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
55
|
+
|
|
56
|
+
tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
|
|
57
|
+
filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
|
|
58
|
+
|
|
59
|
+
test_batch = list(islice(filtered_prefixes, batch_size))
|
|
60
|
+
|
|
61
|
+
def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
|
|
62
|
+
tokens_generated, sequences_processed = 0, 0
|
|
63
|
+
generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
|
|
64
|
+
for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
|
|
65
|
+
batch_padding = batch_size - len(real_batch)
|
|
66
|
+
batch = (*real_batch, *(([0],) * batch_padding))
|
|
67
|
+
|
|
68
|
+
length_without_padding = jnp.array(list(map(len, batch)))
|
|
69
|
+
|
|
70
|
+
padded = jnp.array(
|
|
71
|
+
[
|
|
72
|
+
jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
|
|
73
|
+
for tokens in batch
|
|
74
|
+
],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
generated = generate_tokens_compiled(
|
|
78
|
+
model,
|
|
79
|
+
prompt_token_ids=padded,
|
|
80
|
+
prompt_lengths_without_padding=length_without_padding,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
|
|
84
|
+
|
|
85
|
+
for conv_idx in range(len(real_batch)):
|
|
86
|
+
token_ids = generated.token_ids[conv_idx].tolist()
|
|
87
|
+
seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
|
|
88
|
+
if tokens_to_generate is not None:
|
|
89
|
+
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
90
|
+
tokens_generated += seqlen
|
|
91
|
+
sequences_processed += 1
|
|
92
|
+
|
|
93
|
+
token_ids = token_ids[:seqlen]
|
|
94
|
+
token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
|
|
95
|
+
token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
|
|
96
|
+
token_logits = [
|
|
97
|
+
dict(zip(keys, values, strict=True))
|
|
98
|
+
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
|
|
102
|
+
|
|
103
|
+
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
104
|
+
break
|
|
105
|
+
|
|
106
|
+
if progress_callback is not None:
|
|
107
|
+
progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
|
|
108
|
+
|
|
109
|
+
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
yield from decrease_batchsize_on_oom(collect_traces_body, batch_size)
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
from collections.abc import Callable, Iterable
|
|
3
|
-
from itertools import batched, chain
|
|
4
|
-
from typing import NamedTuple
|
|
5
|
-
|
|
6
|
-
import jax
|
|
7
|
-
import jax.numpy as jnp
|
|
8
|
-
|
|
9
|
-
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
10
|
-
from lalamo.data.utils import get_prefixes_ending_in_user_message
|
|
11
|
-
from lalamo.message_processor import Message
|
|
12
|
-
from lalamo.models import LanguageModel
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class CollectTracesEvent(NamedTuple):
|
|
16
|
-
sequences_processed: int
|
|
17
|
-
tokens_generated: int
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def inference_collect_traces(
|
|
21
|
-
model: LanguageModel,
|
|
22
|
-
conversations: Iterable[Iterable[Message]],
|
|
23
|
-
num_top_logits_to_collect: int = 8,
|
|
24
|
-
batch_size: int = 1,
|
|
25
|
-
max_input_length: int = 1024,
|
|
26
|
-
max_output_length: int = 1024,
|
|
27
|
-
tokens_to_generate: int | None = None,
|
|
28
|
-
progress_callback: Callable[[CollectTracesEvent], None] | None = None,
|
|
29
|
-
) -> Iterable[LalamoCompletion]:
|
|
30
|
-
generate_tokens_compiled = (
|
|
31
|
-
jax.jit(
|
|
32
|
-
functools.partial(
|
|
33
|
-
LanguageModel.generate_tokens,
|
|
34
|
-
max_output_length=max_output_length,
|
|
35
|
-
num_top_logits_to_return=num_top_logits_to_collect,
|
|
36
|
-
),
|
|
37
|
-
)
|
|
38
|
-
.lower(
|
|
39
|
-
model,
|
|
40
|
-
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
41
|
-
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
42
|
-
)
|
|
43
|
-
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
44
|
-
# 0 - no autotune, gpu shouldn't be touched
|
|
45
|
-
# 1 - basic level, gpu should be touched veeery little
|
|
46
|
-
# 2,3 - gpu touched more and more
|
|
47
|
-
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
48
|
-
.compile(compiler_options={"xla_gpu_autotune_level": "2"})
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
52
|
-
|
|
53
|
-
tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
|
|
54
|
-
filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
|
|
55
|
-
|
|
56
|
-
tokens_generated, sequences_processed = 0, 0
|
|
57
|
-
|
|
58
|
-
for real_batch in batched(filtered_prefixes, n=batch_size):
|
|
59
|
-
batch_padding = batch_size - len(real_batch)
|
|
60
|
-
batch = (*real_batch, *(([0],) * batch_padding))
|
|
61
|
-
|
|
62
|
-
length_without_padding = jnp.array(list(map(len, batch)))
|
|
63
|
-
|
|
64
|
-
padded = jnp.array(
|
|
65
|
-
[jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0) for tokens in batch],
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
generated = generate_tokens_compiled(
|
|
69
|
-
model,
|
|
70
|
-
prompt_token_ids=padded,
|
|
71
|
-
prompt_lengths_without_padding=length_without_padding,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
|
|
75
|
-
|
|
76
|
-
for conv_idx in range(len(real_batch)):
|
|
77
|
-
token_ids = generated.token_ids[conv_idx].tolist()
|
|
78
|
-
seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
|
|
79
|
-
if tokens_to_generate is not None:
|
|
80
|
-
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
81
|
-
tokens_generated += seqlen
|
|
82
|
-
sequences_processed += 1
|
|
83
|
-
|
|
84
|
-
token_ids = token_ids[:seqlen]
|
|
85
|
-
token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
|
|
86
|
-
token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
|
|
87
|
-
token_logits = [
|
|
88
|
-
dict(zip(keys, values, strict=True))
|
|
89
|
-
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
|
|
93
|
-
|
|
94
|
-
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
95
|
-
break
|
|
96
|
-
|
|
97
|
-
if progress_callback is not None:
|
|
98
|
-
progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
|
|
99
|
-
|
|
100
|
-
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
101
|
-
break
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{lalamo-0.6.3 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|