lalamo 0.5.17__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/commands.py +69 -17
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +148 -27
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +4 -2
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/language_model.py +7 -6
  16. lalamo/modules/activations.py +1 -1
  17. lalamo/modules/classifier.py +11 -24
  18. lalamo/modules/common.py +4 -1
  19. lalamo/modules/decoder.py +5 -11
  20. lalamo/modules/embedding.py +25 -62
  21. lalamo/modules/linear.py +19 -33
  22. lalamo/modules/mlp.py +9 -19
  23. lalamo/modules/mlx_interop.py +1 -1
  24. lalamo/modules/rope.py +1 -1
  25. lalamo/modules/token_mixers/__init__.py +1 -1
  26. lalamo/modules/token_mixers/attention.py +9 -27
  27. lalamo/modules/token_mixers/mamba.py +26 -25
  28. lalamo/modules/token_mixers/short_conv.py +7 -14
  29. lalamo/modules/transformer.py +10 -20
  30. lalamo/modules/transformer_layer.py +8 -20
  31. lalamo/registry_abc.py +4 -4
  32. lalamo/sampling.py +14 -0
  33. lalamo/speculator/estimator.py +3 -3
  34. lalamo/speculator/ngram.py +1 -1
  35. {lalamo-0.5.17.dist-info → lalamo-0.6.1.dist-info}/METADATA +1 -1
  36. {lalamo-0.5.17.dist-info → lalamo-0.6.1.dist-info}/RECORD +40 -40
  37. {lalamo-0.5.17.dist-info → lalamo-0.6.1.dist-info}/WHEEL +1 -1
  38. {lalamo-0.5.17.dist-info → lalamo-0.6.1.dist-info}/entry_points.txt +0 -0
  39. {lalamo-0.5.17.dist-info → lalamo-0.6.1.dist-info}/licenses/LICENSE +0 -0
  40. {lalamo-0.5.17.dist-info → lalamo-0.6.1.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ from jax import vmap
10
10
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
12
  from lalamo.common import dummy_array
13
- from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
13
+ from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector, require_array, require_tree
14
14
  from lalamo.modules.linear import LinearBase, LinearConfig
15
15
  from lalamo.modules.normalization import Normalization, NormalizationConfig
16
16
  from lalamo.modules.rope import PositionalEmbeddings
@@ -433,33 +433,15 @@ class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
433
433
  result["sinks"] = self.sinks
434
434
  return result
435
435
 
436
- def import_weights(
437
- self,
438
- weights: ParameterTree[Array],
439
- ) -> Self:
436
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
440
437
  assert isinstance(weights, Mapping)
441
- assert isinstance(weights["qkv_projection"], Mapping)
442
- assert isinstance(weights["out_projection"], Mapping)
443
- if self.query_norm is not None:
444
- assert isinstance(weights["query_norm"], Mapping)
445
- query_norm = self.query_norm.import_weights(weights["query_norm"])
446
- else:
447
- query_norm = None
448
- if self.key_norm is not None:
449
- assert isinstance(weights["key_norm"], Mapping)
450
- key_norm = self.key_norm.import_weights(weights["key_norm"])
451
- else:
452
- key_norm = None
453
- if self.sinks is not None:
454
- assert isinstance(weights["sinks"], Array)
455
- sinks = weights["sinks"]
456
- else:
457
- sinks = None
458
438
  return replace(
459
439
  self,
460
- qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"]),
461
- out_projection=self.out_projection.import_weights(weights["out_projection"]),
462
- query_norm=query_norm,
463
- key_norm=key_norm,
464
- sinks=sinks,
440
+ qkv_projection=self.qkv_projection.import_weights(require_tree(weights["qkv_projection"])),
441
+ out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
442
+ query_norm=self.query_norm.import_weights(require_tree(weights["query_norm"]))
443
+ if self.query_norm
444
+ else None,
445
+ key_norm=self.key_norm.import_weights(require_tree(weights["key_norm"])) if self.key_norm else None,
446
+ sinks=require_array(weights["sinks"]) if self.sinks is not None else None,
465
447
  )
@@ -10,7 +10,7 @@ from einops import einsum, rearrange
10
10
  from jax import vmap
11
11
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
12
 
13
- from lalamo.common import ParameterTree, dummy_array
13
+ from lalamo.common import ParameterTree, dummy_array, require_array, require_tree
14
14
  from lalamo.modules.activations import Activation
15
15
  from lalamo.modules.common import LalamoModule, PositionalEmbeddingSelector
16
16
  from lalamo.modules.linear import LinearBase, LinearConfig
@@ -112,6 +112,7 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
112
112
  def __call__(
113
113
  self,
114
114
  inputs: Float[Array, "suffix_tokens channels"],
115
+ length_without_padding: Int[Array, ""] | int | None = None,
115
116
  state: Float[Array, "prefix_tokens channels"] | None = None,
116
117
  return_updated_state: bool = False,
117
118
  ) -> CausalConvResult:
@@ -136,9 +137,23 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
136
137
  if self.biases is not None:
137
138
  results = results + self.biases
138
139
 
140
+ if return_updated_state:
141
+ if length_without_padding is None:
142
+ length_without_padding = num_suffix_tokens
143
+ length_without_padding = jnp.asarray(length_without_padding, dtype=jnp.int32)
144
+ length_without_padding = jnp.clip(length_without_padding, 0, num_suffix_tokens)
145
+ updated_state = jax.lax.dynamic_slice_in_dim(
146
+ inputs_with_history,
147
+ start_index=length_without_padding,
148
+ slice_size=self.kernel_size - 1,
149
+ axis=0,
150
+ )
151
+ else:
152
+ updated_state = None
153
+
139
154
  return CausalConvResult(
140
155
  results,
141
- (inputs_with_history if return_updated_state else None),
156
+ updated_state,
142
157
  )
143
158
 
144
159
  def export_weights(self) -> ParameterTree:
@@ -149,16 +164,10 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
149
164
 
150
165
  def import_weights(self, weights: ParameterTree[Array]) -> "SeparableCausalConv":
151
166
  assert isinstance(weights, Mapping)
152
- assert isinstance(weights["weights"], Array)
153
- if self.biases is not None:
154
- assert isinstance(weights["biases"], Array)
155
- biases = weights["biases"]
156
- else:
157
- biases = None
158
167
  return replace(
159
168
  self,
160
- weights=weights["weights"],
161
- biases=biases,
169
+ weights=require_array(weights["weights"]),
170
+ biases=require_array(weights["biases"]) if self.biases is not None else None,
162
171
  )
163
172
 
164
173
 
@@ -442,6 +451,7 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
442
451
 
443
452
  conv_output, updated_conv_state = self.conv(
444
453
  conv_inputs,
454
+ length_without_padding,
445
455
  state.conv_state,
446
456
  return_updated_state=return_updated_state,
447
457
  )
@@ -532,22 +542,13 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
532
542
  "gate_bias": self.gate_bias,
533
543
  }
534
544
 
535
- def import_weights(
536
- self,
537
- weights: ParameterTree[Array],
538
- ) -> Self:
545
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
539
546
  assert isinstance(weights, Mapping)
540
- assert isinstance(weights["in_projection"], Mapping)
541
- assert isinstance(weights["out_projection"], Mapping)
542
- assert isinstance(weights["conv"], Mapping)
543
- assert isinstance(weights["skip_connection_weight"], Array)
544
- assert isinstance(weights["gate_bias"], Array)
545
-
546
547
  return replace(
547
548
  self,
548
- in_projection=self.in_projection.import_weights(weights["in_projection"]),
549
- out_projection=self.out_projection.import_weights(weights["out_projection"]),
550
- conv=self.conv.import_weights(weights["conv"]),
551
- skip_connection_weight=weights["skip_connection_weight"],
552
- gate_bias=weights["gate_bias"],
549
+ in_projection=self.in_projection.import_weights(require_tree(weights["in_projection"])),
550
+ out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
551
+ conv=self.conv.import_weights(require_tree(weights["conv"])),
552
+ skip_connection_weight=require_array(weights["skip_connection_weight"]),
553
+ gate_bias=require_array(weights["gate_bias"]),
553
554
  )
@@ -6,7 +6,7 @@ import equinox as eqx
6
6
  from jax import vmap
7
7
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
8
8
 
9
- from lalamo.common import ParameterTree
9
+ from lalamo.common import ParameterTree, require_tree
10
10
  from lalamo.modules.common import PositionalEmbeddingSelector
11
11
  from lalamo.modules.linear import LinearBase, LinearConfig
12
12
  from lalamo.modules.rope import PositionalEmbeddings
@@ -116,7 +116,7 @@ class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
116
116
  positional_embeddings: PositionalEmbeddings | None,
117
117
  state: ShortConvStateLayer | None = None,
118
118
  return_updated_state: bool = False,
119
- length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
119
+ length_without_padding: Int[Array, ""] | int | None = None,
120
120
  ) -> TokenMixerResult[ShortConvStateLayer]:
121
121
  if positional_embeddings is not None:
122
122
  raise ValueError("Positional embeddings are not supported for ShortConv.")
@@ -124,7 +124,7 @@ class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
124
124
  pre_conv_gate, post_conv_gate, x = vmap(self.in_projection)(inputs)
125
125
 
126
126
  prev_conv_state = state.conv_state if state is not None else None
127
- conv_output = self.conv(x * pre_conv_gate, prev_conv_state, return_updated_state)
127
+ conv_output = self.conv(x * pre_conv_gate, length_without_padding, prev_conv_state, return_updated_state)
128
128
 
129
129
  (outputs,) = vmap(self.out_projection)(conv_output.outputs * post_conv_gate)
130
130
  updated_conv_state = conv_output.state
@@ -151,18 +151,11 @@ class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
151
151
  "out_projection": self.out_projection.export_weights(),
152
152
  }
153
153
 
154
- def import_weights(
155
- self,
156
- weights: ParameterTree[Array],
157
- ) -> Self:
154
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
158
155
  assert isinstance(weights, Mapping)
159
- assert isinstance(weights["in_projection"], Mapping)
160
- assert isinstance(weights["conv"], Mapping)
161
- assert isinstance(weights["out_projection"], Mapping)
162
-
163
156
  return replace(
164
157
  self,
165
- in_projection=self.in_projection.import_weights(weights["in_projection"]),
166
- conv=self.conv.import_weights(weights["conv"]),
167
- out_projection=self.out_projection.import_weights(weights["out_projection"]),
158
+ in_projection=self.in_projection.import_weights(require_tree(weights["in_projection"])),
159
+ conv=self.conv.import_weights(require_tree(weights["conv"])),
160
+ out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
168
161
  )
@@ -7,7 +7,7 @@ import jax
7
7
  from jax import vmap
8
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
- from lalamo.common import ParameterTree
10
+ from lalamo.common import ParameterTree, require_tree
11
11
  from lalamo.modules.token_mixers import AttentionConfig
12
12
  from lalamo.modules.utils import vmap_twice
13
13
 
@@ -182,7 +182,8 @@ class Transformer(LalamoModule[TransformerConfig]):
182
182
  ) -> TransformerResult:
183
183
  if inner_features.ndim != 3:
184
184
  raise ValueError(
185
- f"inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim), got {inner_features.shape}",
185
+ "inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim),"
186
+ f" got {inner_features.shape}",
186
187
  )
187
188
  if token_positions.ndim != 2:
188
189
  raise ValueError(
@@ -251,35 +252,24 @@ class Transformer(LalamoModule[TransformerConfig]):
251
252
  result["local_rope"] = self.local_rope.export_weights()
252
253
  return result
253
254
 
254
- def import_weights(
255
- self,
256
- weights: ParameterTree[Array],
257
- ) -> Self:
255
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
258
256
  assert isinstance(weights, Mapping)
259
257
  assert isinstance(weights["layers"], Sequence)
260
- assert isinstance(weights["output_norm"], Mapping)
261
-
262
258
  if self.global_rope:
263
- assert isinstance(weights["global_rope"], Mapping)
264
- global_rope = self.global_rope.import_weights(weights["global_rope"])
259
+ global_rope = self.global_rope.import_weights(require_tree(weights["global_rope"]))
265
260
  else:
266
261
  global_rope = None
267
-
268
262
  if self.local_rope:
269
- assert isinstance(weights["local_rope"], Mapping)
270
- local_rope = self.local_rope.import_weights(weights["local_rope"])
263
+ local_rope = self.local_rope.import_weights(require_tree(weights["local_rope"]))
271
264
  else:
272
265
  local_rope = None
273
-
274
- layers = []
275
- for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
276
- assert isinstance(layer_weights, Mapping)
277
- layers.append(layer.import_weights(layer_weights))
278
-
266
+ layers = [
267
+ layer.import_weights(require_tree(lw)) for layer, lw in zip(self.layers, weights["layers"], strict=True)
268
+ ]
279
269
  return replace(
280
270
  self,
281
271
  global_rope=global_rope,
282
272
  layers=tuple(layers),
283
- output_norm=self.output_norm.import_weights(weights["output_norm"]),
273
+ output_norm=self.output_norm.import_weights(require_tree(weights["output_norm"])),
284
274
  local_rope=local_rope,
285
275
  )
@@ -9,7 +9,7 @@ import jax.numpy as jnp
9
9
  from jax import vmap
10
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
- from lalamo.common import ParameterTree
12
+ from lalamo.common import ParameterTree, require_tree
13
13
 
14
14
  from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
15
15
  from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
@@ -293,38 +293,26 @@ class TransformerLayer(LalamoModule[TransformerLayerConfig]):
293
293
  result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
294
294
  return result
295
295
 
296
- def import_weights(
297
- self,
298
- weights: ParameterTree[Array],
299
- ) -> Self:
296
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
300
297
  assert isinstance(weights, Mapping)
301
- assert isinstance(weights["mixer"], Mapping)
302
- assert isinstance(weights["mlp"], Mapping)
303
- assert isinstance(weights["pre_mlp_norm"], Mapping)
304
-
305
298
  if self.post_mixer_norm is not None:
306
- assert isinstance(weights["post_mixer_norm"], Mapping)
307
- post_mixer_norm = self.post_mixer_norm.import_weights(
308
- weights["post_mixer_norm"],
309
- )
299
+ post_mixer_norm = self.post_mixer_norm.import_weights(require_tree(weights["post_mixer_norm"]))
310
300
  else:
311
301
  post_mixer_norm = None
312
302
  if self.post_mlp_norm is not None:
313
- assert isinstance(weights["post_mlp_norm"], Mapping)
314
- post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
303
+ post_mlp_norm = self.post_mlp_norm.import_weights(require_tree(weights["post_mlp_norm"]))
315
304
  else:
316
305
  post_mlp_norm = None
317
306
  if self.pre_mixer_norm is not None:
318
- assert isinstance(weights["pre_mixer_norm"], Mapping)
319
- pre_mixer_norm = self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"])
307
+ pre_mixer_norm = self.pre_mixer_norm.import_weights(require_tree(weights["pre_mixer_norm"]))
320
308
  else:
321
309
  pre_mixer_norm = None
322
310
  return replace(
323
311
  self,
324
312
  pre_mixer_norm=pre_mixer_norm,
325
- mixer=self.mixer.import_weights(weights["mixer"]),
313
+ mixer=self.mixer.import_weights(require_tree(weights["mixer"])),
326
314
  post_mixer_norm=post_mixer_norm,
327
- pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
328
- mlp=self.mlp.import_weights(weights["mlp"]),
315
+ pre_mlp_norm=self.pre_mlp_norm.import_weights(require_tree(weights["pre_mlp_norm"])),
316
+ mlp=self.mlp.import_weights(require_tree(weights["mlp"])),
329
317
  post_mlp_norm=post_mlp_norm,
330
318
  )
lalamo/registry_abc.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, ABCMeta
2
- from typing import Any
2
+ from typing import Any, Self
3
3
  from weakref import WeakSet
4
4
 
5
5
  __all__ = ["RegistryABC", "RegistryMeta"]
@@ -29,7 +29,7 @@ class RegistryMeta(ABCMeta):
29
29
 
30
30
  # Detect and remember the root exactly once
31
31
  if RegistryMeta._ROOT is None and name == "RegistryABC":
32
- RegistryMeta._ROOT = cls # type: ignore[assignment]
32
+ RegistryMeta._ROOT = cls
33
33
  return
34
34
 
35
35
  root = RegistryMeta._ROOT
@@ -58,6 +58,6 @@ class RegistryABC(ABC, metaclass=RegistryMeta):
58
58
  """
59
59
 
60
60
  @classmethod
61
- def __descendants__(cls) -> tuple[type, ...]:
62
- reg: WeakSet[type] = getattr(cls, RegistryMeta._REG_ATTR) # noqa: SLF001
61
+ def __descendants__(cls) -> tuple[type[Self], ...]:
62
+ reg: WeakSet[type[Self]] = getattr(cls, RegistryMeta._REG_ATTR) # noqa: SLF001
63
63
  return tuple(reg)
lalamo/sampling.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from collections.abc import Iterable
3
+ from math import log
3
4
 
4
5
  import equinox as eqx
5
6
  import jax
@@ -10,6 +11,7 @@ __all__ = [
10
11
  "BanTokensPolicy",
11
12
  "CompositePolicy",
12
13
  "GreedyPolicy",
14
+ "MinPPolicy",
13
15
  "SamplingPolicy",
14
16
  "TemperaturePolicy",
15
17
  "TopKPolicy",
@@ -64,6 +66,15 @@ class TopPPolicy(SamplingPolicy):
64
66
  return jnp.where(to_remove_unsorted, -jnp.inf, logits)
65
67
 
66
68
 
69
+ class MinPPolicy(SamplingPolicy):
70
+ p: float = eqx.field(static=True)
71
+
72
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
73
+ max_logit = jnp.max(logits)
74
+ logit_cutoff = max_logit + log(self.p)
75
+ return jnp.where(logits >= logit_cutoff, logits, -jnp.inf)
76
+
77
+
67
78
  class BanTokensPolicy(SamplingPolicy):
68
79
  banned_tokens: tuple[int, ...] = eqx.field(static=True)
69
80
 
@@ -85,6 +96,7 @@ def make_policy(
85
96
  temperature: float | None = None,
86
97
  top_k: int | None = None,
87
98
  top_p: float | None = None,
99
+ min_p: float | None = None,
88
100
  banned_tokens: Iterable[int] | None = None,
89
101
  ) -> SamplingPolicy:
90
102
  policies = []
@@ -96,4 +108,6 @@ def make_policy(
96
108
  policies.append(TopKPolicy(top_k))
97
109
  if top_p is not None:
98
110
  policies.append(TopPPolicy(top_p))
111
+ if min_p is not None:
112
+ policies.append(MinPPolicy(min_p))
99
113
  return CompositePolicy(tuple(policies))
@@ -46,9 +46,9 @@ def estimate_memory_from_batchsize(
46
46
  assert hasattr(memory_analysis, "temp_size_in_bytes")
47
47
 
48
48
  return (
49
- memory_analysis.argument_size_in_bytes # type: ignore (pyright bug)
50
- + memory_analysis.output_size_in_bytes # type: ignore (pyright bug)
51
- + memory_analysis.temp_size_in_bytes # type: ignore (pyright bug)
49
+ memory_analysis.argument_size_in_bytes
50
+ + memory_analysis.output_size_in_bytes
51
+ + memory_analysis.temp_size_in_bytes
52
52
  )
53
53
 
54
54
 
@@ -129,7 +129,7 @@ class NGramSpeculator(Speculator):
129
129
 
130
130
  return (
131
131
  memoryview(self.ngram_keys)[idx_start:idx_end],
132
- memoryview(self.ngram_values)[idx_start:idx_end], # type: ignore (typechecker bug)
132
+ memoryview(self.ngram_values)[idx_start:idx_end].cast("f"), # noop cast to make typechecker happy
133
133
  memoryview(self.ngram_counts)[seq_hash : (seq_hash + 1)],
134
134
  )
135
135
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.17
3
+ Version: 0.6.1
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,20 +1,20 @@
1
- lalamo/__init__.py,sha256=asVMPmQ7BUt7bYlcuNZ7SnOSJDJUiN9QhlU5lRUehSo,1387
2
- lalamo/commands.py,sha256=rU9T8Mx6s7itpk-dj5ToQ4PUpGPfdmmKlrF02l2kIS0,9967
3
- lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
4
- lalamo/main.py,sha256=dE7Us9L6sfz9bp5rUSzGHUkG0Uon4xdju9dGGtXidZI,23888
5
- lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
1
+ lalamo/__init__.py,sha256=XBfWi6pPtdWFEQRvMxVw8KGoqWxIFq01Z2zBxqNp7BE,1386
2
+ lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
3
+ lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
4
+ lalamo/main.py,sha256=Tez84CtMxUi1ySuRSqQElu4Zr1UWs_Gw6HX1xtCZknQ,27383
5
+ lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
6
6
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
7
- lalamo/registry_abc.py,sha256=ENjXiD_wEH100fNjG-W5Em1L_EQ0Lf0pdRhRGvf3qZk,2197
7
+ lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
8
8
  lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
9
- lalamo/sampling.py,sha256=g_dNiJyZrRqoQIiLid4cr6nRT9N5tSz3GtHr8Bt4n-E,3404
9
+ lalamo/sampling.py,sha256=GE6Av7zS-pr5Bg7FtOivRce7I0JIYuNYqfqsRe-yjQk,3867
10
10
  lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
11
11
  lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
12
12
  lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
13
13
  lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
14
14
  lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
15
15
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
16
- lalamo/model_import/common.py,sha256=wvyGD-iLut_Pm3HjDMI05upqdtCW3HWeoeB0YmiFeqk,12419
17
- lalamo/model_import/huggingface_generation_config.py,sha256=mot6VQ6ezCtEhN6VjhnvaU-nR5P5T2BuBUgpFNnWJxU,1495
16
+ lalamo/model_import/common.py,sha256=MIbvK3mxgrDSXea6jujvCOu9Jjyip6MXeTsJjNTBJAU,12325
17
+ lalamo/model_import/huggingface_generation_config.py,sha256=xicv_kJOfIGlz4gi5fRFIkiAZ9_QRDLRtW8nKMm5tVU,2022
18
18
  lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
19
19
  lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
20
20
  lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
@@ -24,26 +24,26 @@ lalamo/model_import/decoder_configs/huggingface/common.py,sha256=YYIDEQy8x7lqL2q
24
24
  lalamo/model_import/decoder_configs/huggingface/gemma2.py,sha256=g8LH_GlSNyL04WWi596zI0rWsD3ahnfNjDk-9zZNcDE,4759
25
25
  lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=UXiEyNqlD0Czc5Gj3n4hNqNDp9Ml5YzH1XZ6BXj0mgU,10223
26
26
  lalamo/model_import/decoder_configs/huggingface/gpt_oss.py,sha256=MBCoPbuWyzbJiBRtHOtpaPHJjQ1UVCAYcVrfIejTnlQ,7446
27
- lalamo/model_import/decoder_configs/huggingface/lfm2.py,sha256=vrBMxtiKEg0eHNDL_bWM9odlrsab7jlMXEY8vjEB7-c,7595
27
+ lalamo/model_import/decoder_configs/huggingface/lfm2.py,sha256=tOx4EsDGRd-87E1Q94DkbGlRBeIvBOvapfr9WeUxFYE,8027
28
28
  lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=pGuBQTY6qpx6CriWwdsLpuTSRS7ECoTP1kt5pSKRlNQ,8549
29
- lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=ANB-vQK8U-zVFubZSTDXXt2S70T5SVOGzf7eOVvPzIQ,5773
29
+ lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=NVvr7_3bfcLHGRrHG3b0IylgTt-knH31oLz3yFqrkqQ,5775
30
30
  lalamo/model_import/decoder_configs/huggingface/mistral.py,sha256=MDGC0ivzJuUpOC11n8vFdcVzqccUyaRw_hkL74mVlAg,4599
31
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py,sha256=A8nNIMhPVumvPWIFR3RexRc6XkFyUd_3mmNpmvyPEGE,8816
31
+ lalamo/model_import/decoder_configs/huggingface/modern_bert.py,sha256=Crh20pjSa35fP22D3J-29mv4yzdrjzW6VhOjY4Tasmg,8801
32
32
  lalamo/model_import/decoder_configs/huggingface/qwen2.py,sha256=n3qIANMPbtQsTtk5QEWWFZ6R85eDxR_kaZd0NDlJ3T4,5786
33
33
  lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=i99mfL2DbeJ0l5aFRV84MTT-PsWf6q-8B-SGPIVGe1w,7522
34
34
  lalamo/model_import/loaders/__init__.py,sha256=3THc1wQ4EPBzQkL_4EaKCa7Ev5Z7oczcvc4AHy9v5EI,228
35
35
  lalamo/model_import/loaders/common.py,sha256=kkugV-bMQlN1zvGHoj3uc7z0FbXKoMtXEBTvyu4KxK4,1844
36
- lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFlYrqtWyNkBU_8,9219
37
- lalamo/model_import/loaders/huggingface.py,sha256=qWdzoSvHvb_3prn2kwfxgnYPW2bVB0Q49m_wyRYha8Q,34677
36
+ lalamo/model_import/loaders/executorch.py,sha256=JCeylxmkXT2iOfVmrvgAyP-9Th-96w3sRtssIW43Ag4,9187
37
+ lalamo/model_import/loaders/huggingface.py,sha256=4zIKuYd5-BC1nkf6rtuKxnOmefEWafv6yXuKEdxg9p4,34629
38
38
  lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
39
39
  lalamo/model_import/model_specs/__init__.py,sha256=JISqwJkloQkGD2jvi1MakNEWapIwlNXXVi5giZyXB74,1275
40
- lalamo/model_import/model_specs/common.py,sha256=8ALKxHrt8uK4XiqjK25NwZj1CC7DM7jlYcFVZPGkFrw,6643
40
+ lalamo/model_import/model_specs/common.py,sha256=OcE6wzDz4MsETxYdcOvRT6x6_NpsyBeIlK1Zl6qkMMo,6823
41
41
  lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
42
42
  lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
43
43
  lalamo/model_import/model_specs/gemma.py,sha256=dwKwOHU1sBJNLFAwtEyydsRUF9QENN3SHtjbfqtOSic,3876
44
44
  lalamo/model_import/model_specs/gpt_oss.py,sha256=PLo0QGrXKdX61ReTRdyOaP_EH3Dmj5lp3fpJjZRwRVA,542
45
45
  lalamo/model_import/model_specs/huggingface.py,sha256=TEkU8y95_hmUWyF-Q5hn0dE2SvXbApghAsQwhWRu4D0,431
46
- lalamo/model_import/model_specs/lfm2.py,sha256=uzuFbcj4Wj2OqL7XJE8Q431VYZelS_HkfPFpl7rJuJY,1038
46
+ lalamo/model_import/model_specs/lfm2.py,sha256=wg4Ggt6BbMO4ScJ6h8tjvBc3IVSrMudESQxjleUF9Ds,2198
47
47
  lalamo/model_import/model_specs/llama.py,sha256=TxhKbIBFmGV2NopOg_k3ltsKlJccbxKyu-GQ7hYWCyw,3140
48
48
  lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
49
49
  lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
@@ -55,27 +55,27 @@ lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKC
55
55
  lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
56
56
  lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
57
57
  lalamo/models/common.py,sha256=uU6eCHtIqMeC_aRGVo09NdpAtvQ6RKSbm6pumVvL8pc,2943
58
- lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
58
+ lalamo/models/language_model.py,sha256=HtFS-R4Uqr7SohFstoAZFVrJI293N9cG_LVkXhZxgFI,13546
59
59
  lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
60
- lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
61
- lalamo/modules/classifier.py,sha256=_jtJ3INEq1dJP5HpUmcDk9YYzpRYlQ04zvFGaWBV6Lg,12101
62
- lalamo/modules/common.py,sha256=dqDEOi-C3H4U9iWUisU32RA-wRDCGuaUNGbObRBhyQM,3315
63
- lalamo/modules/decoder.py,sha256=Opd3QIq1mpGr9P7sLH-Fryitlfp6ESTpcX71vgm89t0,7129
64
- lalamo/modules/embedding.py,sha256=LLiH8mTu81JSpUTj-XhsrVIUfl_GhapnXxw1yGSUBgM,28428
65
- lalamo/modules/linear.py,sha256=XfIYhmpk-bwNHIzIgsL48ZUTclHD2KB4uXHMw9NTE-8,42991
66
- lalamo/modules/mlp.py,sha256=bL3sQ46vCNt1MBRwlzmXZx9nQfRe4axpGe5UOFVanBI,17959
67
- lalamo/modules/mlx_interop.py,sha256=FdfU_1iES-HQ9r4K0SkYwJTyvE0f-_T5ursNCjPLZKY,467
60
+ lalamo/modules/activations.py,sha256=25F4XytJMIwPPmUbxiDUrcrdUi4c-O9SUbwv9lnZbuU,992
61
+ lalamo/modules/classifier.py,sha256=Q5eNzJ68to6JGk8IDZiKv6Rmwh15UyT2xC52tP5njoQ,11767
62
+ lalamo/modules/common.py,sha256=Rc9zenrUMntDKZydI1tzt1ZIY8ggfyk3ZDB-xi81ibw,3406
63
+ lalamo/modules/decoder.py,sha256=I30fptNifcdw9OOCU50aZnEqsJ2X4VM9YXdtRkxbqGc,7014
64
+ lalamo/modules/embedding.py,sha256=PdNy4tGt9F1zve4X73WKNS0DXL-nHUFOlZmGFUAarkQ,27727
65
+ lalamo/modules/linear.py,sha256=4xIhmeouD7R10lt8KJBLxgypVXYhpGmXdHUc-96Upfk,42871
66
+ lalamo/modules/mlp.py,sha256=ogxi9q8J38FnuBkAtC7_KTMc7JZG4BRdsAHYprHZNvM,17690
67
+ lalamo/modules/mlx_interop.py,sha256=kgCm6cPvY62ZNY3icuyKY0bow50j73UdyfVym2WqEUk,483
68
68
  lalamo/modules/normalization.py,sha256=cBdOq6OmJssunVeEwFRJD0BDhgFAN7J8gOKwzIUAY8I,3005
69
- lalamo/modules/rope.py,sha256=rCik7vBNqRXYg3LGbmc1mezPRNbIYMg5cydTFpQy-eU,10157
69
+ lalamo/modules/rope.py,sha256=HbIv5ESLGNAK47HAtqu1whLLUa20Sb28U8kEs6KclZM,10127
70
70
  lalamo/modules/torch_interop.py,sha256=-mujd1zI4ec2w92Hd50RtDa0K3jl6ZSnPxc5r3Fp9nU,916
71
- lalamo/modules/transformer.py,sha256=4olEO8Eh7U6RwSnaECn39ooPuTKUZp_6QmvO6vdirrQ,10532
72
- lalamo/modules/transformer_layer.py,sha256=ZYmGR2Ej328l7K-YpV4eEiBk8SzLsw1RiuSiUP94UpY,12731
71
+ lalamo/modules/transformer.py,sha256=9FD2k_5qwDHYUG5_6M0wVI9-YxfMv0mXlHS-QKiKcP4,10319
72
+ lalamo/modules/transformer_layer.py,sha256=mOqmfVpT7yfHpU87Koso3lvjH5zc-hgPvgVgk04r6ck,12412
73
73
  lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
74
- lalamo/modules/token_mixers/__init__.py,sha256=z6x8cNjis6xIi_2llIoByKqMF2W4xJ05rDnxitHQ3jU,1139
75
- lalamo/modules/token_mixers/attention.py,sha256=gkGMFah2OHB_tyJpkshM1KhMnzG6U7Xt273MkBvDk58,16584
74
+ lalamo/modules/token_mixers/__init__.py,sha256=lwxUl0eG5IvuVc_HOsINP2vtbv9F0cUmSNHFHaEmPGk,1109
75
+ lalamo/modules/token_mixers/attention.py,sha256=ielw1-KWBfCPCPmzSHgM0TaSUcmSkWKTxrN3N_FsGm4,16144
76
76
  lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
77
- lalamo/modules/token_mixers/mamba.py,sha256=fo8xvvmIQss2lKLhav19Jzk1-hTykNp2sjcN6ntcWj4,18789
78
- lalamo/modules/token_mixers/short_conv.py,sha256=93SmoVsuAtdX4ckAkvhHXHiO67pU6soYFpBZxdPFEwc,5219
77
+ lalamo/modules/token_mixers/mamba.py,sha256=zV5CnhEbAtJ32V32a2VZGsbjZ-sohMqRbR5kW9XH1AI,19087
78
+ lalamo/modules/token_mixers/short_conv.py,sha256=k1z9UwcJGag2NHWad7cYiAnhxULtmva9RrdhqVbir18,5085
79
79
  lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
80
80
  lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
81
81
  lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
83
83
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
84
84
  lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
85
85
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
86
- lalamo/speculator/estimator.py,sha256=j-zmhy3RxYDmQ7W0FMTmDk3i275r_Vg1s4NCaS4c_SQ,2760
86
+ lalamo/speculator/estimator.py,sha256=S_TRwMnjWg5qt9le2AYua_Vmo6QkIT-0Si7TjCfC7xc,2670
87
87
  lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
88
- lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
88
+ lalamo/speculator/ngram.py,sha256=Fy3A-oVxZql3gE5M5ot0hKPu0772-kcEPDvD9MkldpA,5889
89
89
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
90
- lalamo-0.5.17.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
- lalamo-0.5.17.dist-info/METADATA,sha256=16-W1J0wiwrmgMTgqiE9r3vxKRmZbGgZ-zS7bNACwTA,3113
92
- lalamo-0.5.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
93
- lalamo-0.5.17.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
- lalamo-0.5.17.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
- lalamo-0.5.17.dist-info/RECORD,,
90
+ lalamo-0.6.1.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
+ lalamo-0.6.1.dist-info/METADATA,sha256=eAuWPVMZl52_KExdalios28l6mOQmKgE3EcIUGUKd4k,3112
92
+ lalamo-0.6.1.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
93
+ lalamo-0.6.1.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
+ lalamo-0.6.1.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
+ lalamo-0.6.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5