lalamo 0.6.5__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,12 @@
1
- import functools
2
1
  from collections.abc import Callable, Iterable
3
- from itertools import batched, chain, islice
2
+ from itertools import chain
4
3
  from typing import NamedTuple
5
4
 
6
- import jax
7
- import jax.numpy as jnp
8
- from jax._src.stages import Compiled
9
-
10
- from lalamo.common import decrease_batchsize_on_oom
11
5
  from lalamo.data.lalamo_completions import LalamoCompletion
12
6
  from lalamo.data.utils import get_prefixes_ending_in_user_message
13
7
  from lalamo.message_processor import Message
14
8
  from lalamo.models import LanguageModel
9
+ from lalamo.models.common import InferenceConfig
15
10
 
16
11
 
17
12
  class CollectTracesEvent(NamedTuple):
@@ -29,84 +24,52 @@ def inference_collect_traces(
29
24
  tokens_to_generate: int | None = None,
30
25
  progress_callback: Callable[[CollectTracesEvent], None] | None = None,
31
26
  ) -> Iterable[LalamoCompletion]:
32
- def make_generate_tokens_compiled(batch_size: int) -> Compiled:
33
- return (
34
- jax.jit(
35
- functools.partial(
36
- LanguageModel.generate_tokens,
37
- max_output_length=max_output_length,
38
- num_top_logits_to_return=num_top_logits_to_collect,
39
- ),
40
- )
41
- .lower(
42
- model,
43
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
44
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
45
- )
46
- # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
47
- # 0 - no autotune, gpu shouldn't be touched
48
- # 1 - basic level, gpu should be touched veeery little
49
- # 2,3 - gpu touched more and more
50
- # 4 (default) - gpu might allocate more memory than the run would require!
51
- .compile(compiler_options={"xla_gpu_autotune_level": "0"})
52
- )
53
-
54
27
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
55
-
56
28
  tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
57
29
  filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
30
+ filtered_prefixes = list(filtered_prefixes) # eagerly materialize the prompts into RAM
31
+
32
+ config = InferenceConfig(
33
+ max_output_length=max_output_length,
34
+ num_top_logits_to_return=num_top_logits_to_collect,
35
+ padded_length=max_input_length,
36
+ batch_size=batch_size,
37
+ )
38
+
39
+ tokens_generated = 0
40
+
41
+ for idx, generated in enumerate(
42
+ model.generate_tokens_many(
43
+ filtered_prefixes,
44
+ inference_config=config,
45
+ ),
46
+ ):
47
+ token_ids = generated.token_ids.tolist()
48
+ seqlen = next(
49
+ (i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids),
50
+ len(token_ids),
51
+ )
58
52
 
59
- test_batch = list(islice(filtered_prefixes, batch_size))
60
-
61
- def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
62
- tokens_generated, sequences_processed = 0, 0
63
- generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
64
- for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
65
- batch_padding = batch_size - len(real_batch)
66
- batch = (*real_batch, *(([0],) * batch_padding))
67
-
68
- length_without_padding = jnp.array(list(map(len, batch)))
69
-
70
- padded = jnp.array(
71
- [
72
- jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
73
- for tokens in batch
74
- ],
75
- )
76
-
77
- generated = generate_tokens_compiled(
78
- model,
79
- prompt_token_ids=padded,
80
- prompt_lengths_without_padding=length_without_padding,
81
- )
82
-
83
- assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
84
-
85
- for conv_idx in range(len(real_batch)):
86
- token_ids = generated.token_ids[conv_idx].tolist()
87
- seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
88
- if tokens_to_generate is not None:
89
- seqlen = min(seqlen, tokens_to_generate - tokens_generated)
90
- tokens_generated += seqlen
91
- sequences_processed += 1
53
+ if tokens_to_generate is not None:
54
+ seqlen = min(seqlen, tokens_to_generate - tokens_generated)
92
55
 
93
- token_ids = token_ids[:seqlen]
94
- token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
95
- token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
96
- token_logits = [
97
- dict(zip(keys, values, strict=True))
98
- for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
99
- ]
56
+ tokens_generated += seqlen
100
57
 
101
- yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
58
+ token_ids = token_ids[:seqlen]
102
59
 
103
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
104
- break
60
+ assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
61
+ token_logits_ids = generated.top_k_token_ids[:seqlen].tolist()
62
+ token_logits_values = generated.top_k_token_logits[:seqlen].tolist()
63
+ token_logits = [
64
+ dict(zip(keys, values, strict=True))
65
+ for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
66
+ ]
105
67
 
106
- if progress_callback is not None:
107
- progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
68
+ # We need the original prompt tokens - get from indexed_inputs
69
+ yield LalamoCompletion(filtered_prefixes[idx], token_ids, token_logits)
108
70
 
109
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
110
- break
71
+ if progress_callback is not None:
72
+ progress_callback(CollectTracesEvent(idx + 1, tokens_generated))
111
73
 
112
- yield from decrease_batchsize_on_oom(collect_traces_body, batch_size)
74
+ if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
75
+ break
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.5
3
+ Version: 0.6.6
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,21 +1,22 @@
1
- lalamo/__init__.py,sha256=RpKc5sKIQHI8tPVwzH7lIJJWE7tJy6FZauEhabEp2Hg,1532
2
- lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
3
- lalamo/common.py,sha256=ddGIPlFCgo6Q683v8uP8G2dh8nsCJe9woZL8A_7_Rt4,6124
4
- lalamo/main.py,sha256=f1zHYQpX_OndAguOE0wqIOkzjzUolUC7w3_1ndtMC4Y,27655
5
- lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
1
+ lalamo/__init__.py,sha256=FFdoG3pwkVS9Xi1X_aCiByG9QyzSP_NufH2lIhj60EY,1532
2
+ lalamo/commands.py,sha256=3zUZE2bg39XFFGh5PQ-L2oIs73GPsp0EkFinMZQNHro,18097
3
+ lalamo/common.py,sha256=odAWGfzRRfsKJPtVekOLTlN3SrHcMerPzMk4BfVjn9I,5262
4
+ lalamo/main.py,sha256=V6VwWlo7QO5RH5HpsbpxQgDeOGFv2Y4abhJOjHFw2MA,37153
5
+ lalamo/message_processor.py,sha256=gf-CiidoRp1XmLdy8jkv06Gg0Nqe_DAlpYObOF9JfpA,6490
6
6
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
7
7
  lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
8
8
  lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
9
9
  lalamo/sampling.py,sha256=GE6Av7zS-pr5Bg7FtOivRce7I0JIYuNYqfqsRe-yjQk,3867
10
10
  lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
11
- lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
12
- lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
11
+ lalamo/data/__init__.py,sha256=QH23n37CWLcY69oLVE8gNNr6aJ57G0D_ZsO8eYs7-Jk,225
12
+ lalamo/data/huggingface_message.py,sha256=8oTCxL_IOHRCVgQneRv52sgJNprsiAIPvps5nBu6LWo,1037
13
13
  lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
14
14
  lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
15
15
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
16
- lalamo/model_import/common.py,sha256=MIbvK3mxgrDSXea6jujvCOu9Jjyip6MXeTsJjNTBJAU,12325
16
+ lalamo/model_import/common.py,sha256=evZeeizFev2i7Whd9X7sgUQV8v5apjfmy_BhshnFbyo,13011
17
17
  lalamo/model_import/huggingface_generation_config.py,sha256=xicv_kJOfIGlz4gi5fRFIkiAZ9_QRDLRtW8nKMm5tVU,2022
18
18
  lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
19
+ lalamo/model_import/remote_registry.py,sha256=4VjZSwlYMqflMfSSPi7-GSb9tmTLMZELzXoJJ3Tsx5s,1045
19
20
  lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
20
21
  lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
21
22
  lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDjOTrWevfyw7QbWF1mUdOQ,5924
@@ -47,20 +48,22 @@ lalamo/model_import/model_specs/lfm2.py,sha256=wg4Ggt6BbMO4ScJ6h8tjvBc3IVSrMudES
47
48
  lalamo/model_import/model_specs/llama.py,sha256=TxhKbIBFmGV2NopOg_k3ltsKlJccbxKyu-GQ7hYWCyw,3140
48
49
  lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
49
50
  lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
50
- lalamo/model_import/model_specs/mistral.py,sha256=HAojorjOqsJn2DoMBzYRw8A70qCslhFEsE9AF5xumlg,1278
51
+ lalamo/model_import/model_specs/mistral.py,sha256=i616AQg876PP9GHqHwaFIc29rUlsuhs0Z8_p0wv9eYg,1479
51
52
  lalamo/model_import/model_specs/pleias.py,sha256=5sRpZGYwLdsav6bLiW-459y1Cs9iJKgKkBIuGsOxtsQ,368
52
53
  lalamo/model_import/model_specs/polaris.py,sha256=Mw1-6bByjDmPIKlIUIV46CsmV5xUp_laI5Qquo5DmAQ,520
53
54
  lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiEIhWxUdTRfo,7563
54
55
  lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
55
- lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
56
+ lalamo/models/__init__.py,sha256=XMYuKSsiiIQUOq-ZtjIJcaIjTeCMYaO9bKJ9kvvLq98,394
56
57
  lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
57
- lalamo/models/common.py,sha256=uU6eCHtIqMeC_aRGVo09NdpAtvQ6RKSbm6pumVvL8pc,2943
58
- lalamo/models/language_model.py,sha256=HtFS-R4Uqr7SohFstoAZFVrJI293N9cG_LVkXhZxgFI,13546
58
+ lalamo/models/common.py,sha256=8gMDvu0JXNejRslhzdurrPAS3ZymcR2Grq1RRpddc4M,3402
59
+ lalamo/models/compile_helpers.py,sha256=t_rGCznSAQC2W4ioGZUg4Oc7lpTL6VfutKtOZ06qfXo,2227
60
+ lalamo/models/language_model.py,sha256=YL86--CwI-T7h4ymCk3DXZ5Cswq3OCn_7wJGfYI6swk,26113
61
+ lalamo/models/lm_helpers.py,sha256=rocQ184MCF5gnFwLbzWR7mDrV6b-0VxvOkqbhPxsCKE,6590
59
62
  lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
60
63
  lalamo/modules/activations.py,sha256=25F4XytJMIwPPmUbxiDUrcrdUi4c-O9SUbwv9lnZbuU,992
61
64
  lalamo/modules/classifier.py,sha256=Q5eNzJ68to6JGk8IDZiKv6Rmwh15UyT2xC52tP5njoQ,11767
62
65
  lalamo/modules/common.py,sha256=Rc9zenrUMntDKZydI1tzt1ZIY8ggfyk3ZDB-xi81ibw,3406
63
- lalamo/modules/decoder.py,sha256=I30fptNifcdw9OOCU50aZnEqsJ2X4VM9YXdtRkxbqGc,7014
66
+ lalamo/modules/decoder.py,sha256=zC4IlSzBeEbHiAlGCl8TGCBqGLVtXb_FrJuC9cPwYqo,7103
64
67
  lalamo/modules/embedding.py,sha256=PdNy4tGt9F1zve4X73WKNS0DXL-nHUFOlZmGFUAarkQ,27727
65
68
  lalamo/modules/linear.py,sha256=4xIhmeouD7R10lt8KJBLxgypVXYhpGmXdHUc-96Upfk,42871
66
69
  lalamo/modules/mlp.py,sha256=ogxi9q8J38FnuBkAtC7_KTMc7JZG4BRdsAHYprHZNvM,17690
@@ -74,22 +77,21 @@ lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
74
77
  lalamo/modules/token_mixers/__init__.py,sha256=lwxUl0eG5IvuVc_HOsINP2vtbv9F0cUmSNHFHaEmPGk,1109
75
78
  lalamo/modules/token_mixers/attention.py,sha256=ielw1-KWBfCPCPmzSHgM0TaSUcmSkWKTxrN3N_FsGm4,16144
76
79
  lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
77
- lalamo/modules/token_mixers/mamba.py,sha256=zV5CnhEbAtJ32V32a2VZGsbjZ-sohMqRbR5kW9XH1AI,19087
80
+ lalamo/modules/token_mixers/mamba.py,sha256=EFyuAEjp6pNwOriesFnulOafyGRHYgqotou6UE55axE,28945
78
81
  lalamo/modules/token_mixers/short_conv.py,sha256=k1z9UwcJGag2NHWad7cYiAnhxULtmva9RrdhqVbir18,5085
79
82
  lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
80
83
  lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
81
84
  lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
82
85
  lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxbnMJCl--JOoe9Fkcc9Mg,1642
83
86
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
84
- lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
87
+ lalamo/speculator/__init__.py,sha256=Ye3gMhrtNxaWPMzWbXFqKX7Rv32LGlT2k9eX2uvifKg,364
85
88
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
86
- lalamo/speculator/estimator.py,sha256=WPG3rxKq4iLro8QwcePF766ageexHc17ANiF5rKAlKU,3833
87
- lalamo/speculator/inference.py,sha256=47TUiLV0Dkk3dbf1-IkdlWbHCICFw6IDwKZ73FYQUQo,4802
89
+ lalamo/speculator/inference.py,sha256=-RfgtdwMU4-EGnzc7oT8zJEhiA_Md03rPMyyi_BF26k,2792
88
90
  lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
89
91
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
90
- lalamo-0.6.5.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
- lalamo-0.6.5.dist-info/METADATA,sha256=EWI8eHaPSj7tXrW7xW9BPNpeDRjboNNvtGbq3hRELzU,3112
92
- lalamo-0.6.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
- lalamo-0.6.5.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
- lalamo-0.6.5.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
- lalamo-0.6.5.dist-info/RECORD,,
92
+ lalamo-0.6.6.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
93
+ lalamo-0.6.6.dist-info/METADATA,sha256=Yc0I-RS-xekkjmN-Y5LQf-0R3gKuIeuFmZ3hPXJh4tY,3112
94
+ lalamo-0.6.6.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
95
+ lalamo-0.6.6.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
96
+ lalamo-0.6.6.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
97
+ lalamo-0.6.6.dist-info/RECORD,,
@@ -1,127 +0,0 @@
1
- import functools
2
- import itertools
3
- import os
4
- from collections.abc import Callable
5
- from typing import NamedTuple
6
-
7
- import jax
8
- import jax.numpy as jnp
9
-
10
- from lalamo.models import LanguageModel
11
-
12
-
13
- def get_default_device_bytes() -> int | None:
14
- dynamic_allocate = False
15
-
16
- preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
17
- dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
18
-
19
- allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
20
- dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
21
-
22
- if dynamic_allocate:
23
- return None
24
-
25
- memory_stats = jax.local_devices()[0].memory_stats()
26
- if memory_stats is None or "bytes_limit" not in memory_stats:
27
- return None
28
-
29
- mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
30
- try:
31
- mem_fraction = float(mem_fraction_raw)
32
- except ValueError:
33
- mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
34
-
35
- # 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
36
- # so it should correspond to something you'd see in nvidia-smi
37
- memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
38
-
39
- return get_usable_memory_from_bytes(memory_limit)
40
-
41
-
42
- def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
- # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
- return int(limit_bytes * 0.93)
45
-
46
-
47
- def estimate_memory_from_batchsize(
48
- model: LanguageModel,
49
- max_input_length: int,
50
- max_output_length: int,
51
- num_logits_per_token: int,
52
- batch_size: int,
53
- ) -> int:
54
- memory_analysis = (
55
- jax.jit(
56
- functools.partial(
57
- LanguageModel.generate_tokens,
58
- max_output_length=max_output_length,
59
- num_top_logits_to_return=num_logits_per_token,
60
- ),
61
- )
62
- .lower(
63
- model,
64
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
65
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
66
- )
67
- # disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
68
- .compile(compiler_options={"xla_gpu_autotune_level": "0"})
69
- .memory_analysis()
70
- )
71
-
72
- assert hasattr(memory_analysis, "argument_size_in_bytes")
73
- assert hasattr(memory_analysis, "output_size_in_bytes")
74
- assert hasattr(memory_analysis, "temp_size_in_bytes")
75
-
76
- return (
77
- memory_analysis.argument_size_in_bytes
78
- + memory_analysis.output_size_in_bytes
79
- + memory_analysis.temp_size_in_bytes
80
- )
81
-
82
-
83
- class EstimateBatchsizeFromMemoryEvent(NamedTuple):
84
- lo: int
85
- hi: int | None
86
-
87
-
88
- def estimate_batchsize_from_memory(
89
- model: LanguageModel,
90
- max_input_length: int,
91
- max_output_length: int,
92
- num_logits_per_token: int,
93
- target_mem: int,
94
- progress: Callable[[EstimateBatchsizeFromMemoryEvent], None] | None = None,
95
- ) -> int:
96
- mem_for_bs = functools.cache(
97
- functools.partial(
98
- estimate_memory_from_batchsize,
99
- model,
100
- max_input_length,
101
- max_output_length,
102
- num_logits_per_token,
103
- ),
104
- )
105
-
106
- lo = 0
107
- hi = 0
108
- for candidate_exp in itertools.count():
109
- lo = hi
110
- hi = 2**candidate_exp
111
-
112
- if progress is not None:
113
- progress(EstimateBatchsizeFromMemoryEvent(lo, None))
114
- if target_mem < mem_for_bs(hi):
115
- break
116
-
117
- while hi - lo > 1:
118
- mid = (lo + hi) // 2
119
-
120
- if progress is not None:
121
- progress(EstimateBatchsizeFromMemoryEvent(lo, hi))
122
- if target_mem < mem_for_bs(mid):
123
- hi = mid
124
- else:
125
- lo = mid
126
-
127
- return lo
File without changes