lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/main.py CHANGED
@@ -10,6 +10,7 @@ from pathlib import Path
10
10
  from typing import Annotated
11
11
 
12
12
  import jax.profiler
13
+ import requests
13
14
  import thefuzz.process
14
15
  from click import Context as ClickContext
15
16
  from click import Parameter as ClickParameter
@@ -35,24 +36,31 @@ from lalamo.commands import (
35
36
  CollectTracesCallbacks,
36
37
  ConversionCallbacks,
37
38
  EstimateBatchsizeCallbacks,
39
+ GenerateRepliesCallbacks,
38
40
  Precision,
41
+ PullCallbacks,
39
42
  TraceCallbacks,
40
43
  TrainCallbacks,
44
+ _suggest_similar_models,
41
45
  )
42
46
  from lalamo.commands import collect_traces as _collect_traces
43
47
  from lalamo.commands import convert as _convert
44
48
  from lalamo.commands import estimate_batchsize as _estimate_batchsize
49
+ from lalamo.commands import generate_replies as _generate_replies
50
+ from lalamo.commands import pull as _pull
45
51
  from lalamo.commands import trace as _trace
46
52
  from lalamo.commands import train as _train
53
+ from lalamo.common import (
54
+ get_default_device_bytes,
55
+ get_usable_memory_from_bytes,
56
+ )
47
57
  from lalamo.data.lalamo_completions import LalamoCompletion
48
58
  from lalamo.message_processor import UserMessage
49
59
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec
50
60
  from lalamo.model_import.common import FileSpec
61
+ from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile, fetch_available_models
51
62
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
52
- from lalamo.speculator.estimator import (
53
- get_default_device_bytes,
54
- get_usable_memory_from_bytes,
55
- )
63
+ from lalamo.models.common import BatchSizesComputedEvent
56
64
  from lalamo.speculator.ngram import NGramSpeculator
57
65
  from lalamo.speculator.utils import test_speculator
58
66
 
@@ -76,7 +84,7 @@ class ModelParser(ParamType):
76
84
  def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext | None) -> ModelSpec:
77
85
  result = REPO_TO_MODEL.get(value)
78
86
  if result is None:
79
- closest_repo = _closest_repo(value)
87
+ closest_repo = _closest_repo(value, list(REPO_TO_MODEL))
80
88
  error_message_parts = [
81
89
  f'"{value}".',
82
90
  ]
@@ -92,10 +100,37 @@ class ModelParser(ParamType):
92
100
  return result
93
101
 
94
102
 
95
- def _closest_repo(query: str, min_score: float = 80) -> str | None:
96
- if not REPO_TO_MODEL:
103
+ class RemoteModelParser(ParamType):
104
+ name: str = "Pre-converted Model"
105
+
106
+ def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext | None) -> "RegistryModel":
107
+ try:
108
+ available_models = fetch_available_models()
109
+ except (requests.RequestException, ValueError) as e:
110
+ error_message = f"Failed to fetch model list from SDK. Check your internet connection.\n\nError: {e}"
111
+ return self.fail(error_message, param, ctx)
112
+
113
+ repo_to_model = {m.repo_id: m for m in available_models}
114
+ model_spec = repo_to_model.get(value)
115
+ if model_spec is None:
116
+ closest_repo = _closest_repo(value, list(repo_to_model))
117
+ if closest_repo:
118
+ model_spec = repo_to_model[closest_repo]
119
+
120
+ if model_spec is None:
121
+ suggestions = _suggest_similar_models(value, available_models)
122
+ error_message = f'Model "{value}" not found.'
123
+ if suggestions:
124
+ error_message += "\n\nDid you mean one of these?\n" + "\n".join(f" - {s}" for s in suggestions)
125
+ return self.fail(error_message, param, ctx)
126
+
127
+ return model_spec
128
+
129
+
130
+ def _closest_repo(query: str, repo_ids: list[str], min_score: float = 80) -> str | None:
131
+ if not repo_ids:
97
132
  return None
98
- (closest_match, score), *_ = thefuzz.process.extract(query, list(REPO_TO_MODEL))
133
+ (closest_match, score), *_ = thefuzz.process.extract(query, repo_ids)
99
134
  if closest_match and score >= min_score:
100
135
  return closest_match
101
136
  return None
@@ -266,6 +301,49 @@ class CliConversionCallbacks(ConversionCallbacks):
266
301
  console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{self.output_dir}`[/cyan]!")
267
302
 
268
303
 
304
+ @dataclass
305
+ class CliPullCallbacks(PullCallbacks):
306
+ stack: ExitStack = field(default_factory=ExitStack)
307
+ progress: Progress | None = None
308
+ downloading_tasks: dict[RegistryModelFile, TaskID] = field(default_factory=dict)
309
+
310
+ def started(self) -> None:
311
+ console.print(f"📦 Pulling [cyan]{self.model_spec.name}[/cyan] by [cyan]{self.model_spec.vendor}[/cyan]")
312
+
313
+ self.progress = self.stack.enter_context(
314
+ Progress(
315
+ SpinnerColumn(),
316
+ TextColumn("[progress.description]{task.description}"),
317
+ transient=True,
318
+ ),
319
+ )
320
+
321
+ def output_dir_exists(self) -> None:
322
+ if not self.overwrite and not Confirm().ask(
323
+ rf"⚠️ Output directory [cyan]{self.output_dir}[/cyan] already exists."
324
+ r" Do you want to overwrite it?",
325
+ ):
326
+ raise Exit
327
+
328
+ shutil.rmtree(self.output_dir)
329
+
330
+ def downloading(self, file_spec: RegistryModelFile) -> None:
331
+ assert self.progress is not None
332
+
333
+ self.downloading_tasks[file_spec] = self.progress.add_task(f"⬇️ Downloading {file_spec.name}...")
334
+
335
+ def finished_downloading(self, file_spec: RegistryModelFile) -> None:
336
+ assert self.progress is not None
337
+
338
+ self.progress.remove_task(self.downloading_tasks[file_spec])
339
+
340
+ def finished(self) -> None:
341
+ assert self.progress is not None
342
+
343
+ self.stack.close()
344
+ console.print(f"🎉 Model successfully pulled to [cyan]{self.output_dir}[/cyan]!")
345
+
346
+
269
347
  @app.command(help="Convert the model for use with the Uzu inference engine.")
270
348
  def convert(
271
349
  model_repo: Annotated[
@@ -322,6 +400,46 @@ def convert(
322
400
  )
323
401
 
324
402
 
403
+ @app.command(help="Pull a pre-converted model from the SDK repository.")
404
+ def pull(
405
+ model_spec: Annotated[
406
+ RegistryModel,
407
+ Argument(
408
+ help=(
409
+ "Model repository ID from the pre-converted catalog. "
410
+ "Example: [cyan]'meta-llama/Llama-3.2-1B-Instruct'[/cyan]. "
411
+ "Fuzzy matching is supported for typos and partial names."
412
+ ),
413
+ click_type=RemoteModelParser(),
414
+ show_default=False,
415
+ metavar="MODEL_IDENTIFIER",
416
+ ),
417
+ ],
418
+ output_dir: Annotated[
419
+ Path | None,
420
+ Option(
421
+ help="Directory to save the pulled model to.",
422
+ show_default="Saves the pulled model in the `models/<model_name>` directory",
423
+ ),
424
+ ] = None,
425
+ overwrite: Annotated[
426
+ bool,
427
+ Option(
428
+ help="Overwrite existing model files without prompting.",
429
+ ),
430
+ ] = False,
431
+ ) -> None:
432
+ if output_dir is None:
433
+ output_dir = DEFAULT_OUTPUT_DIR / model_spec.name
434
+
435
+ _pull(
436
+ model_spec,
437
+ output_dir,
438
+ partial(CliPullCallbacks),
439
+ overwrite=overwrite,
440
+ )
441
+
442
+
325
443
  @dataclass
326
444
  class CliTraceCallbacks(TraceCallbacks):
327
445
  overwrite: bool = False
@@ -492,6 +610,151 @@ def list_models(
492
610
  console.print(table)
493
611
 
494
612
 
613
+ @dataclass
614
+ class CliGenerateRepliesCallbacks(GenerateRepliesCallbacks):
615
+ stack: ExitStack = field(default_factory=ExitStack)
616
+ progress: Progress | None = None
617
+ loading_task: TaskID | None = None
618
+ estimating_task: TaskID | None = None
619
+ generation_task: TaskID | None = None
620
+
621
+ def loading_model(self) -> None:
622
+ self.progress = self.stack.enter_context(
623
+ Progress(
624
+ SpinnerColumn(),
625
+ TextColumn("[progress.description]{task.description}"),
626
+ MofNCompleteColumn(),
627
+ TimeElapsedColumn(),
628
+ transient=True,
629
+ ),
630
+ )
631
+ self.loading_task = self.progress.add_task("🧠 [cyan]Loading model...[/cyan]", total=None)
632
+
633
+ def finished_loading_model(self) -> None:
634
+ assert self.progress is not None
635
+ assert self.loading_task is not None
636
+ self.progress.remove_task(self.loading_task)
637
+
638
+ def loading_dataset(self) -> None:
639
+ assert self.progress is not None
640
+ self.loading_task = self.progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]", total=None)
641
+
642
+ def finished_loading_dataset(self) -> None:
643
+ assert self.progress is not None
644
+ assert self.loading_task is not None
645
+ self.progress.remove_task(self.loading_task)
646
+
647
+ def estimating_batchsize(self, sequence_length: int, lo: int, hi: int | None) -> None:
648
+ assert self.progress is not None
649
+ hi_str = str(hi) if hi is not None else "?"
650
+ description = (
651
+ f"📐 [cyan]Computing batch size for the prompt length of {sequence_length}... ({lo}..{hi_str})[/cyan]"
652
+ )
653
+ if self.estimating_task is None:
654
+ self.estimating_task = self.progress.add_task(description)
655
+ else:
656
+ self.progress.update(self.estimating_task, description=description)
657
+
658
+ def batch_sizes_estimated(self) -> None:
659
+ assert self.progress is not None
660
+ if self.estimating_task is None:
661
+ self.estimating_task = self.progress.add_task(
662
+ "📐 [cyan]Estimating the best batch sizes...[/cyan]",
663
+ total=None,
664
+ )
665
+
666
+ def batch_sizes_computed(self, event: BatchSizesComputedEvent) -> None:
667
+ assert self.progress is not None
668
+ if self.estimating_task is not None:
669
+ self.progress.remove_task(self.estimating_task)
670
+ self.estimating_task = None
671
+ output_console = self.progress.console if self.progress is not None else console
672
+ for info in event.batch_sizes:
673
+ output_console.print(
674
+ f"Prefix length {info.prefix_length} has {info.num_elements} elements, "
675
+ f"with batchsize of {info.batch_size}",
676
+ )
677
+ self.generation_task = self.progress.add_task(
678
+ "🔮 [cyan]Generating replies...[/cyan]",
679
+ total=self.total_rows,
680
+ )
681
+
682
+ def generation_progress(self, rows_processed: int) -> None:
683
+ assert self.progress is not None
684
+ assert self.generation_task is not None
685
+ self.progress.update(self.generation_task, completed=rows_processed + 1)
686
+
687
+ def finished_generation(self) -> None:
688
+ assert self.progress is not None
689
+ assert self.generation_task is not None
690
+ self.progress.update(self.generation_task, description="✅ Completed")
691
+ self.stack.close()
692
+ console.print(f"💾 Replies saved to [cyan]{self.output_path}[/cyan]")
693
+
694
+
695
+ @app.command(help="Generate replies for conversations in a parquet file.")
696
+ def generate_replies(
697
+ model_path: Annotated[
698
+ Path,
699
+ Argument(
700
+ help="Path to the model directory.",
701
+ metavar="MODEL_PATH",
702
+ ),
703
+ ],
704
+ dataset_path: Annotated[
705
+ Path,
706
+ Argument(
707
+ help="Path to the input parquet file with conversations.",
708
+ metavar="DATASET_PATH",
709
+ ),
710
+ ],
711
+ output_path: Annotated[
712
+ Path,
713
+ Option(
714
+ help="Path to save the output parquet file.",
715
+ ),
716
+ ],
717
+ vram_gb: Annotated[
718
+ int | None,
719
+ Option(
720
+ help="Maximum VRAM in GB. Batch sizes are estimated automatically.",
721
+ show_default="max on default device",
722
+ ),
723
+ ] = None,
724
+ max_output_length: Annotated[
725
+ int,
726
+ Option(help="Maximum number of tokens to generate per reply."),
727
+ ] = 8192,
728
+ batch_size: Annotated[
729
+ int | None,
730
+ Option(help="Fixed batch size to use, skipping automatic estimation."),
731
+ ] = None,
732
+ ) -> None:
733
+ if batch_size is not None and vram_gb is not None:
734
+ err_console.print("Cannot use both --batch-size and --vram-gb")
735
+ raise Exit(1)
736
+
737
+ max_vram: int | None = None
738
+ if batch_size is None:
739
+ if vram_gb is not None:
740
+ mem_bytes = vram_gb * 1000 * 1000 * 1000
741
+ elif (mem_bytes := get_default_device_bytes()) is None:
742
+ err_console.print("Cannot get the default device's memory stats, use --vram-gb or --batch-size")
743
+ raise Exit(1)
744
+
745
+ max_vram = mem_bytes
746
+
747
+ _generate_replies(
748
+ model_path,
749
+ dataset_path,
750
+ output_path,
751
+ max_vram,
752
+ max_output_length,
753
+ batch_size,
754
+ CliGenerateRepliesCallbacks,
755
+ )
756
+
757
+
495
758
  speculator_app = Typer()
496
759
  app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
497
760
 
@@ -727,10 +990,12 @@ def view_traces(
727
990
  table.add_column("Prefix")
728
991
  table.add_column("Completion")
729
992
 
993
+ from rich.text import Text
994
+
730
995
  for completion in islice(traces, num_completions):
731
996
  detokenized_prefix = model.message_processor.detokenize(completion.prefix_token_ids)
732
997
  detokenized_completion = model.message_processor.detokenize(completion.completion_token_ids)
733
- table.add_row(detokenized_prefix, detokenized_completion)
998
+ table.add_row(Text(detokenized_prefix), Text(detokenized_completion))
734
999
 
735
1000
  console.print(table)
736
1001
 
@@ -39,6 +39,7 @@ class HuggingFaceMessage(TypedDict):
39
39
  class HuggingFaceRequest(TypedDict):
40
40
  add_generation_prompt: bool
41
41
  bos_token: str | None
42
+ eos_token: str | None
42
43
  messages: list[HuggingFaceMessage]
43
44
  enable_thinking: NotRequired[bool]
44
45
  tools: NotRequired[dict]
@@ -75,6 +76,7 @@ class MessageProcessorConfig:
75
76
  system_role_name: str
76
77
  user_role_name: str
77
78
  assistant_role_name: str
79
+ eos_token: str | None
78
80
  bos_token: str | None
79
81
 
80
82
  def init(self, tokenizer: Tokenizer) -> "MessageProcessor":
@@ -115,6 +117,10 @@ class MessageProcessor:
115
117
  def bos_token(self) -> str | None:
116
118
  return self.config.bos_token
117
119
 
120
+ @property
121
+ def eos_token(self) -> str | None:
122
+ return self.config.eos_token
123
+
118
124
  def message_to_dict(self, message: Message) -> HuggingFaceMessage:
119
125
  match message:
120
126
  case UserMessage(content=content):
@@ -137,7 +143,12 @@ class MessageProcessor:
137
143
  enable_thinking: bool | None = None,
138
144
  ) -> HuggingFaceRequest:
139
145
  converted_messages = [self.message_to_dict(message) for message in messages]
140
- result = HuggingFaceRequest(add_generation_prompt=True, messages=converted_messages, bos_token=self.bos_token)
146
+ result = HuggingFaceRequest(
147
+ add_generation_prompt=True,
148
+ messages=converted_messages,
149
+ bos_token=self.bos_token,
150
+ eos_token=self.eos_token,
151
+ )
141
152
  if enable_thinking is not None:
142
153
  result["enable_thinking"] = enable_thinking
143
154
  if tools is not None:
@@ -163,9 +174,16 @@ class MessageProcessor:
163
174
  rendered = self.render_request(messages)
164
175
  return self.tokenize_text(rendered)
165
176
 
177
+ def tokenize_requests(self, dataset: Iterable[Iterable[Message]]) -> list[list[int]]:
178
+ return [self.tokenize_request(messages) for messages in dataset]
179
+
166
180
  def detokenize(self, tokens: list[int]) -> str:
167
181
  return self.tokenizer.decode(tokens, skip_special_tokens=False)
168
182
 
183
+ def parse_tokenized_response(self, tokens: list[int]) -> AssistantMessage:
184
+ detokenized = self.detokenize(tokens)
185
+ return self.parse_response(detokenized)
186
+
169
187
  def __post_init__(self) -> None:
170
188
  if self.output_parser_regex is not None:
171
189
  all_fields = AssistantMessage.__dataclass_fields__
@@ -138,6 +138,7 @@ def import_message_processor(
138
138
  progress_callback,
139
139
  )
140
140
  tokenizer_config = HFTokenizerConfig.from_json(tokenizer_config_file)
141
+
141
142
  if tokenizer_config.chat_template is None:
142
143
  match model_spec.configs.chat_template:
143
144
  case JSONFieldSpec(file_spec, field_name):
@@ -165,13 +166,28 @@ def import_message_processor(
165
166
  tokenizer.add_special_tokens(added_special_tokens)
166
167
  tokenizer.add_tokens(added_not_special_tokens)
167
168
 
169
+ bos_token = tokenizer_config.bos_token
170
+ eos_token = tokenizer_config.eos_token
171
+
172
+ # If we were not able to identify bos/eos - they are probably somewhere else, so we check config.json
173
+ if eos_token is None or bos_token is None:
174
+ foreign_decoder_config_file = download_config_file(model_spec, output_dir, progress_callback)
175
+ with open(foreign_decoder_config_file) as foreign_decoder_file:
176
+ foreign_decoder_json = json.load(foreign_decoder_file)
177
+
178
+ if bos_token is None:
179
+ bos_token = foreign_decoder_json.get("bos_token_id")
180
+ if eos_token is None:
181
+ eos_token = foreign_decoder_json.get("eos_token_id")
182
+
168
183
  message_processor_config = MessageProcessorConfig(
169
184
  prompt_template=prompt_template,
170
185
  output_parser_regex=model_spec.output_parser_regex,
171
186
  system_role_name=model_spec.system_role_name,
172
187
  user_role_name=model_spec.user_role_name,
173
188
  assistant_role_name=model_spec.assistant_role_name,
174
- bos_token=tokenizer_config.bos_token,
189
+ bos_token=bos_token,
190
+ eos_token=eos_token,
175
191
  )
176
192
  return MessageProcessor(config=message_processor_config, tokenizer=tokenizer)
177
193
 
@@ -10,6 +10,8 @@ from .common import (
10
10
 
11
11
  __all__ = ["MISTRAL_MODELS"]
12
12
 
13
+ CODESTRAL_TOKENIZER_REPO = "mistralai/Codestral-22B-v0.1"
14
+
13
15
  CODESTRAL = [
14
16
  ModelSpec(
15
17
  vendor="Mistral",
@@ -21,6 +23,9 @@ CODESTRAL = [
21
23
  config_type=HFMistralConfig,
22
24
  weights_type=WeightsType.SAFETENSORS,
23
25
  use_cases=(UseCase.CODE,),
26
+ configs=ConfigMap(
27
+ tokenizer_config=FileSpec(repo=CODESTRAL_TOKENIZER_REPO, filename="tokenizer_config.json"),
28
+ ),
24
29
  ),
25
30
  ]
26
31
 
@@ -0,0 +1,44 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, ClassVar
3
+
4
+ import cattrs
5
+ import requests
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class RegistryModelFile:
10
+ name: str
11
+ url: str
12
+ size: int
13
+ crc32c: str
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class RegistryModel:
18
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
19
+
20
+ id: str
21
+ vendor: str
22
+ name: str
23
+ family: str
24
+ size: str
25
+ repo_id: str
26
+ quantization: str | None
27
+ files: list[RegistryModelFile]
28
+
29
+ @classmethod
30
+ def from_dict(cls, data: dict[str, Any]) -> "RegistryModel":
31
+ if "repoId" in data:
32
+ data = {**data, "repo_id": data.pop("repoId")}
33
+ return cls._converter.structure(data, cls)
34
+
35
+
36
+ def fetch_available_models() -> list[RegistryModel]:
37
+ api_url = "https://sdk.trymirai.com/api/v1/models/list/lalamo"
38
+ response = requests.get(api_url, timeout=30)
39
+ response.raise_for_status()
40
+
41
+ data = response.json()
42
+ models_data = data.get("models", [])
43
+
44
+ return [RegistryModel.from_dict(model_data) for model_data in models_data]
lalamo/models/__init__.py CHANGED
@@ -1,7 +1,10 @@
1
1
  from .classifier import ClassifierModel, ClassifierModelConfig
2
+ from .common import BatchSizeInfo, BatchSizesComputedEvent
2
3
  from .language_model import GenerationConfig, LanguageModel, LanguageModelConfig
3
4
 
4
5
  __all__ = [
6
+ "BatchSizeInfo",
7
+ "BatchSizesComputedEvent",
5
8
  "ClassifierModel",
6
9
  "ClassifierModelConfig",
7
10
  "GenerationConfig",
lalamo/models/common.py CHANGED
@@ -18,11 +18,33 @@ from lalamo.modules.decoder import DecoderConfig, DecoderResult
18
18
  from lalamo.safetensors import safe_read
19
19
 
20
20
  __all__ = [
21
+ "BatchSizeInfo",
22
+ "BatchSizesComputedEvent",
21
23
  "TextModel",
22
24
  "TextModelConfig",
23
25
  ]
24
26
 
25
27
 
28
+ @dataclass(frozen=True)
29
+ class InferenceConfig:
30
+ max_output_length: int = 8192
31
+ padded_length: int = 8192
32
+ num_top_logits_to_return: int | None = None
33
+ batch_size: int | None = None
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class BatchSizeInfo:
38
+ prefix_length: int
39
+ num_elements: int
40
+ batch_size: int
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class BatchSizesComputedEvent:
45
+ batch_sizes: tuple[BatchSizeInfo, ...]
46
+
47
+
26
48
  @dataclass(frozen=True)
27
49
  class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
28
50
  model_config: ConfigT
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from typing import TYPE_CHECKING
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+ from .common import InferenceConfig
10
+
11
+ if TYPE_CHECKING:
12
+ from jax._src.stages import Compiled
13
+
14
+ from .language_model import ForwardPassConfig, GenerationConfig, LanguageModel
15
+
16
+ _compile_cache: dict[
17
+ tuple[int, GenerationConfig | None, InferenceConfig | None, ForwardPassConfig | None],
18
+ Compiled,
19
+ ] = {}
20
+
21
+
22
+ def compile_generate_tokens(
23
+ model: LanguageModel,
24
+ generation_config: GenerationConfig | None = None,
25
+ inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
26
+ *,
27
+ forward_pass_config: ForwardPassConfig | None = None,
28
+ ) -> Compiled:
29
+ from .language_model import LanguageModel
30
+
31
+ key = (id(model), generation_config, inference_config, forward_pass_config)
32
+ if key not in _compile_cache:
33
+ generate_tokens_fn = functools.partial(
34
+ LanguageModel.generate_tokens,
35
+ generation_config=generation_config,
36
+ max_output_length=inference_config.max_output_length,
37
+ num_top_logits_to_return=inference_config.num_top_logits_to_return,
38
+ forward_pass_config=forward_pass_config,
39
+ )
40
+ _compile_cache[key] = (
41
+ jax.jit(generate_tokens_fn)
42
+ .lower(
43
+ model,
44
+ prompt_token_ids=jax.ShapeDtypeStruct(
45
+ (inference_config.batch_size, inference_config.padded_length),
46
+ jnp.int32,
47
+ ),
48
+ prompt_lengths_without_padding=jax.ShapeDtypeStruct((inference_config.batch_size,), jnp.int32),
49
+ keys=jax.ShapeDtypeStruct((inference_config.batch_size,), jax.random.key(0).dtype),
50
+ )
51
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
52
+ # 0 - no autotune, gpu shouldn't be touched
53
+ # 1 - basic level, gpu should be touched veeery little
54
+ # 2,3 - gpu touched more and more
55
+ # 4 (default) - gpu might allocate more memory than the run would require!
56
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
57
+ )
58
+ return _compile_cache[key]