lalamo 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.3"
4
+ __version__ = "0.2.4"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
@@ -3,7 +3,6 @@ from enum import Enum
3
3
  from pathlib import Path
4
4
 
5
5
  import jax.numpy as jnp
6
- import torch
7
6
  from jaxtyping import Array, DTypeLike
8
7
  from safetensors.flax import load_file as load_safetensors
9
8
 
@@ -17,9 +16,9 @@ __all__ = [
17
16
  "ModelSpec",
18
17
  "TokenizerFileSpec",
19
18
  "UseCase",
20
- "huggingface_weight_files",
21
19
  "awq_model_spec",
22
20
  "build_quantized_models",
21
+ "huggingface_weight_files",
23
22
  ]
24
23
 
25
24
 
@@ -36,6 +35,9 @@ class WeightsType(Enum):
36
35
  def load(self, filename: Path | str, float_dtype: DTypeLike) -> dict[str, jnp.ndarray]:
37
36
  if self == WeightsType.SAFETENSORS:
38
37
  return {k: cast_if_float(v, float_dtype) for k, v in load_safetensors(filename).items()}
38
+
39
+ import torch
40
+
39
41
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
40
42
  return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
41
43
 
@@ -72,11 +74,15 @@ def huggingface_weight_files(num_shards: int) -> tuple[str, ...]:
72
74
  return tuple(f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1))
73
75
 
74
76
 
75
- def awq_model_spec(model_spec: ModelSpec, repo: str, quantization: QuantizationMode = QuantizationMode.UINT4) -> ModelSpec:
77
+ def awq_model_spec(
78
+ model_spec: ModelSpec,
79
+ repo: str,
80
+ quantization: QuantizationMode = QuantizationMode.UINT4,
81
+ ) -> ModelSpec:
76
82
  return ModelSpec(
77
83
  vendor=model_spec.vendor,
78
84
  family=model_spec.family,
79
- name="{}-AWQ".format(model_spec.name),
85
+ name=f"{model_spec.name}-AWQ",
80
86
  size=model_spec.size,
81
87
  quantization=quantization,
82
88
  repo=repo,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,4 +1,4 @@
1
- lalamo/__init__.py,sha256=9K_9yBY3GmAYmuIxMuTCbWQxjFfcKrdIk27drbDFjuo,217
1
+ lalamo/__init__.py,sha256=mEbbq3bHm0JhMKHAv7egBwAjSSpWpmxYeTeo9df4X8o,217
2
2
  lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
3
3
  lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
4
4
  lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
@@ -22,7 +22,7 @@ lalamo/model_import/loaders/common.py,sha256=2FigeDMUwlMPUebX8DAK2Yh9aLgVtsfTj0S
22
22
  lalamo/model_import/loaders/executorch.py,sha256=nSvpylK8QL3nBk78P3FabLoyA87E3kv5CCpMfvuZe6Q,8886
23
23
  lalamo/model_import/loaders/huggingface.py,sha256=Ze_qB0fSxY8lH4ovH0t8jd5jiteasUWkS9HdgMZXCrs,10523
24
24
  lalamo/model_import/model_specs/__init__.py,sha256=_sJthAH1xXl5B9JPhRqMVP2t5KkhzqmKFHSRlOiFg8s,915
25
- lalamo/model_import/model_specs/common.py,sha256=ygfNjwVZBrjNkCVuv66R1vy5hXjgbAJyDc0QJfRfgik,3789
25
+ lalamo/model_import/model_specs/common.py,sha256=Ob3yTMDczKUHMWBH0PaClbSvHJhKfZ-zbv2Z04YqMVg,3806
26
26
  lalamo/model_import/model_specs/deepseek.py,sha256=9l3pVyC-ZoIaFG4xWhPDCbKkD2TsND286o0KzO0uxKo,788
27
27
  lalamo/model_import/model_specs/gemma.py,sha256=y4aDeaGGl4JPIanAgPMOlyfD_cx3Q7rpTKgDgx5AsX0,2299
28
28
  lalamo/model_import/model_specs/huggingface.py,sha256=ktDJ_qZxSGmHREydrYQaWi71bXJZiHqzHDoZeORENno,784
@@ -45,9 +45,9 @@ lalamo/modules/mlp.py,sha256=bV8qJTjsQFGv-CA7d32UQFn6BX5zmCKWC5pgm29-W3U,2631
45
45
  lalamo/modules/normalization.py,sha256=BWCHv6ycFJ_qMGfxkusGfay9dWzUlbpuwmjbLy2rI68,2380
46
46
  lalamo/modules/rope.py,sha256=Vdt2J_W0MPDK52nHsroLVCfWMHyHW3AfrKZCZAE4VYs,9369
47
47
  lalamo/modules/utils.py,sha256=5QTdi34kEI5jix7TfTdB0mOYZbzZUul_T1y8eWCA6lQ,262
48
- lalamo-0.2.3.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
49
- lalamo-0.2.3.dist-info/METADATA,sha256=t6eIuMJLWk08EVESbkb_QfG2uvQxlokJ98lKwQckU6U,2611
50
- lalamo-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
- lalamo-0.2.3.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
52
- lalamo-0.2.3.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
53
- lalamo-0.2.3.dist-info/RECORD,,
48
+ lalamo-0.2.4.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
49
+ lalamo-0.2.4.dist-info/METADATA,sha256=mTCoEZB9eNgl86j-CSoT8YmFboXd5SUm4IC0YLgxBuk,2611
50
+ lalamo-0.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
+ lalamo-0.2.4.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
52
+ lalamo-0.2.4.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
53
+ lalamo-0.2.4.dist-info/RECORD,,
File without changes