lalamo 0.6.4__tar.gz → 0.6.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {lalamo-0.6.4 → lalamo-0.6.6}/PKG-INFO +1 -1
  2. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/__init__.py +1 -1
  3. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/commands.py +247 -14
  4. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/common.py +33 -0
  5. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/data/__init__.py +3 -2
  6. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/data/huggingface_message.py +4 -5
  7. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/main.py +274 -9
  8. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/message_processor.py +19 -1
  9. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/common.py +17 -1
  10. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/mistral.py +5 -0
  11. lalamo-0.6.6/lalamo/model_import/remote_registry.py +44 -0
  12. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/models/__init__.py +3 -0
  13. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/models/common.py +22 -0
  14. lalamo-0.6.6/lalamo/models/compile_helpers.py +58 -0
  15. lalamo-0.6.6/lalamo/models/language_model.py +638 -0
  16. lalamo-0.6.6/lalamo/models/lm_helpers.py +198 -0
  17. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/decoder.py +4 -0
  18. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/mamba.py +345 -105
  19. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/speculator/__init__.py +0 -2
  20. lalamo-0.6.6/lalamo/speculator/inference.py +75 -0
  21. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo.egg-info/PKG-INFO +1 -1
  22. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo.egg-info/SOURCES.txt +3 -1
  23. {lalamo-0.6.4 → lalamo-0.6.6}/pyproject.toml +3 -3
  24. lalamo-0.6.4/lalamo/models/language_model.py +0 -352
  25. lalamo-0.6.4/lalamo/speculator/estimator.py +0 -127
  26. lalamo-0.6.4/lalamo/speculator/inference.py +0 -101
  27. {lalamo-0.6.4 → lalamo-0.6.6}/LICENSE +0 -0
  28. {lalamo-0.6.4 → lalamo-0.6.6}/README.md +0 -0
  29. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/data/lalamo_completions.py +0 -0
  30. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/data/utils.py +0 -0
  31. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/__init__.py +0 -0
  32. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  33. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/common.py +0 -0
  34. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  35. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  36. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  37. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  38. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  39. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  40. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  41. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  42. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  43. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  44. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  45. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  46. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  47. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/huggingface_generation_config.py +0 -0
  48. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  49. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/loaders/__init__.py +0 -0
  50. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/loaders/common.py +0 -0
  51. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/loaders/executorch.py +0 -0
  52. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/loaders/huggingface.py +0 -0
  53. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/loaders/utils.py +0 -0
  54. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/__init__.py +0 -0
  55. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/common.py +0 -0
  56. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/deepseek.py +0 -0
  57. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  58. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/gemma.py +0 -0
  59. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  60. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/huggingface.py +0 -0
  61. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/lfm2.py +0 -0
  62. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/llama.py +0 -0
  63. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/llamba.py +0 -0
  64. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/mirai.py +0 -0
  65. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/pleias.py +0 -0
  66. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/polaris.py +0 -0
  67. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/qwen.py +0 -0
  68. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/model_import/model_specs/reka.py +0 -0
  69. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/models/classifier.py +0 -0
  70. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/__init__.py +0 -0
  71. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/activations.py +0 -0
  72. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/classifier.py +0 -0
  73. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/common.py +0 -0
  74. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/embedding.py +0 -0
  75. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/linear.py +0 -0
  76. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/mlp.py +0 -0
  77. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/mlx_interop.py +0 -0
  78. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/normalization.py +0 -0
  79. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/rope.py +0 -0
  80. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/__init__.py +0 -0
  81. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/attention.py +0 -0
  82. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/common.py +0 -0
  83. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/short_conv.py +0 -0
  84. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  85. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/common.py +0 -0
  86. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  87. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  88. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  89. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/torch_interop.py +0 -0
  90. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/transformer.py +0 -0
  91. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/transformer_layer.py +0 -0
  92. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/modules/utils.py +0 -0
  93. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/quantization.py +0 -0
  94. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/registry_abc.py +0 -0
  95. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/safetensors.py +0 -0
  96. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/sampling.py +0 -0
  97. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/speculator/common.py +0 -0
  98. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/speculator/ngram.py +0 -0
  99. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/speculator/utils.py +0 -0
  100. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo/utils.py +0 -0
  101. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo.egg-info/dependency_links.txt +0 -0
  102. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo.egg-info/entry_points.txt +0 -0
  103. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo.egg-info/requires.txt +0 -0
  104. {lalamo-0.6.4 → lalamo-0.6.6}/lalamo.egg-info/top_level.txt +0 -0
  105. {lalamo-0.6.4 → lalamo-0.6.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.4
3
+ Version: 0.6.6
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -32,7 +32,7 @@ from lalamo.speculator import (
32
32
  SpeculatorTrainingEvent,
33
33
  )
34
34
 
35
- __version__ = "0.6.4"
35
+ __version__ = "0.6.6"
36
36
 
37
37
  __all__ = [
38
38
  "AssistantMessage",
@@ -1,16 +1,22 @@
1
1
  import json
2
+ import shutil
3
+ import tempfile
2
4
  from collections.abc import Callable, Iterable
3
5
  from dataclasses import dataclass
4
6
  from enum import Enum
5
7
  from itertools import chain
6
8
  from pathlib import Path
7
9
 
10
+ import polars as pl
11
+ import requests
12
+ import thefuzz.process
8
13
  from jaxtyping import DTypeLike
9
14
 
10
- from lalamo.common import flatten_parameters
11
- from lalamo.data import import_hf_parquet
15
+ from lalamo.common import flatten_parameters, get_default_device_bytes
16
+ from lalamo.data import load_hf_parquet, shuffle_dataset
17
+ from lalamo.data.huggingface_message import HFMessage
12
18
  from lalamo.data.lalamo_completions import LalamoCompletion
13
- from lalamo.message_processor import Message
19
+ from lalamo.message_processor import AssistantMessage, Message
14
20
  from lalamo.model_import import ModelMetadata, ModelSpec, import_model
15
21
  from lalamo.model_import.common import (
16
22
  DownloadingFileEvent,
@@ -20,15 +26,107 @@ from lalamo.model_import.common import (
20
26
  InitializingModelEvent,
21
27
  StatusEvent,
22
28
  )
29
+ from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile
23
30
  from lalamo.models import LanguageModelConfig
31
+ from lalamo.models.common import BatchSizesComputedEvent, InferenceConfig
32
+ from lalamo.models.lm_helpers import estimate_batchsize_from_bytes
24
33
  from lalamo.modules import config_converter
25
34
  from lalamo.safetensors import safe_write
26
- from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
27
35
  from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
28
36
  from lalamo.speculator.ngram import NGramSpeculator
29
37
  from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
30
38
 
31
39
 
40
+ @dataclass
41
+ class PullCallbacks:
42
+ model_spec: RegistryModel
43
+ output_dir: Path
44
+ overwrite: bool
45
+
46
+ def started(self) -> None:
47
+ pass
48
+
49
+ def output_dir_exists(self) -> None:
50
+ raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
51
+
52
+ def downloading(self, file_spec: RegistryModelFile) -> None:
53
+ pass
54
+
55
+ def finished_downloading(self, file_spec: RegistryModelFile) -> None:
56
+ pass
57
+
58
+ def finished(self) -> None:
59
+ pass
60
+
61
+
62
+ def _download_file(url: str, dest_path: Path) -> None:
63
+ response = requests.get(url, stream=True, timeout=60)
64
+ response.raise_for_status()
65
+
66
+ with open(dest_path, "wb") as f:
67
+ for chunk in response.iter_content(chunk_size=8192):
68
+ if chunk:
69
+ f.write(chunk)
70
+
71
+
72
+ def _suggest_similar_models(query: str, available_models: list[RegistryModel], limit: int = 3) -> list[str]:
73
+ repo_ids = [m.repo_id for m in available_models]
74
+ matches = thefuzz.process.extract(query, repo_ids, limit=limit)
75
+ return [match[0] for match in matches if match[1] >= 50]
76
+
77
+
78
+ def pull(
79
+ model_spec: RegistryModel,
80
+ output_dir: Path,
81
+ callbacks_type: Callable[
82
+ [
83
+ RegistryModel,
84
+ Path,
85
+ bool,
86
+ ],
87
+ PullCallbacks,
88
+ ] = PullCallbacks,
89
+ overwrite: bool = False,
90
+ ) -> None:
91
+ callbacks = callbacks_type(model_spec, output_dir, overwrite)
92
+
93
+ if output_dir.exists():
94
+ callbacks.output_dir_exists()
95
+
96
+ callbacks.started()
97
+
98
+ with tempfile.TemporaryDirectory() as temp_dir:
99
+ temp_path = Path(temp_dir)
100
+
101
+ for file_spec in model_spec.files:
102
+ callbacks.downloading(file_spec)
103
+
104
+ # Security: validate filename to prevent path traversal attacks
105
+ safe_name = Path(file_spec.name).name
106
+ if not safe_name or safe_name != file_spec.name:
107
+ raise RuntimeError(
108
+ f"Invalid filename from registry: {file_spec.name!r}. "
109
+ f"Filenames must not contain path separators or traversal sequences.",
110
+ )
111
+
112
+ file_path = temp_path / safe_name
113
+ try:
114
+ _download_file(file_spec.url, file_path)
115
+ except requests.RequestException as e:
116
+ raise RuntimeError(f"Failed to download {safe_name}: {e}") from e
117
+
118
+ callbacks.finished_downloading(file_spec)
119
+
120
+ output_dir.mkdir(parents=True, exist_ok=True)
121
+ for file_spec in model_spec.files:
122
+ safe_name = Path(file_spec.name).name
123
+ src = temp_path / safe_name
124
+ dst = output_dir / safe_name
125
+ shutil.move(str(src), str(dst))
126
+
127
+ callbacks.finished()
128
+
129
+
32
130
  class Precision(Enum):
33
131
  FLOAT32 = "float32"
34
132
  FLOAT16 = "float16"
@@ -244,16 +342,19 @@ def estimate_batchsize(
244
342
  model = LanguageModelConfig.load_model(model_path)
245
343
  callbacks.finished_loading_model()
246
344
 
247
- def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
248
- callbacks.estimating_batchsize(event.lo, event.hi)
249
-
250
- bs = estimate_batchsize_from_memory(
251
- model,
252
- max_input_length,
253
- max_output_length,
254
- num_logits_per_token,
345
+ def memory_per_batchsize(batch_size: int) -> int:
346
+ inference_config = InferenceConfig(
347
+ max_output_length=max_output_length,
348
+ padded_length=max_input_length,
349
+ num_top_logits_to_return=num_logits_per_token,
350
+ batch_size=batch_size,
351
+ )
352
+ return model.estimate_memory_consumption(inference_config=inference_config)
353
+
354
+ bs = estimate_batchsize_from_bytes(
355
+ memory_per_batchsize,
255
356
  mem,
256
- progress_callback,
357
+ lambda event: callbacks.estimating_batchsize(event.lo, event.hi),
257
358
  )
258
359
 
259
360
  callbacks.finished_estimating_batchsize(bs)
@@ -329,7 +430,11 @@ def collect_traces(
329
430
  callbacks.finished_loading_model()
330
431
 
331
432
  callbacks.loading_dataset()
332
- dataset = iter(import_hf_parquet(dataset_path))
433
+ dataframe = shuffle_dataset(load_hf_parquet(dataset_path))
434
+ conversations = dataframe.get_column("conversation")
435
+ dataset = iter(
436
+ [HFMessage.from_dict(message).as_message() for message in conversation] for conversation in conversations
437
+ )
333
438
  dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
334
439
  callbacks.finished_loading_dataset()
335
440
 
@@ -427,3 +532,131 @@ def train(
427
532
  with open(output_path, "wb") as fd:
428
533
  fd.write(speculator.serialize())
429
534
  callbacks.finished_saving_speculator()
535
+
536
+
537
+ @dataclass
538
+ class GenerateRepliesCallbacks:
539
+ model_path: Path
540
+ dataset_path: Path
541
+ output_path: Path
542
+ max_vram: int | None
543
+ batch_size: int | None
544
+ total_rows: int
545
+
546
+ def loading_model(self) -> None:
547
+ pass
548
+
549
+ def finished_loading_model(self) -> None:
550
+ pass
551
+
552
+ def loading_dataset(self) -> None:
553
+ pass
554
+
555
+ def finished_loading_dataset(self) -> None:
556
+ pass
557
+
558
+ def estimating_batchsize(self, sequence_length: int, lo: int, hi: int | None) -> None:
559
+ pass
560
+
561
+ def batch_sizes_estimated(self) -> None:
562
+ pass
563
+
564
+ def batch_sizes_computed(self, event: BatchSizesComputedEvent) -> None:
565
+ pass
566
+
567
+ def generation_progress(self, rows_processed: int) -> None:
568
+ pass
569
+
570
+ def finished_generation(self) -> None:
571
+ pass
572
+
573
+
574
+ def generate_replies(
575
+ model_path: Path,
576
+ dataset_path: Path,
577
+ output_path: Path,
578
+ max_vram: int | None,
579
+ max_output_length: int = 8192,
580
+ batch_size: int | None = None,
581
+ callbacks_type: Callable[
582
+ [
583
+ Path,
584
+ Path,
585
+ Path,
586
+ int | None,
587
+ int | None,
588
+ int,
589
+ ],
590
+ GenerateRepliesCallbacks,
591
+ ] = GenerateRepliesCallbacks,
592
+ ) -> None:
593
+ # figure out max_vram if neither batch_size nor max_vram is set
594
+ if max_vram is None and batch_size is None:
595
+ max_vram = get_default_device_bytes()
596
+ if max_vram is None:
597
+ raise ValueError(
598
+ "Unable to determine default defice memory capacity; please specify either --vram-gb or --batch-size",
599
+ )
600
+
601
+ # Count rows without loading full dataset
602
+ total_rows = pl.scan_parquet(dataset_path).select(pl.len()).collect().item()
603
+
604
+ callbacks = callbacks_type(
605
+ model_path,
606
+ dataset_path,
607
+ output_path,
608
+ max_vram,
609
+ batch_size,
610
+ total_rows,
611
+ )
612
+
613
+ callbacks.loading_model()
614
+ model = LanguageModelConfig.load_model(model_path)
615
+ callbacks.finished_loading_model()
616
+
617
+ callbacks.loading_dataset()
618
+ dataframe = load_hf_parquet(dataset_path).collect()
619
+ conversations = dataframe.get_column("conversation")
620
+ dataset = iter(
621
+ [HFMessage.from_dict(message).as_message() for message in conversation] for conversation in conversations
622
+ )
623
+ try:
624
+ first_row = next(dataset)
625
+ except StopIteration:
626
+ callbacks.finished_loading_dataset()
627
+ output_path.parent.mkdir(parents=True, exist_ok=True)
628
+ pl.DataFrame({"response": [], "chain_of_thought": []}).write_parquet(output_path)
629
+ return
630
+ dataset = chain([first_row], dataset) # iterator is lazy, force it to actually open the file
631
+ callbacks.finished_loading_dataset()
632
+
633
+ inference_config = InferenceConfig(max_output_length=max_output_length, batch_size=batch_size)
634
+
635
+ callbacks.batch_sizes_estimated()
636
+
637
+ replies: list[tuple[int, AssistantMessage]] = []
638
+ for rows_processed, (idx, reply) in enumerate(
639
+ model.reply_many(
640
+ dataset,
641
+ inference_config=inference_config,
642
+ vram_bytes=max_vram,
643
+ batch_sizes_callback=callbacks.batch_sizes_computed,
644
+ ),
645
+ ):
646
+ replies.append((idx, reply))
647
+ callbacks.generation_progress(rows_processed)
648
+
649
+ # Sort by original index to restore input order
650
+ replies.sort(key=lambda x: x[0])
651
+
652
+ df = pl.DataFrame(
653
+ {
654
+ "response": [reply.response for _, reply in replies],
655
+ "chain_of_thought": [reply.chain_of_thought for _, reply in replies],
656
+ },
657
+ )
658
+
659
+ output_path.parent.mkdir(parents=True, exist_ok=True)
660
+ df.write_parquet(output_path)
661
+
662
+ callbacks.finished_generation()
@@ -1,7 +1,9 @@
1
+ import os
1
2
  from collections import defaultdict
2
3
  from collections.abc import Mapping, Sequence
3
4
  from typing import cast
4
5
 
6
+ import jax
5
7
  import jax.numpy as jnp
6
8
  from jax._src.api import ShapeDtypeStruct
7
9
  from jaxtyping import Array, DTypeLike
@@ -11,6 +13,7 @@ from lalamo.utils import MapDictValues, MapSequence
11
13
  __all__ = [
12
14
  "DEFAULT_PRECISION",
13
15
  "ArrayLike",
16
+ "LalamoWarning",
14
17
  "ParameterPath",
15
18
  "ParameterTree",
16
19
  "dummy_array",
@@ -23,6 +26,10 @@ __all__ = [
23
26
  DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
24
27
 
25
28
 
29
+ class LalamoWarning(UserWarning):
30
+ """Custom warning class for Lalamo-specific warnings."""
31
+
32
+
26
33
  type ArrayLike = Array | ShapeDtypeStruct
27
34
 
28
35
 
@@ -121,3 +128,29 @@ class ParameterPath(str):
121
128
  if not self:
122
129
  return ParameterPath(str(other))
123
130
  return ParameterPath(self + "." + str(other))
131
+
132
+
133
+ def get_default_device_bytes() -> int | None:
134
+ dynamic_allocate = False
135
+
136
+ preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
137
+ dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
138
+
139
+ allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
140
+ dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
141
+
142
+ if dynamic_allocate:
143
+ return None
144
+
145
+ memory_stats = jax.local_devices()[0].memory_stats()
146
+ if memory_stats is None or "bytes_limit" not in memory_stats:
147
+ return None
148
+
149
+ # 500mb is seemingly the usually observed overhead
150
+ memory_limit = memory_stats["bytes_limit"] - (500 * 1000 * 1000)
151
+
152
+ return memory_limit
153
+
154
+
155
+ def get_usable_memory_from_bytes(limit_bytes: int) -> int:
156
+ return int(limit_bytes * 0.95)
@@ -1,7 +1,8 @@
1
- from .huggingface_message import import_hf_parquet
1
+ from .huggingface_message import load_hf_parquet, shuffle_dataset
2
2
  from .utils import get_prefixes_ending_in_user_message
3
3
 
4
4
  __all__ = [
5
5
  "get_prefixes_ending_in_user_message",
6
- "import_hf_parquet",
6
+ "load_hf_parquet",
7
+ "shuffle_dataset",
7
8
  ]
@@ -1,4 +1,3 @@
1
- from collections.abc import Iterable
2
1
  from dataclasses import dataclass
3
2
  from pathlib import Path
4
3
  from typing import ClassVar, Self
@@ -30,10 +29,10 @@ class HFMessage:
30
29
  raise ValueError(f"Cannot convert {other} message")
31
30
 
32
31
 
33
- def import_hf_parquet(path: Path | str) -> Iterable[list[Message]]:
32
+ def load_hf_parquet(path: Path | str) -> pl.LazyFrame:
34
33
  path = Path(path)
34
+ return pl.scan_parquet(path)
35
35
 
36
- dataframe = pl.scan_parquet(path).collect()
37
36
 
38
- for conversation in dataframe.get_column("conversation").shuffle(1337):
39
- yield [HFMessage.from_dict(message).as_message() for message in conversation]
37
+ def shuffle_dataset(frame: pl.LazyFrame, seed: int = 1337) -> pl.DataFrame:
38
+ return frame.collect().sample(fraction=1.0, shuffle=True, seed=seed)