lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +247 -14
- lalamo/common.py +33 -0
- lalamo/data/__init__.py +3 -2
- lalamo/data/huggingface_message.py +4 -5
- lalamo/main.py +274 -9
- lalamo/message_processor.py +19 -1
- lalamo/model_import/common.py +17 -1
- lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo/model_import/remote_registry.py +44 -0
- lalamo/models/__init__.py +3 -0
- lalamo/models/common.py +22 -0
- lalamo/models/compile_helpers.py +58 -0
- lalamo/models/language_model.py +342 -56
- lalamo/models/lm_helpers.py +198 -0
- lalamo/modules/decoder.py +4 -0
- lalamo/modules/token_mixers/mamba.py +345 -105
- lalamo/speculator/__init__.py +0 -2
- lalamo/speculator/inference.py +35 -61
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/METADATA +1 -1
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/RECORD +25 -23
- lalamo/speculator/estimator.py +0 -127
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/WHEEL +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/top_level.txt +0 -0
lalamo/speculator/inference.py
CHANGED
|
@@ -1,15 +1,12 @@
|
|
|
1
|
-
import functools
|
|
2
1
|
from collections.abc import Callable, Iterable
|
|
3
|
-
from itertools import
|
|
2
|
+
from itertools import chain
|
|
4
3
|
from typing import NamedTuple
|
|
5
4
|
|
|
6
|
-
import jax
|
|
7
|
-
import jax.numpy as jnp
|
|
8
|
-
|
|
9
5
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
10
6
|
from lalamo.data.utils import get_prefixes_ending_in_user_message
|
|
11
7
|
from lalamo.message_processor import Message
|
|
12
8
|
from lalamo.models import LanguageModel
|
|
9
|
+
from lalamo.models.common import InferenceConfig
|
|
13
10
|
|
|
14
11
|
|
|
15
12
|
class CollectTracesEvent(NamedTuple):
|
|
@@ -27,75 +24,52 @@ def inference_collect_traces(
|
|
|
27
24
|
tokens_to_generate: int | None = None,
|
|
28
25
|
progress_callback: Callable[[CollectTracesEvent], None] | None = None,
|
|
29
26
|
) -> Iterable[LalamoCompletion]:
|
|
30
|
-
generate_tokens_compiled = (
|
|
31
|
-
jax.jit(
|
|
32
|
-
functools.partial(
|
|
33
|
-
LanguageModel.generate_tokens,
|
|
34
|
-
max_output_length=max_output_length,
|
|
35
|
-
num_top_logits_to_return=num_top_logits_to_collect,
|
|
36
|
-
),
|
|
37
|
-
)
|
|
38
|
-
.lower(
|
|
39
|
-
model,
|
|
40
|
-
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
41
|
-
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
42
|
-
)
|
|
43
|
-
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
44
|
-
# 0 - no autotune, gpu shouldn't be touched
|
|
45
|
-
# 1 - basic level, gpu should be touched veeery little
|
|
46
|
-
# 2,3 - gpu touched more and more
|
|
47
|
-
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
48
|
-
.compile(compiler_options={"xla_gpu_autotune_level": "2"})
|
|
49
|
-
)
|
|
50
|
-
|
|
51
27
|
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
52
|
-
|
|
53
28
|
tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
|
|
54
29
|
filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
|
|
30
|
+
filtered_prefixes = list(filtered_prefixes) # eagerly materialize the prompts into RAM
|
|
55
31
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
length_without_padding = jnp.array(list(map(len, batch)))
|
|
63
|
-
|
|
64
|
-
padded = jnp.array(
|
|
65
|
-
[jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0) for tokens in batch],
|
|
66
|
-
)
|
|
32
|
+
config = InferenceConfig(
|
|
33
|
+
max_output_length=max_output_length,
|
|
34
|
+
num_top_logits_to_return=num_top_logits_to_collect,
|
|
35
|
+
padded_length=max_input_length,
|
|
36
|
+
batch_size=batch_size,
|
|
37
|
+
)
|
|
67
38
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
39
|
+
tokens_generated = 0
|
|
40
|
+
|
|
41
|
+
for idx, generated in enumerate(
|
|
42
|
+
model.generate_tokens_many(
|
|
43
|
+
filtered_prefixes,
|
|
44
|
+
inference_config=config,
|
|
45
|
+
),
|
|
46
|
+
):
|
|
47
|
+
token_ids = generated.token_ids.tolist()
|
|
48
|
+
seqlen = next(
|
|
49
|
+
(i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids),
|
|
50
|
+
len(token_ids),
|
|
72
51
|
)
|
|
73
52
|
|
|
74
|
-
|
|
53
|
+
if tokens_to_generate is not None:
|
|
54
|
+
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
75
55
|
|
|
76
|
-
|
|
77
|
-
token_ids = generated.token_ids[conv_idx].tolist()
|
|
78
|
-
seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
|
|
79
|
-
if tokens_to_generate is not None:
|
|
80
|
-
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
81
|
-
tokens_generated += seqlen
|
|
82
|
-
sequences_processed += 1
|
|
56
|
+
tokens_generated += seqlen
|
|
83
57
|
|
|
84
|
-
|
|
85
|
-
token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
|
|
86
|
-
token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
|
|
87
|
-
token_logits = [
|
|
88
|
-
dict(zip(keys, values, strict=True))
|
|
89
|
-
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
90
|
-
]
|
|
58
|
+
token_ids = token_ids[:seqlen]
|
|
91
59
|
|
|
92
|
-
|
|
60
|
+
assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
|
|
61
|
+
token_logits_ids = generated.top_k_token_ids[:seqlen].tolist()
|
|
62
|
+
token_logits_values = generated.top_k_token_logits[:seqlen].tolist()
|
|
63
|
+
token_logits = [
|
|
64
|
+
dict(zip(keys, values, strict=True))
|
|
65
|
+
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
66
|
+
]
|
|
93
67
|
|
|
94
|
-
|
|
95
|
-
|
|
68
|
+
# We need the original prompt tokens - get from indexed_inputs
|
|
69
|
+
yield LalamoCompletion(filtered_prefixes[idx], token_ids, token_logits)
|
|
96
70
|
|
|
97
71
|
if progress_callback is not None:
|
|
98
|
-
progress_callback(CollectTracesEvent(
|
|
72
|
+
progress_callback(CollectTracesEvent(idx + 1, tokens_generated))
|
|
99
73
|
|
|
100
74
|
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
101
75
|
break
|
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
2
|
-
lalamo/commands.py,sha256=
|
|
3
|
-
lalamo/common.py,sha256=
|
|
4
|
-
lalamo/main.py,sha256=
|
|
5
|
-
lalamo/message_processor.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=FFdoG3pwkVS9Xi1X_aCiByG9QyzSP_NufH2lIhj60EY,1532
|
|
2
|
+
lalamo/commands.py,sha256=3zUZE2bg39XFFGh5PQ-L2oIs73GPsp0EkFinMZQNHro,18097
|
|
3
|
+
lalamo/common.py,sha256=odAWGfzRRfsKJPtVekOLTlN3SrHcMerPzMk4BfVjn9I,5262
|
|
4
|
+
lalamo/main.py,sha256=V6VwWlo7QO5RH5HpsbpxQgDeOGFv2Y4abhJOjHFw2MA,37153
|
|
5
|
+
lalamo/message_processor.py,sha256=gf-CiidoRp1XmLdy8jkv06Gg0Nqe_DAlpYObOF9JfpA,6490
|
|
6
6
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
7
7
|
lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
|
|
8
8
|
lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
|
|
9
9
|
lalamo/sampling.py,sha256=GE6Av7zS-pr5Bg7FtOivRce7I0JIYuNYqfqsRe-yjQk,3867
|
|
10
10
|
lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
|
|
11
|
-
lalamo/data/__init__.py,sha256=
|
|
12
|
-
lalamo/data/huggingface_message.py,sha256
|
|
11
|
+
lalamo/data/__init__.py,sha256=QH23n37CWLcY69oLVE8gNNr6aJ57G0D_ZsO8eYs7-Jk,225
|
|
12
|
+
lalamo/data/huggingface_message.py,sha256=8oTCxL_IOHRCVgQneRv52sgJNprsiAIPvps5nBu6LWo,1037
|
|
13
13
|
lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
|
|
14
14
|
lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
|
|
15
15
|
lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
|
|
16
|
-
lalamo/model_import/common.py,sha256=
|
|
16
|
+
lalamo/model_import/common.py,sha256=evZeeizFev2i7Whd9X7sgUQV8v5apjfmy_BhshnFbyo,13011
|
|
17
17
|
lalamo/model_import/huggingface_generation_config.py,sha256=xicv_kJOfIGlz4gi5fRFIkiAZ9_QRDLRtW8nKMm5tVU,2022
|
|
18
18
|
lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
|
|
19
|
+
lalamo/model_import/remote_registry.py,sha256=4VjZSwlYMqflMfSSPi7-GSb9tmTLMZELzXoJJ3Tsx5s,1045
|
|
19
20
|
lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
|
|
20
21
|
lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
|
|
21
22
|
lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDjOTrWevfyw7QbWF1mUdOQ,5924
|
|
@@ -47,20 +48,22 @@ lalamo/model_import/model_specs/lfm2.py,sha256=wg4Ggt6BbMO4ScJ6h8tjvBc3IVSrMudES
|
|
|
47
48
|
lalamo/model_import/model_specs/llama.py,sha256=TxhKbIBFmGV2NopOg_k3ltsKlJccbxKyu-GQ7hYWCyw,3140
|
|
48
49
|
lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
|
|
49
50
|
lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
|
|
50
|
-
lalamo/model_import/model_specs/mistral.py,sha256=
|
|
51
|
+
lalamo/model_import/model_specs/mistral.py,sha256=i616AQg876PP9GHqHwaFIc29rUlsuhs0Z8_p0wv9eYg,1479
|
|
51
52
|
lalamo/model_import/model_specs/pleias.py,sha256=5sRpZGYwLdsav6bLiW-459y1Cs9iJKgKkBIuGsOxtsQ,368
|
|
52
53
|
lalamo/model_import/model_specs/polaris.py,sha256=Mw1-6bByjDmPIKlIUIV46CsmV5xUp_laI5Qquo5DmAQ,520
|
|
53
54
|
lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiEIhWxUdTRfo,7563
|
|
54
55
|
lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
|
|
55
|
-
lalamo/models/__init__.py,sha256=
|
|
56
|
+
lalamo/models/__init__.py,sha256=XMYuKSsiiIQUOq-ZtjIJcaIjTeCMYaO9bKJ9kvvLq98,394
|
|
56
57
|
lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
|
|
57
|
-
lalamo/models/common.py,sha256=
|
|
58
|
-
lalamo/models/
|
|
58
|
+
lalamo/models/common.py,sha256=8gMDvu0JXNejRslhzdurrPAS3ZymcR2Grq1RRpddc4M,3402
|
|
59
|
+
lalamo/models/compile_helpers.py,sha256=t_rGCznSAQC2W4ioGZUg4Oc7lpTL6VfutKtOZ06qfXo,2227
|
|
60
|
+
lalamo/models/language_model.py,sha256=YL86--CwI-T7h4ymCk3DXZ5Cswq3OCn_7wJGfYI6swk,26113
|
|
61
|
+
lalamo/models/lm_helpers.py,sha256=rocQ184MCF5gnFwLbzWR7mDrV6b-0VxvOkqbhPxsCKE,6590
|
|
59
62
|
lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
|
|
60
63
|
lalamo/modules/activations.py,sha256=25F4XytJMIwPPmUbxiDUrcrdUi4c-O9SUbwv9lnZbuU,992
|
|
61
64
|
lalamo/modules/classifier.py,sha256=Q5eNzJ68to6JGk8IDZiKv6Rmwh15UyT2xC52tP5njoQ,11767
|
|
62
65
|
lalamo/modules/common.py,sha256=Rc9zenrUMntDKZydI1tzt1ZIY8ggfyk3ZDB-xi81ibw,3406
|
|
63
|
-
lalamo/modules/decoder.py,sha256=
|
|
66
|
+
lalamo/modules/decoder.py,sha256=zC4IlSzBeEbHiAlGCl8TGCBqGLVtXb_FrJuC9cPwYqo,7103
|
|
64
67
|
lalamo/modules/embedding.py,sha256=PdNy4tGt9F1zve4X73WKNS0DXL-nHUFOlZmGFUAarkQ,27727
|
|
65
68
|
lalamo/modules/linear.py,sha256=4xIhmeouD7R10lt8KJBLxgypVXYhpGmXdHUc-96Upfk,42871
|
|
66
69
|
lalamo/modules/mlp.py,sha256=ogxi9q8J38FnuBkAtC7_KTMc7JZG4BRdsAHYprHZNvM,17690
|
|
@@ -74,22 +77,21 @@ lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
|
|
|
74
77
|
lalamo/modules/token_mixers/__init__.py,sha256=lwxUl0eG5IvuVc_HOsINP2vtbv9F0cUmSNHFHaEmPGk,1109
|
|
75
78
|
lalamo/modules/token_mixers/attention.py,sha256=ielw1-KWBfCPCPmzSHgM0TaSUcmSkWKTxrN3N_FsGm4,16144
|
|
76
79
|
lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
|
|
77
|
-
lalamo/modules/token_mixers/mamba.py,sha256=
|
|
80
|
+
lalamo/modules/token_mixers/mamba.py,sha256=EFyuAEjp6pNwOriesFnulOafyGRHYgqotou6UE55axE,28945
|
|
78
81
|
lalamo/modules/token_mixers/short_conv.py,sha256=k1z9UwcJGag2NHWad7cYiAnhxULtmva9RrdhqVbir18,5085
|
|
79
82
|
lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
|
|
80
83
|
lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
|
|
81
84
|
lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
|
|
82
85
|
lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxbnMJCl--JOoe9Fkcc9Mg,1642
|
|
83
86
|
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
84
|
-
lalamo/speculator/__init__.py,sha256=
|
|
87
|
+
lalamo/speculator/__init__.py,sha256=Ye3gMhrtNxaWPMzWbXFqKX7Rv32LGlT2k9eX2uvifKg,364
|
|
85
88
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
86
|
-
lalamo/speculator/
|
|
87
|
-
lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
|
|
89
|
+
lalamo/speculator/inference.py,sha256=-RfgtdwMU4-EGnzc7oT8zJEhiA_Md03rPMyyi_BF26k,2792
|
|
88
90
|
lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
|
|
89
91
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
90
|
-
lalamo-0.6.
|
|
91
|
-
lalamo-0.6.
|
|
92
|
-
lalamo-0.6.
|
|
93
|
-
lalamo-0.6.
|
|
94
|
-
lalamo-0.6.
|
|
95
|
-
lalamo-0.6.
|
|
92
|
+
lalamo-0.6.6.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
93
|
+
lalamo-0.6.6.dist-info/METADATA,sha256=Yc0I-RS-xekkjmN-Y5LQf-0R3gKuIeuFmZ3hPXJh4tY,3112
|
|
94
|
+
lalamo-0.6.6.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
95
|
+
lalamo-0.6.6.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
96
|
+
lalamo-0.6.6.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
97
|
+
lalamo-0.6.6.dist-info/RECORD,,
|
lalamo/speculator/estimator.py
DELETED
|
@@ -1,127 +0,0 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
import itertools
|
|
3
|
-
import os
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from typing import NamedTuple
|
|
6
|
-
|
|
7
|
-
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
|
|
10
|
-
from lalamo.models import LanguageModel
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def get_default_device_bytes() -> int | None:
|
|
14
|
-
dynamic_allocate = False
|
|
15
|
-
|
|
16
|
-
preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
|
|
17
|
-
dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
|
|
18
|
-
|
|
19
|
-
allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
|
|
20
|
-
dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
|
|
21
|
-
|
|
22
|
-
if dynamic_allocate:
|
|
23
|
-
return None
|
|
24
|
-
|
|
25
|
-
memory_stats = jax.local_devices()[0].memory_stats()
|
|
26
|
-
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
27
|
-
return None
|
|
28
|
-
|
|
29
|
-
mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
|
|
30
|
-
try:
|
|
31
|
-
mem_fraction = float(mem_fraction_raw)
|
|
32
|
-
except ValueError:
|
|
33
|
-
mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
|
|
34
|
-
|
|
35
|
-
# 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
|
|
36
|
-
# so it should correspond to something you'd see in nvidia-smi
|
|
37
|
-
memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
|
|
38
|
-
|
|
39
|
-
return get_usable_memory_from_bytes(memory_limit)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def get_usable_memory_from_bytes(limit_bytes: int) -> int:
|
|
43
|
-
# JAX allocates a bit more than it needs, so we discount it by some safety factor
|
|
44
|
-
return int(limit_bytes * 0.95)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def estimate_memory_from_batchsize(
|
|
48
|
-
model: LanguageModel,
|
|
49
|
-
max_input_length: int,
|
|
50
|
-
max_output_length: int,
|
|
51
|
-
num_logits_per_token: int,
|
|
52
|
-
batch_size: int,
|
|
53
|
-
) -> int:
|
|
54
|
-
memory_analysis = (
|
|
55
|
-
jax.jit(
|
|
56
|
-
functools.partial(
|
|
57
|
-
LanguageModel.generate_tokens,
|
|
58
|
-
max_output_length=max_output_length,
|
|
59
|
-
num_top_logits_to_return=num_logits_per_token,
|
|
60
|
-
),
|
|
61
|
-
)
|
|
62
|
-
.lower(
|
|
63
|
-
model,
|
|
64
|
-
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
65
|
-
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
66
|
-
)
|
|
67
|
-
# disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
|
|
68
|
-
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
69
|
-
.memory_analysis()
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
assert hasattr(memory_analysis, "argument_size_in_bytes")
|
|
73
|
-
assert hasattr(memory_analysis, "output_size_in_bytes")
|
|
74
|
-
assert hasattr(memory_analysis, "temp_size_in_bytes")
|
|
75
|
-
|
|
76
|
-
return (
|
|
77
|
-
memory_analysis.argument_size_in_bytes
|
|
78
|
-
+ memory_analysis.output_size_in_bytes
|
|
79
|
-
+ memory_analysis.temp_size_in_bytes
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class EstimateBatchsizeFromMemoryEvent(NamedTuple):
|
|
84
|
-
lo: int
|
|
85
|
-
hi: int | None
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def estimate_batchsize_from_memory(
|
|
89
|
-
model: LanguageModel,
|
|
90
|
-
max_input_length: int,
|
|
91
|
-
max_output_length: int,
|
|
92
|
-
num_logits_per_token: int,
|
|
93
|
-
target_mem: int,
|
|
94
|
-
progress: Callable[[EstimateBatchsizeFromMemoryEvent], None] | None = None,
|
|
95
|
-
) -> int:
|
|
96
|
-
mem_for_bs = functools.cache(
|
|
97
|
-
functools.partial(
|
|
98
|
-
estimate_memory_from_batchsize,
|
|
99
|
-
model,
|
|
100
|
-
max_input_length,
|
|
101
|
-
max_output_length,
|
|
102
|
-
num_logits_per_token,
|
|
103
|
-
),
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
lo = 0
|
|
107
|
-
hi = 0
|
|
108
|
-
for candidate_exp in itertools.count():
|
|
109
|
-
lo = hi
|
|
110
|
-
hi = 2**candidate_exp
|
|
111
|
-
|
|
112
|
-
if progress is not None:
|
|
113
|
-
progress(EstimateBatchsizeFromMemoryEvent(lo, None))
|
|
114
|
-
if target_mem < mem_for_bs(hi):
|
|
115
|
-
break
|
|
116
|
-
|
|
117
|
-
while hi - lo > 1:
|
|
118
|
-
mid = (lo + hi) // 2
|
|
119
|
-
|
|
120
|
-
if progress is not None:
|
|
121
|
-
progress(EstimateBatchsizeFromMemoryEvent(lo, hi))
|
|
122
|
-
if target_mem < mem_for_bs(mid):
|
|
123
|
-
hi = mid
|
|
124
|
-
else:
|
|
125
|
-
lo = mid
|
|
126
|
-
|
|
127
|
-
return lo
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|