lalamo 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +24 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +100 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/METADATA +1 -1
  44. lalamo-0.2.3.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,77 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ import jax
5
+ from jax import numpy as jnp
6
+ from jaxtyping import Array, DTypeLike, Float
7
+
8
+ from lalamo.common import ParameterDict
9
+
10
+ from .common import LalamoModule, WeightLayout
11
+
12
+ __all__ = [
13
+ "RMSNorm",
14
+ "RMSNormConfig",
15
+ "UpcastMode",
16
+ ]
17
+
18
+
19
+ class UpcastMode(Enum):
20
+ ONLY_NORMALIZATION = "only_normalization"
21
+ FULL_LAYER = "full_layer"
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class RMSNormConfig:
26
+ scale_precision: DTypeLike
27
+ accumulation_precision: DTypeLike
28
+ epsilon: float
29
+ scale_offset: float | None
30
+ upcast_mode: UpcastMode
31
+
32
+ def init(self, channels: int) -> "RMSNorm":
33
+ scales = jnp.ones(channels, dtype=self.scale_precision)
34
+ return RMSNorm(self, scales=scales)
35
+
36
+
37
+ class RMSNorm(LalamoModule[RMSNormConfig]):
38
+ scales: Float[Array, " channels"]
39
+
40
+ @property
41
+ def activation_precision(self) -> DTypeLike:
42
+ return self.config.scale_precision
43
+
44
+ @property
45
+ def input_dim(self) -> int:
46
+ (result,) = self.scales.shape
47
+ return result
48
+
49
+ def __post_init__(self) -> None:
50
+ if self.config.scale_precision != self.scales.dtype:
51
+ raise ValueError(
52
+ f"Scales precision {self.scales.dtype} does not match the"
53
+ f" specified precision {self.config.scale_precision}",
54
+ )
55
+
56
+ def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
57
+ upcasted_inputs = inputs.astype(self.config.accumulation_precision)
58
+
59
+ adjusted_variance = jnp.mean(jnp.square(upcasted_inputs)) + self.config.epsilon
60
+ normalized_x = upcasted_inputs * jax.lax.rsqrt(adjusted_variance)
61
+
62
+ if self.config.upcast_mode == UpcastMode.ONLY_NORMALIZATION:
63
+ normalized_x = normalized_x.astype(inputs.dtype)
64
+
65
+ if self.config.upcast_mode == UpcastMode.FULL_LAYER:
66
+ adjusted_scales = self.scales.astype(self.config.accumulation_precision)
67
+ else:
68
+ adjusted_scales = self.scales
69
+
70
+ if self.config.scale_offset is not None:
71
+ adjusted_scales = adjusted_scales + self.config.scale_offset
72
+
73
+ result = normalized_x * adjusted_scales
74
+ return result.astype(inputs.dtype)
75
+
76
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
77
+ return ParameterDict(scales=self.scales)
lalamo/modules/rope.py ADDED
@@ -0,0 +1,255 @@
1
+ # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py
2
+ # Original PyTorch code copyright notice:
3
+ #
4
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+
21
+ import equinox as eqx
22
+ from jax import numpy as jnp
23
+ from jaxtyping import Array, DTypeLike, Float, Int
24
+
25
+ from lalamo.common import ParameterDict
26
+
27
+ from .common import LalamoModule, WeightLayout, register_config_union
28
+
29
+ __all__ = [
30
+ "LinearScalingRoPEConfig",
31
+ "LlamaRoPEConfig",
32
+ "PositionalEmbeddings",
33
+ "RoPE",
34
+ "RoPEConfigBase",
35
+ "UnscaledRoPEConfig",
36
+ "YARNRoPEConfig",
37
+ ]
38
+
39
+
40
+ class PositionalEmbeddings(eqx.Module):
41
+ cosines: Float[Array, "tokens head_channels"]
42
+ sines: Float[Array, "tokens head_channels"]
43
+
44
+ @property
45
+ def head_dim(self) -> int:
46
+ return self.cosines.shape[-1]
47
+
48
+ def rotate_half(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
49
+ x1 = heads[..., : self.head_dim // 2]
50
+ x2 = heads[..., self.head_dim // 2 :]
51
+ return jnp.concatenate((-x2, x1), axis=-1)
52
+
53
+ def apply(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
54
+ return heads * self.cosines + self.rotate_half(heads) * self.sines
55
+
56
+ def export(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
57
+ return ParameterDict(
58
+ cosines=self.cosines,
59
+ sines=self.sines,
60
+ )
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class RoPEConfigBase:
65
+ precision: DTypeLike
66
+ base: float
67
+ max_sequence_length: int
68
+
69
+ @property
70
+ def _attention_scaling_factor(self) -> float:
71
+ return 1.0
72
+
73
+ def _scale_inverse_frequencies(
74
+ self,
75
+ inverse_frequencies: Float[Array, " tokens"],
76
+ head_dim: int, # noqa: ARG002
77
+ max_sequence_length: int, # noqa: ARG002
78
+ ) -> Float[Array, " tokens"]:
79
+ return inverse_frequencies
80
+
81
+ def init(
82
+ self,
83
+ head_dim: int,
84
+ num_timesteps: int,
85
+ ) -> "RoPE":
86
+ timesteps = jnp.arange(num_timesteps, dtype=jnp.float32)
87
+ channel_indices = jnp.arange(0, head_dim, 2, dtype=jnp.int32)
88
+ inverse_frequencies = 1.0 / (self.base ** (channel_indices.astype(jnp.float32) / head_dim))
89
+ inverse_frequencies = self._scale_inverse_frequencies(inverse_frequencies, head_dim, self.max_sequence_length)
90
+ outer_inverse_frequencies = jnp.outer(timesteps, inverse_frequencies)
91
+ embeddings = jnp.concatenate((outer_inverse_frequencies, outer_inverse_frequencies), axis=-1)
92
+ cosines = (jnp.cos(embeddings) * self._attention_scaling_factor).astype(self.precision)
93
+ sines = (jnp.sin(embeddings) * self._attention_scaling_factor).astype(self.precision)
94
+ return RoPE(config=self, cosines=cosines, sines=sines)
95
+
96
+
97
+ class RoPE(LalamoModule[RoPEConfigBase]):
98
+ sines: Float[Array, "tokens head_channels"]
99
+ cosines: Float[Array, "tokens head_channels"]
100
+
101
+ @property
102
+ def activation_precision(self) -> DTypeLike:
103
+ return self.config.precision
104
+
105
+ def __post_init__(self) -> None:
106
+ if self.cosines.dtype != self.config.precision:
107
+ raise ValueError(
108
+ f"Cosines dtype {self.cosines.dtype} does not match the specified precision {self.config.precision}",
109
+ )
110
+ if self.sines.dtype != self.config.precision:
111
+ raise ValueError(
112
+ f"Sines dtype {self.sines.dtype} does not match the specified precision {self.config.precision}",
113
+ )
114
+ if self.cosines.shape != self.sines.shape:
115
+ raise ValueError(
116
+ f"Cosines and sines shape mismatch: cosines have shape {self.cosines.shape},"
117
+ f" while sines have shape {self.sines.shape}",
118
+ )
119
+
120
+ @property
121
+ def head_dim(self) -> int:
122
+ _, result = self.sines.shape
123
+ return result
124
+
125
+ @property
126
+ def max_sequence_length(self) -> int:
127
+ result, _ = self.sines.shape
128
+ return result
129
+
130
+ def __call__(self, timesteps: Int[Array, " tokens"]) -> PositionalEmbeddings:
131
+ return PositionalEmbeddings(
132
+ cosines=self.cosines[timesteps],
133
+ sines=self.sines[timesteps],
134
+ )
135
+
136
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
137
+ return ParameterDict(cosines=self.cosines, sines=self.sines)
138
+
139
+
140
+ class UnscaledRoPEConfig(RoPEConfigBase):
141
+ pass
142
+
143
+
144
+ @dataclass(frozen=True)
145
+ class LlamaRoPEConfig(RoPEConfigBase):
146
+ scaling_factor: float
147
+ original_context_length: int
148
+ low_frequency_factor: float
149
+ high_frequency_factor: float
150
+
151
+ def _scale_inverse_frequencies(
152
+ self,
153
+ inverse_frequencies: Float[Array, " tokens"],
154
+ head_dim: int, # noqa: ARG002
155
+ max_sequence_length: int, # noqa: ARG002
156
+ ) -> Float[Array, " tokens"]:
157
+ low_frequency_wavelength = self.original_context_length / self.low_frequency_factor
158
+ high_frequency_wavelength = self.original_context_length / self.high_frequency_factor
159
+
160
+ wavelengths = 2 * math.pi / inverse_frequencies
161
+
162
+ high_frequency_mask = wavelengths < high_frequency_wavelength
163
+ low_frequency_mask = wavelengths > low_frequency_wavelength
164
+ mid_frequency_mask = (~high_frequency_mask) & (~low_frequency_mask)
165
+
166
+ smoothing_factors = self.original_context_length / wavelengths - self.low_frequency_factor
167
+ smoothing_factors = smoothing_factors / (self.high_frequency_factor - self.low_frequency_factor)
168
+
169
+ scaled_frequencies = inverse_frequencies / self.scaling_factor
170
+ smoothly_scaled_frequencies = (
171
+ smoothing_factors * inverse_frequencies + (1 - smoothing_factors) * scaled_frequencies
172
+ )
173
+
174
+ result = inverse_frequencies * high_frequency_mask.astype(jnp.float32)
175
+ result = result + smoothly_scaled_frequencies * mid_frequency_mask.astype(jnp.float32)
176
+ result = result + scaled_frequencies * low_frequency_mask.astype(jnp.float32)
177
+
178
+ return result
179
+
180
+
181
+ @dataclass(frozen=True)
182
+ class YARNRoPEConfig(RoPEConfigBase):
183
+ scaling_factor: float
184
+ beta_fast: float
185
+ beta_slow: float
186
+
187
+ @classmethod
188
+ def _find_correction_dim(cls, num_rotations: float, dim: int, base: float, max_position_embeddings: int) -> float:
189
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
190
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
191
+
192
+ @classmethod
193
+ def _find_correction_range(
194
+ cls,
195
+ low_rot: float,
196
+ high_rot: float,
197
+ dim: int,
198
+ base: float,
199
+ max_position_embeddings: int,
200
+ ) -> tuple[int, int]:
201
+ """Find dimension range bounds based on rotations"""
202
+ low = math.floor(cls._find_correction_dim(low_rot, dim, base, max_position_embeddings))
203
+ high = math.ceil(cls._find_correction_dim(high_rot, dim, base, max_position_embeddings))
204
+ return max(low, 0), min(high, dim - 1)
205
+
206
+ @classmethod
207
+ def _linear_ramp_factor(cls, min_value: float, max_value: float, dim: int) -> Float[Array, " head_dim"]:
208
+ if min_value == max_value:
209
+ max_value += 0.001 # Prevent singularity
210
+
211
+ linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_value) / (max_value - min_value)
212
+ ramp_func = jnp.clip(linear_func, 0, 1)
213
+ return ramp_func
214
+
215
+ def _scale_inverse_frequencies(
216
+ self,
217
+ inverse_frequencies: Float[Array, " tokens"],
218
+ head_dim: int,
219
+ max_sequence_length: int,
220
+ ) -> Float[Array, " tokens"]:
221
+ scaled_frequencies = inverse_frequencies / self.scaling_factor
222
+
223
+ low, high = self._find_correction_range(
224
+ self.beta_fast,
225
+ self.beta_slow,
226
+ head_dim,
227
+ self.base,
228
+ max_sequence_length,
229
+ )
230
+
231
+ # Get n-dimensional rotational scaling corrected for extrapolation
232
+ smoothing_factor = 1 - self._linear_ramp_factor(low, high, head_dim // 2)
233
+ return scaled_frequencies * (1 - smoothing_factor) + inverse_frequencies * smoothing_factor
234
+
235
+ @property
236
+ def attention_scaling_factor(self) -> float:
237
+ return 0.1 * math.log(self.scaling_factor) + 1.0
238
+
239
+
240
+ @dataclass(frozen=True)
241
+ class LinearScalingRoPEConfig(RoPEConfigBase):
242
+ scaling_factor: float
243
+
244
+ def _scale_inverse_frequencies(
245
+ self,
246
+ inverse_frequencies: Float[Array, " tokens"],
247
+ head_dim: int, # noqa: ARG002
248
+ max_sequence_length: int, # noqa: ARG002
249
+ ) -> Float[Array, " tokens"]:
250
+ return inverse_frequencies / self.scaling_factor
251
+
252
+
253
+ RoPEConfig = UnscaledRoPEConfig | LlamaRoPEConfig | YARNRoPEConfig | LinearScalingRoPEConfig
254
+
255
+ register_config_union(RoPEConfig)
@@ -0,0 +1,13 @@
1
+ import jax
2
+ from jaxtyping import Array, Float
3
+
4
+ __all__ = [
5
+ "apply_soft_capping",
6
+ ]
7
+
8
+
9
+ def apply_soft_capping(
10
+ values: Float[Array, "*"],
11
+ soft_cap: float,
12
+ ) -> Float[Array, "dst_tokens src_tokens"]:
13
+ return jax.nn.tanh(values / soft_cap) * soft_cap
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -0,0 +1,53 @@
1
+ lalamo/__init__.py,sha256=9K_9yBY3GmAYmuIxMuTCbWQxjFfcKrdIk27drbDFjuo,217
2
+ lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
3
+ lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
4
+ lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
5
+ lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
+ lalamo/utils.py,sha256=QzkT0_82nd9pS5p0e7yOOdL_ZeKQr_Ftj4kFrWF35R8,1754
7
+ lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
8
+ lalamo/model_import/common.py,sha256=sHXEGQUtVb6TRT5FOGtJG9pz1Ohy5v_LtunubVxZKqQ,3303
9
+ lalamo/model_import/configs/__init__.py,sha256=JYXeco_kfzKZuWqEmG24qxeYWs-FuE1W1kNgoFNrBEw,461
10
+ lalamo/model_import/configs/common.py,sha256=MKAinEL7WXkijS3IrfiTRgx2l6otpnIaJG_CajosMCU,1803
11
+ lalamo/model_import/configs/executorch.py,sha256=Kx_T-B5jumfWf9vj20We4FF0GkSkTmIYeWOss88-qYA,5266
12
+ lalamo/model_import/configs/huggingface/__init__.py,sha256=kWHUnZDwGQCbA3Ucm-FEDr8zZ2yZ3yviPVftlNgMk30,460
13
+ lalamo/model_import/configs/huggingface/common.py,sha256=p6oEKIT2Ezh_d8eDXYzHaJaqjPriQrAzz2bkEq_HkgY,1698
14
+ lalamo/model_import/configs/huggingface/gemma2.py,sha256=oIefI_ad-7DtzXmisFczkKPuOQ-KkzMkKWTk9likaMs,4101
15
+ lalamo/model_import/configs/huggingface/gemma3.py,sha256=1tkkmEs4pF0t0XlDS2Z5mWzcPGrRWb7FcLgazrFDJy8,6434
16
+ lalamo/model_import/configs/huggingface/llama.py,sha256=_vOalgc24uhMcPyCqyxWOZk80hXxqN-dhMHvBGtbIlc,5444
17
+ lalamo/model_import/configs/huggingface/mistral.py,sha256=39qsX_Twml8C0xz0CayVZse2uaHJtKS9-54B8nQw_5k,4148
18
+ lalamo/model_import/configs/huggingface/qwen2.py,sha256=GnO1_DKDewiB4AW8lJu_x30lL-GgB9GYc64rl6XqfYI,4963
19
+ lalamo/model_import/configs/huggingface/qwen3.py,sha256=UJ-EP0geHmGXnT_Ioy7Z7V4vns_dKz2YpPe-GLPQg20,5029
20
+ lalamo/model_import/loaders/__init__.py,sha256=Olg7a79phusilNgEa7PTgx1JgQQJLgAVg18T8isp0mw,148
21
+ lalamo/model_import/loaders/common.py,sha256=2FigeDMUwlMPUebX8DAK2Yh9aLgVtsfTj0S431p7A0o,1782
22
+ lalamo/model_import/loaders/executorch.py,sha256=nSvpylK8QL3nBk78P3FabLoyA87E3kv5CCpMfvuZe6Q,8886
23
+ lalamo/model_import/loaders/huggingface.py,sha256=Ze_qB0fSxY8lH4ovH0t8jd5jiteasUWkS9HdgMZXCrs,10523
24
+ lalamo/model_import/model_specs/__init__.py,sha256=_sJthAH1xXl5B9JPhRqMVP2t5KkhzqmKFHSRlOiFg8s,915
25
+ lalamo/model_import/model_specs/common.py,sha256=ygfNjwVZBrjNkCVuv66R1vy5hXjgbAJyDc0QJfRfgik,3789
26
+ lalamo/model_import/model_specs/deepseek.py,sha256=9l3pVyC-ZoIaFG4xWhPDCbKkD2TsND286o0KzO0uxKo,788
27
+ lalamo/model_import/model_specs/gemma.py,sha256=y4aDeaGGl4JPIanAgPMOlyfD_cx3Q7rpTKgDgx5AsX0,2299
28
+ lalamo/model_import/model_specs/huggingface.py,sha256=ktDJ_qZxSGmHREydrYQaWi71bXJZiHqzHDoZeORENno,784
29
+ lalamo/model_import/model_specs/llama.py,sha256=7eXfMwj_VZpeHAuXmPk1jcA_X7iXsJ8AWf6pk_Qy7rg,3226
30
+ lalamo/model_import/model_specs/mistral.py,sha256=xDX2SyTruGR7A8LI_Ypa6qAP5nVyYhxLffoxS2F6bmI,1649
31
+ lalamo/model_import/model_specs/pleias.py,sha256=zLRjmT6PXFtykqSYpaRtVObP306urMjF2J6dTKdAbQM,747
32
+ lalamo/model_import/model_specs/polaris.py,sha256=TiGlXI3j7HP9bs01jdcysBNFxvNKnxTF30wuv5Jg2mQ,768
33
+ lalamo/model_import/model_specs/qwen.py,sha256=dsCo3uaSPtPPjGuWHerFUY27f5Pv_HFOao2lPjFFHJI,11302
34
+ lalamo/model_import/model_specs/reka.py,sha256=YtAuM52ImgH24lVuICXDcS39mNNzG_b-ouoAy5uVYLk,766
35
+ lalamo/modules/__init__.py,sha256=iNzQL_qIG10U157bWmblj9fZNewup0O0aB8IsMXuBPU,2164
36
+ lalamo/modules/activations.py,sha256=ZgUd3E4VTAVgCZaj9HhYkXiJuiKrWBzK6so5JGnucOc,532
37
+ lalamo/modules/attention.py,sha256=ZvhQPBpsgZc-RDM1EQZgG6a8CZGYvMthE7g3qDedUVM,12158
38
+ lalamo/modules/common.py,sha256=6FOmvICxVJTxLll785WY7KVY7ixAvPuzaW_J1trYNj0,3171
39
+ lalamo/modules/decoder.py,sha256=Erc8k_tjrCCZOMBJM1bL0b96zFjaKDGIuFmtEp-OJVk,9043
40
+ lalamo/modules/decoder_layer.py,sha256=gVVE48hlkNTvWJIt5oKak7VhEIrYdCaErf7_6_A_9ys,9443
41
+ lalamo/modules/embedding.py,sha256=6xnNFi_TrB0ymSAmCLwmwjAZW0pchgzjQDjDws22PPw,10684
42
+ lalamo/modules/kv_cache.py,sha256=GLZ84VTl2QtJYdnADR3j0e4CUmnSAlSvnCb63E8AtIk,7225
43
+ lalamo/modules/linear.py,sha256=loUGFu3wx-iGqDqGMphQorhqBm7b9lAqT4B0jAmoamk,24087
44
+ lalamo/modules/mlp.py,sha256=bV8qJTjsQFGv-CA7d32UQFn6BX5zmCKWC5pgm29-W3U,2631
45
+ lalamo/modules/normalization.py,sha256=BWCHv6ycFJ_qMGfxkusGfay9dWzUlbpuwmjbLy2rI68,2380
46
+ lalamo/modules/rope.py,sha256=Vdt2J_W0MPDK52nHsroLVCfWMHyHW3AfrKZCZAE4VYs,9369
47
+ lalamo/modules/utils.py,sha256=5QTdi34kEI5jix7TfTdB0mOYZbzZUul_T1y8eWCA6lQ,262
48
+ lalamo-0.2.3.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
49
+ lalamo-0.2.3.dist-info/METADATA,sha256=t6eIuMJLWk08EVESbkb_QfG2uvQxlokJ98lKwQckU6U,2611
50
+ lalamo-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
+ lalamo-0.2.3.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
52
+ lalamo-0.2.3.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
53
+ lalamo-0.2.3.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- lalamo/__init__.py,sha256=uKBR6vAH2AmdpPqz1q2zVVwQyCpWRWUHAfm-uQg8DAM,217
2
- lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
3
- lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
4
- lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
5
- lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
- lalamo/utils.py,sha256=QzkT0_82nd9pS5p0e7yOOdL_ZeKQr_Ftj4kFrWF35R8,1754
7
- lalamo-0.2.1.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
8
- lalamo-0.2.1.dist-info/METADATA,sha256=1qDWPQiCYK_EIeff-oiaF7VeIksGNdZ4nCFikHXGJR4,2611
9
- lalamo-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- lalamo-0.2.1.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
11
- lalamo-0.2.1.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
12
- lalamo-0.2.1.dist-info/RECORD,,
File without changes