lalamo 0.6.4__tar.gz → 0.6.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. {lalamo-0.6.4 → lalamo-0.6.5}/PKG-INFO +1 -1
  2. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/__init__.py +1 -1
  3. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/common.py +55 -1
  4. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/speculator/estimator.py +1 -1
  5. lalamo-0.6.5/lalamo/speculator/inference.py +112 -0
  6. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo.egg-info/PKG-INFO +1 -1
  7. lalamo-0.6.4/lalamo/speculator/inference.py +0 -101
  8. {lalamo-0.6.4 → lalamo-0.6.5}/LICENSE +0 -0
  9. {lalamo-0.6.4 → lalamo-0.6.5}/README.md +0 -0
  10. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/commands.py +0 -0
  11. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/data/__init__.py +0 -0
  12. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/data/huggingface_message.py +0 -0
  13. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/data/lalamo_completions.py +0 -0
  14. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/data/utils.py +0 -0
  15. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/main.py +0 -0
  16. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/message_processor.py +0 -0
  17. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/__init__.py +0 -0
  18. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/common.py +0 -0
  19. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  20. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/common.py +0 -0
  21. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  22. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  23. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  24. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  25. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  26. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  27. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  28. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  29. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  30. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  31. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  32. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  33. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  34. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/huggingface_generation_config.py +0 -0
  35. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  36. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/loaders/__init__.py +0 -0
  37. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/loaders/common.py +0 -0
  38. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/loaders/executorch.py +0 -0
  39. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/loaders/huggingface.py +0 -0
  40. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/loaders/utils.py +0 -0
  41. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/__init__.py +0 -0
  42. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/common.py +0 -0
  43. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/deepseek.py +0 -0
  44. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  45. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/gemma.py +0 -0
  46. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  47. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/huggingface.py +0 -0
  48. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/lfm2.py +0 -0
  49. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/llama.py +0 -0
  50. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/llamba.py +0 -0
  51. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/mirai.py +0 -0
  52. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/mistral.py +0 -0
  53. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/pleias.py +0 -0
  54. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/polaris.py +0 -0
  55. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/qwen.py +0 -0
  56. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/model_import/model_specs/reka.py +0 -0
  57. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/models/__init__.py +0 -0
  58. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/models/classifier.py +0 -0
  59. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/models/common.py +0 -0
  60. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/models/language_model.py +0 -0
  61. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/__init__.py +0 -0
  62. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/activations.py +0 -0
  63. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/classifier.py +0 -0
  64. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/common.py +0 -0
  65. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/decoder.py +0 -0
  66. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/embedding.py +0 -0
  67. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/linear.py +0 -0
  68. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/mlp.py +0 -0
  69. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/mlx_interop.py +0 -0
  70. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/normalization.py +0 -0
  71. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/rope.py +0 -0
  72. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/__init__.py +0 -0
  73. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/attention.py +0 -0
  74. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/common.py +0 -0
  75. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/mamba.py +0 -0
  76. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/short_conv.py +0 -0
  77. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  78. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/common.py +0 -0
  79. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  80. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  81. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  82. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/torch_interop.py +0 -0
  83. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/transformer.py +0 -0
  84. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/transformer_layer.py +0 -0
  85. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/modules/utils.py +0 -0
  86. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/quantization.py +0 -0
  87. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/registry_abc.py +0 -0
  88. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/safetensors.py +0 -0
  89. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/sampling.py +0 -0
  90. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/speculator/__init__.py +0 -0
  91. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/speculator/common.py +0 -0
  92. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/speculator/ngram.py +0 -0
  93. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/speculator/utils.py +0 -0
  94. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo/utils.py +0 -0
  95. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo.egg-info/SOURCES.txt +0 -0
  96. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo.egg-info/dependency_links.txt +0 -0
  97. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo.egg-info/entry_points.txt +0 -0
  98. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo.egg-info/requires.txt +0 -0
  99. {lalamo-0.6.4 → lalamo-0.6.5}/lalamo.egg-info/top_level.txt +0 -0
  100. {lalamo-0.6.4 → lalamo-0.6.5}/pyproject.toml +0 -0
  101. {lalamo-0.6.4 → lalamo-0.6.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.4
3
+ Version: 0.6.5
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -32,7 +32,7 @@ from lalamo.speculator import (
32
32
  SpeculatorTrainingEvent,
33
33
  )
34
34
 
35
- __version__ = "0.6.4"
35
+ __version__ = "0.6.5"
36
36
 
37
37
  __all__ = [
38
38
  "AssistantMessage",
@@ -1,9 +1,11 @@
1
+ import warnings
1
2
  from collections import defaultdict
2
- from collections.abc import Mapping, Sequence
3
+ from collections.abc import Callable, Iterable, Mapping, Sequence
3
4
  from typing import cast
4
5
 
5
6
  import jax.numpy as jnp
6
7
  from jax._src.api import ShapeDtypeStruct
8
+ from jax.errors import JaxRuntimeError
7
9
  from jaxtyping import Array, DTypeLike
8
10
 
9
11
  from lalamo.utils import MapDictValues, MapSequence
@@ -11,8 +13,10 @@ from lalamo.utils import MapDictValues, MapSequence
11
13
  __all__ = [
12
14
  "DEFAULT_PRECISION",
13
15
  "ArrayLike",
16
+ "LalamoWarning",
14
17
  "ParameterPath",
15
18
  "ParameterTree",
19
+ "decrease_batchsize_on_oom",
16
20
  "dummy_array",
17
21
  "flatten_parameters",
18
22
  "require_array",
@@ -23,6 +27,10 @@ __all__ = [
23
27
  DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
24
28
 
25
29
 
30
+ class LalamoWarning(UserWarning):
31
+ """Custom warning class for Lalamo-specific warnings."""
32
+
33
+
26
34
  type ArrayLike = Array | ShapeDtypeStruct
27
35
 
28
36
 
@@ -121,3 +129,49 @@ class ParameterPath(str):
121
129
  if not self:
122
130
  return ParameterPath(str(other))
123
131
  return ParameterPath(self + "." + str(other))
132
+
133
+
134
+ def decrease_batchsize_on_oom[T](
135
+ fn: Callable[[int], Iterable[T]],
136
+ starting_batch_size: int,
137
+ ) -> Iterable[T]:
138
+ """
139
+ Execute fn(batch_size) with automatic batch size reduction on OOM.
140
+ Only reduces batch size if OOM happened on the first batch.
141
+
142
+ Args:
143
+ fn: Function that takes batch_size and returns an iterable
144
+ starting_batch_size: Initial batch size to try
145
+
146
+ Yields:
147
+ Results from fn(batch_size)
148
+
149
+ Raises:
150
+ JaxRuntimeError: If OOM occurs after first batch completes or at batch_size=1
151
+ """
152
+ first_batch_completed = False
153
+ effective_batch_size = starting_batch_size
154
+
155
+ while True:
156
+ try:
157
+ for result in fn(effective_batch_size):
158
+ yield result
159
+
160
+ # as soon as we yielded we are not allowed to retry anymore
161
+ # to make sure we don't ever miss/duplicate outputs
162
+ first_batch_completed = True
163
+ break
164
+ except JaxRuntimeError:
165
+ if first_batch_completed:
166
+ raise
167
+ # because OOM's sometimes generate stuff that won't be garbage collected,
168
+ # we need to be very aggressive with decreasing batchsize here
169
+ new_bs = max(int(0.7 * effective_batch_size - 1), 1)
170
+ if new_bs == 1 and effective_batch_size == 1:
171
+ raise
172
+ warnings.warn(
173
+ f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
174
+ LalamoWarning,
175
+ stacklevel=3,
176
+ )
177
+ effective_batch_size = new_bs
@@ -41,7 +41,7 @@ def get_default_device_bytes() -> int | None:
41
41
 
42
42
  def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
43
  # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
- return int(limit_bytes * 0.95)
44
+ return int(limit_bytes * 0.93)
45
45
 
46
46
 
47
47
  def estimate_memory_from_batchsize(
@@ -0,0 +1,112 @@
1
+ import functools
2
+ from collections.abc import Callable, Iterable
3
+ from itertools import batched, chain, islice
4
+ from typing import NamedTuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jax._src.stages import Compiled
9
+
10
+ from lalamo.common import decrease_batchsize_on_oom
11
+ from lalamo.data.lalamo_completions import LalamoCompletion
12
+ from lalamo.data.utils import get_prefixes_ending_in_user_message
13
+ from lalamo.message_processor import Message
14
+ from lalamo.models import LanguageModel
15
+
16
+
17
+ class CollectTracesEvent(NamedTuple):
18
+ sequences_processed: int
19
+ tokens_generated: int
20
+
21
+
22
+ def inference_collect_traces(
23
+ model: LanguageModel,
24
+ conversations: Iterable[Iterable[Message]],
25
+ num_top_logits_to_collect: int = 8,
26
+ batch_size: int = 1,
27
+ max_input_length: int = 1024,
28
+ max_output_length: int = 1024,
29
+ tokens_to_generate: int | None = None,
30
+ progress_callback: Callable[[CollectTracesEvent], None] | None = None,
31
+ ) -> Iterable[LalamoCompletion]:
32
+ def make_generate_tokens_compiled(batch_size: int) -> Compiled:
33
+ return (
34
+ jax.jit(
35
+ functools.partial(
36
+ LanguageModel.generate_tokens,
37
+ max_output_length=max_output_length,
38
+ num_top_logits_to_return=num_top_logits_to_collect,
39
+ ),
40
+ )
41
+ .lower(
42
+ model,
43
+ prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
44
+ prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
45
+ )
46
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
47
+ # 0 - no autotune, gpu shouldn't be touched
48
+ # 1 - basic level, gpu should be touched veeery little
49
+ # 2,3 - gpu touched more and more
50
+ # 4 (default) - gpu might allocate more memory than the run would require!
51
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
52
+ )
53
+
54
+ prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
55
+
56
+ tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
57
+ filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
58
+
59
+ test_batch = list(islice(filtered_prefixes, batch_size))
60
+
61
+ def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
62
+ tokens_generated, sequences_processed = 0, 0
63
+ generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
64
+ for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
65
+ batch_padding = batch_size - len(real_batch)
66
+ batch = (*real_batch, *(([0],) * batch_padding))
67
+
68
+ length_without_padding = jnp.array(list(map(len, batch)))
69
+
70
+ padded = jnp.array(
71
+ [
72
+ jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
73
+ for tokens in batch
74
+ ],
75
+ )
76
+
77
+ generated = generate_tokens_compiled(
78
+ model,
79
+ prompt_token_ids=padded,
80
+ prompt_lengths_without_padding=length_without_padding,
81
+ )
82
+
83
+ assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
84
+
85
+ for conv_idx in range(len(real_batch)):
86
+ token_ids = generated.token_ids[conv_idx].tolist()
87
+ seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
88
+ if tokens_to_generate is not None:
89
+ seqlen = min(seqlen, tokens_to_generate - tokens_generated)
90
+ tokens_generated += seqlen
91
+ sequences_processed += 1
92
+
93
+ token_ids = token_ids[:seqlen]
94
+ token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
95
+ token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
96
+ token_logits = [
97
+ dict(zip(keys, values, strict=True))
98
+ for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
99
+ ]
100
+
101
+ yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
102
+
103
+ if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
104
+ break
105
+
106
+ if progress_callback is not None:
107
+ progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
108
+
109
+ if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
110
+ break
111
+
112
+ yield from decrease_batchsize_on_oom(collect_traces_body, batch_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.4
3
+ Version: 0.6.5
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,101 +0,0 @@
1
- import functools
2
- from collections.abc import Callable, Iterable
3
- from itertools import batched, chain
4
- from typing import NamedTuple
5
-
6
- import jax
7
- import jax.numpy as jnp
8
-
9
- from lalamo.data.lalamo_completions import LalamoCompletion
10
- from lalamo.data.utils import get_prefixes_ending_in_user_message
11
- from lalamo.message_processor import Message
12
- from lalamo.models import LanguageModel
13
-
14
-
15
- class CollectTracesEvent(NamedTuple):
16
- sequences_processed: int
17
- tokens_generated: int
18
-
19
-
20
- def inference_collect_traces(
21
- model: LanguageModel,
22
- conversations: Iterable[Iterable[Message]],
23
- num_top_logits_to_collect: int = 8,
24
- batch_size: int = 1,
25
- max_input_length: int = 1024,
26
- max_output_length: int = 1024,
27
- tokens_to_generate: int | None = None,
28
- progress_callback: Callable[[CollectTracesEvent], None] | None = None,
29
- ) -> Iterable[LalamoCompletion]:
30
- generate_tokens_compiled = (
31
- jax.jit(
32
- functools.partial(
33
- LanguageModel.generate_tokens,
34
- max_output_length=max_output_length,
35
- num_top_logits_to_return=num_top_logits_to_collect,
36
- ),
37
- )
38
- .lower(
39
- model,
40
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
- )
43
- # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
- # 0 - no autotune, gpu shouldn't be touched
45
- # 1 - basic level, gpu should be touched veeery little
46
- # 2,3 - gpu touched more and more
47
- # 4 (default) - gpu might allocate more memory than the run would require!
48
- .compile(compiler_options={"xla_gpu_autotune_level": "2"})
49
- )
50
-
51
- prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
52
-
53
- tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
54
- filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
55
-
56
- tokens_generated, sequences_processed = 0, 0
57
-
58
- for real_batch in batched(filtered_prefixes, n=batch_size):
59
- batch_padding = batch_size - len(real_batch)
60
- batch = (*real_batch, *(([0],) * batch_padding))
61
-
62
- length_without_padding = jnp.array(list(map(len, batch)))
63
-
64
- padded = jnp.array(
65
- [jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0) for tokens in batch],
66
- )
67
-
68
- generated = generate_tokens_compiled(
69
- model,
70
- prompt_token_ids=padded,
71
- prompt_lengths_without_padding=length_without_padding,
72
- )
73
-
74
- assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
75
-
76
- for conv_idx in range(len(real_batch)):
77
- token_ids = generated.token_ids[conv_idx].tolist()
78
- seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
79
- if tokens_to_generate is not None:
80
- seqlen = min(seqlen, tokens_to_generate - tokens_generated)
81
- tokens_generated += seqlen
82
- sequences_processed += 1
83
-
84
- token_ids = token_ids[:seqlen]
85
- token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
86
- token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
87
- token_logits = [
88
- dict(zip(keys, values, strict=True))
89
- for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
90
- ]
91
-
92
- yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
93
-
94
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
95
- break
96
-
97
- if progress_callback is not None:
98
- progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
99
-
100
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
101
- break
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes