lalamo 0.6.5__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +247 -14
- lalamo/common.py +27 -48
- lalamo/data/__init__.py +3 -2
- lalamo/data/huggingface_message.py +4 -5
- lalamo/main.py +274 -9
- lalamo/message_processor.py +19 -1
- lalamo/model_import/common.py +17 -1
- lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo/model_import/remote_registry.py +44 -0
- lalamo/models/__init__.py +3 -0
- lalamo/models/common.py +22 -0
- lalamo/models/compile_helpers.py +58 -0
- lalamo/models/language_model.py +342 -56
- lalamo/models/lm_helpers.py +198 -0
- lalamo/modules/decoder.py +4 -0
- lalamo/modules/token_mixers/mamba.py +345 -105
- lalamo/speculator/__init__.py +0 -2
- lalamo/speculator/inference.py +41 -78
- {lalamo-0.6.5.dist-info → lalamo-0.6.6.dist-info}/METADATA +1 -1
- {lalamo-0.6.5.dist-info → lalamo-0.6.6.dist-info}/RECORD +25 -23
- lalamo/speculator/estimator.py +0 -127
- {lalamo-0.6.5.dist-info → lalamo-0.6.6.dist-info}/WHEEL +0 -0
- {lalamo-0.6.5.dist-info → lalamo-0.6.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.5.dist-info → lalamo-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.5.dist-info → lalamo-0.6.6.dist-info}/top_level.txt +0 -0
lalamo/speculator/inference.py
CHANGED
|
@@ -1,17 +1,12 @@
|
|
|
1
|
-
import functools
|
|
2
1
|
from collections.abc import Callable, Iterable
|
|
3
|
-
from itertools import
|
|
2
|
+
from itertools import chain
|
|
4
3
|
from typing import NamedTuple
|
|
5
4
|
|
|
6
|
-
import jax
|
|
7
|
-
import jax.numpy as jnp
|
|
8
|
-
from jax._src.stages import Compiled
|
|
9
|
-
|
|
10
|
-
from lalamo.common import decrease_batchsize_on_oom
|
|
11
5
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
12
6
|
from lalamo.data.utils import get_prefixes_ending_in_user_message
|
|
13
7
|
from lalamo.message_processor import Message
|
|
14
8
|
from lalamo.models import LanguageModel
|
|
9
|
+
from lalamo.models.common import InferenceConfig
|
|
15
10
|
|
|
16
11
|
|
|
17
12
|
class CollectTracesEvent(NamedTuple):
|
|
@@ -29,84 +24,52 @@ def inference_collect_traces(
|
|
|
29
24
|
tokens_to_generate: int | None = None,
|
|
30
25
|
progress_callback: Callable[[CollectTracesEvent], None] | None = None,
|
|
31
26
|
) -> Iterable[LalamoCompletion]:
|
|
32
|
-
def make_generate_tokens_compiled(batch_size: int) -> Compiled:
|
|
33
|
-
return (
|
|
34
|
-
jax.jit(
|
|
35
|
-
functools.partial(
|
|
36
|
-
LanguageModel.generate_tokens,
|
|
37
|
-
max_output_length=max_output_length,
|
|
38
|
-
num_top_logits_to_return=num_top_logits_to_collect,
|
|
39
|
-
),
|
|
40
|
-
)
|
|
41
|
-
.lower(
|
|
42
|
-
model,
|
|
43
|
-
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
44
|
-
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
45
|
-
)
|
|
46
|
-
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
47
|
-
# 0 - no autotune, gpu shouldn't be touched
|
|
48
|
-
# 1 - basic level, gpu should be touched veeery little
|
|
49
|
-
# 2,3 - gpu touched more and more
|
|
50
|
-
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
51
|
-
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
52
|
-
)
|
|
53
|
-
|
|
54
27
|
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
55
|
-
|
|
56
28
|
tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
|
|
57
29
|
filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
|
|
30
|
+
filtered_prefixes = list(filtered_prefixes) # eagerly materialize the prompts into RAM
|
|
31
|
+
|
|
32
|
+
config = InferenceConfig(
|
|
33
|
+
max_output_length=max_output_length,
|
|
34
|
+
num_top_logits_to_return=num_top_logits_to_collect,
|
|
35
|
+
padded_length=max_input_length,
|
|
36
|
+
batch_size=batch_size,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
tokens_generated = 0
|
|
40
|
+
|
|
41
|
+
for idx, generated in enumerate(
|
|
42
|
+
model.generate_tokens_many(
|
|
43
|
+
filtered_prefixes,
|
|
44
|
+
inference_config=config,
|
|
45
|
+
),
|
|
46
|
+
):
|
|
47
|
+
token_ids = generated.token_ids.tolist()
|
|
48
|
+
seqlen = next(
|
|
49
|
+
(i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids),
|
|
50
|
+
len(token_ids),
|
|
51
|
+
)
|
|
58
52
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
|
|
62
|
-
tokens_generated, sequences_processed = 0, 0
|
|
63
|
-
generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
|
|
64
|
-
for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
|
|
65
|
-
batch_padding = batch_size - len(real_batch)
|
|
66
|
-
batch = (*real_batch, *(([0],) * batch_padding))
|
|
67
|
-
|
|
68
|
-
length_without_padding = jnp.array(list(map(len, batch)))
|
|
69
|
-
|
|
70
|
-
padded = jnp.array(
|
|
71
|
-
[
|
|
72
|
-
jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
|
|
73
|
-
for tokens in batch
|
|
74
|
-
],
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
generated = generate_tokens_compiled(
|
|
78
|
-
model,
|
|
79
|
-
prompt_token_ids=padded,
|
|
80
|
-
prompt_lengths_without_padding=length_without_padding,
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
|
|
84
|
-
|
|
85
|
-
for conv_idx in range(len(real_batch)):
|
|
86
|
-
token_ids = generated.token_ids[conv_idx].tolist()
|
|
87
|
-
seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
|
|
88
|
-
if tokens_to_generate is not None:
|
|
89
|
-
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
90
|
-
tokens_generated += seqlen
|
|
91
|
-
sequences_processed += 1
|
|
53
|
+
if tokens_to_generate is not None:
|
|
54
|
+
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
92
55
|
|
|
93
|
-
|
|
94
|
-
token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
|
|
95
|
-
token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
|
|
96
|
-
token_logits = [
|
|
97
|
-
dict(zip(keys, values, strict=True))
|
|
98
|
-
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
99
|
-
]
|
|
56
|
+
tokens_generated += seqlen
|
|
100
57
|
|
|
101
|
-
|
|
58
|
+
token_ids = token_ids[:seqlen]
|
|
102
59
|
|
|
103
|
-
|
|
104
|
-
|
|
60
|
+
assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
|
|
61
|
+
token_logits_ids = generated.top_k_token_ids[:seqlen].tolist()
|
|
62
|
+
token_logits_values = generated.top_k_token_logits[:seqlen].tolist()
|
|
63
|
+
token_logits = [
|
|
64
|
+
dict(zip(keys, values, strict=True))
|
|
65
|
+
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
66
|
+
]
|
|
105
67
|
|
|
106
|
-
|
|
107
|
-
|
|
68
|
+
# We need the original prompt tokens - get from indexed_inputs
|
|
69
|
+
yield LalamoCompletion(filtered_prefixes[idx], token_ids, token_logits)
|
|
108
70
|
|
|
109
|
-
|
|
110
|
-
|
|
71
|
+
if progress_callback is not None:
|
|
72
|
+
progress_callback(CollectTracesEvent(idx + 1, tokens_generated))
|
|
111
73
|
|
|
112
|
-
|
|
74
|
+
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
75
|
+
break
|
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
2
|
-
lalamo/commands.py,sha256=
|
|
3
|
-
lalamo/common.py,sha256=
|
|
4
|
-
lalamo/main.py,sha256=
|
|
5
|
-
lalamo/message_processor.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=FFdoG3pwkVS9Xi1X_aCiByG9QyzSP_NufH2lIhj60EY,1532
|
|
2
|
+
lalamo/commands.py,sha256=3zUZE2bg39XFFGh5PQ-L2oIs73GPsp0EkFinMZQNHro,18097
|
|
3
|
+
lalamo/common.py,sha256=odAWGfzRRfsKJPtVekOLTlN3SrHcMerPzMk4BfVjn9I,5262
|
|
4
|
+
lalamo/main.py,sha256=V6VwWlo7QO5RH5HpsbpxQgDeOGFv2Y4abhJOjHFw2MA,37153
|
|
5
|
+
lalamo/message_processor.py,sha256=gf-CiidoRp1XmLdy8jkv06Gg0Nqe_DAlpYObOF9JfpA,6490
|
|
6
6
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
7
7
|
lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
|
|
8
8
|
lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
|
|
9
9
|
lalamo/sampling.py,sha256=GE6Av7zS-pr5Bg7FtOivRce7I0JIYuNYqfqsRe-yjQk,3867
|
|
10
10
|
lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
|
|
11
|
-
lalamo/data/__init__.py,sha256=
|
|
12
|
-
lalamo/data/huggingface_message.py,sha256
|
|
11
|
+
lalamo/data/__init__.py,sha256=QH23n37CWLcY69oLVE8gNNr6aJ57G0D_ZsO8eYs7-Jk,225
|
|
12
|
+
lalamo/data/huggingface_message.py,sha256=8oTCxL_IOHRCVgQneRv52sgJNprsiAIPvps5nBu6LWo,1037
|
|
13
13
|
lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
|
|
14
14
|
lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
|
|
15
15
|
lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
|
|
16
|
-
lalamo/model_import/common.py,sha256=
|
|
16
|
+
lalamo/model_import/common.py,sha256=evZeeizFev2i7Whd9X7sgUQV8v5apjfmy_BhshnFbyo,13011
|
|
17
17
|
lalamo/model_import/huggingface_generation_config.py,sha256=xicv_kJOfIGlz4gi5fRFIkiAZ9_QRDLRtW8nKMm5tVU,2022
|
|
18
18
|
lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
|
|
19
|
+
lalamo/model_import/remote_registry.py,sha256=4VjZSwlYMqflMfSSPi7-GSb9tmTLMZELzXoJJ3Tsx5s,1045
|
|
19
20
|
lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
|
|
20
21
|
lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
|
|
21
22
|
lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDjOTrWevfyw7QbWF1mUdOQ,5924
|
|
@@ -47,20 +48,22 @@ lalamo/model_import/model_specs/lfm2.py,sha256=wg4Ggt6BbMO4ScJ6h8tjvBc3IVSrMudES
|
|
|
47
48
|
lalamo/model_import/model_specs/llama.py,sha256=TxhKbIBFmGV2NopOg_k3ltsKlJccbxKyu-GQ7hYWCyw,3140
|
|
48
49
|
lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
|
|
49
50
|
lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
|
|
50
|
-
lalamo/model_import/model_specs/mistral.py,sha256=
|
|
51
|
+
lalamo/model_import/model_specs/mistral.py,sha256=i616AQg876PP9GHqHwaFIc29rUlsuhs0Z8_p0wv9eYg,1479
|
|
51
52
|
lalamo/model_import/model_specs/pleias.py,sha256=5sRpZGYwLdsav6bLiW-459y1Cs9iJKgKkBIuGsOxtsQ,368
|
|
52
53
|
lalamo/model_import/model_specs/polaris.py,sha256=Mw1-6bByjDmPIKlIUIV46CsmV5xUp_laI5Qquo5DmAQ,520
|
|
53
54
|
lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiEIhWxUdTRfo,7563
|
|
54
55
|
lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
|
|
55
|
-
lalamo/models/__init__.py,sha256=
|
|
56
|
+
lalamo/models/__init__.py,sha256=XMYuKSsiiIQUOq-ZtjIJcaIjTeCMYaO9bKJ9kvvLq98,394
|
|
56
57
|
lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
|
|
57
|
-
lalamo/models/common.py,sha256=
|
|
58
|
-
lalamo/models/
|
|
58
|
+
lalamo/models/common.py,sha256=8gMDvu0JXNejRslhzdurrPAS3ZymcR2Grq1RRpddc4M,3402
|
|
59
|
+
lalamo/models/compile_helpers.py,sha256=t_rGCznSAQC2W4ioGZUg4Oc7lpTL6VfutKtOZ06qfXo,2227
|
|
60
|
+
lalamo/models/language_model.py,sha256=YL86--CwI-T7h4ymCk3DXZ5Cswq3OCn_7wJGfYI6swk,26113
|
|
61
|
+
lalamo/models/lm_helpers.py,sha256=rocQ184MCF5gnFwLbzWR7mDrV6b-0VxvOkqbhPxsCKE,6590
|
|
59
62
|
lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
|
|
60
63
|
lalamo/modules/activations.py,sha256=25F4XytJMIwPPmUbxiDUrcrdUi4c-O9SUbwv9lnZbuU,992
|
|
61
64
|
lalamo/modules/classifier.py,sha256=Q5eNzJ68to6JGk8IDZiKv6Rmwh15UyT2xC52tP5njoQ,11767
|
|
62
65
|
lalamo/modules/common.py,sha256=Rc9zenrUMntDKZydI1tzt1ZIY8ggfyk3ZDB-xi81ibw,3406
|
|
63
|
-
lalamo/modules/decoder.py,sha256=
|
|
66
|
+
lalamo/modules/decoder.py,sha256=zC4IlSzBeEbHiAlGCl8TGCBqGLVtXb_FrJuC9cPwYqo,7103
|
|
64
67
|
lalamo/modules/embedding.py,sha256=PdNy4tGt9F1zve4X73WKNS0DXL-nHUFOlZmGFUAarkQ,27727
|
|
65
68
|
lalamo/modules/linear.py,sha256=4xIhmeouD7R10lt8KJBLxgypVXYhpGmXdHUc-96Upfk,42871
|
|
66
69
|
lalamo/modules/mlp.py,sha256=ogxi9q8J38FnuBkAtC7_KTMc7JZG4BRdsAHYprHZNvM,17690
|
|
@@ -74,22 +77,21 @@ lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
|
|
|
74
77
|
lalamo/modules/token_mixers/__init__.py,sha256=lwxUl0eG5IvuVc_HOsINP2vtbv9F0cUmSNHFHaEmPGk,1109
|
|
75
78
|
lalamo/modules/token_mixers/attention.py,sha256=ielw1-KWBfCPCPmzSHgM0TaSUcmSkWKTxrN3N_FsGm4,16144
|
|
76
79
|
lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
|
|
77
|
-
lalamo/modules/token_mixers/mamba.py,sha256=
|
|
80
|
+
lalamo/modules/token_mixers/mamba.py,sha256=EFyuAEjp6pNwOriesFnulOafyGRHYgqotou6UE55axE,28945
|
|
78
81
|
lalamo/modules/token_mixers/short_conv.py,sha256=k1z9UwcJGag2NHWad7cYiAnhxULtmva9RrdhqVbir18,5085
|
|
79
82
|
lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
|
|
80
83
|
lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
|
|
81
84
|
lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
|
|
82
85
|
lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxbnMJCl--JOoe9Fkcc9Mg,1642
|
|
83
86
|
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
84
|
-
lalamo/speculator/__init__.py,sha256=
|
|
87
|
+
lalamo/speculator/__init__.py,sha256=Ye3gMhrtNxaWPMzWbXFqKX7Rv32LGlT2k9eX2uvifKg,364
|
|
85
88
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
86
|
-
lalamo/speculator/
|
|
87
|
-
lalamo/speculator/inference.py,sha256=47TUiLV0Dkk3dbf1-IkdlWbHCICFw6IDwKZ73FYQUQo,4802
|
|
89
|
+
lalamo/speculator/inference.py,sha256=-RfgtdwMU4-EGnzc7oT8zJEhiA_Md03rPMyyi_BF26k,2792
|
|
88
90
|
lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
|
|
89
91
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
90
|
-
lalamo-0.6.
|
|
91
|
-
lalamo-0.6.
|
|
92
|
-
lalamo-0.6.
|
|
93
|
-
lalamo-0.6.
|
|
94
|
-
lalamo-0.6.
|
|
95
|
-
lalamo-0.6.
|
|
92
|
+
lalamo-0.6.6.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
93
|
+
lalamo-0.6.6.dist-info/METADATA,sha256=Yc0I-RS-xekkjmN-Y5LQf-0R3gKuIeuFmZ3hPXJh4tY,3112
|
|
94
|
+
lalamo-0.6.6.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
95
|
+
lalamo-0.6.6.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
96
|
+
lalamo-0.6.6.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
97
|
+
lalamo-0.6.6.dist-info/RECORD,,
|
lalamo/speculator/estimator.py
DELETED
|
@@ -1,127 +0,0 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
import itertools
|
|
3
|
-
import os
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from typing import NamedTuple
|
|
6
|
-
|
|
7
|
-
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
|
|
10
|
-
from lalamo.models import LanguageModel
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def get_default_device_bytes() -> int | None:
|
|
14
|
-
dynamic_allocate = False
|
|
15
|
-
|
|
16
|
-
preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
|
|
17
|
-
dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
|
|
18
|
-
|
|
19
|
-
allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
|
|
20
|
-
dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
|
|
21
|
-
|
|
22
|
-
if dynamic_allocate:
|
|
23
|
-
return None
|
|
24
|
-
|
|
25
|
-
memory_stats = jax.local_devices()[0].memory_stats()
|
|
26
|
-
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
27
|
-
return None
|
|
28
|
-
|
|
29
|
-
mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
|
|
30
|
-
try:
|
|
31
|
-
mem_fraction = float(mem_fraction_raw)
|
|
32
|
-
except ValueError:
|
|
33
|
-
mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
|
|
34
|
-
|
|
35
|
-
# 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
|
|
36
|
-
# so it should correspond to something you'd see in nvidia-smi
|
|
37
|
-
memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
|
|
38
|
-
|
|
39
|
-
return get_usable_memory_from_bytes(memory_limit)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def get_usable_memory_from_bytes(limit_bytes: int) -> int:
|
|
43
|
-
# JAX allocates a bit more than it needs, so we discount it by some safety factor
|
|
44
|
-
return int(limit_bytes * 0.93)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def estimate_memory_from_batchsize(
|
|
48
|
-
model: LanguageModel,
|
|
49
|
-
max_input_length: int,
|
|
50
|
-
max_output_length: int,
|
|
51
|
-
num_logits_per_token: int,
|
|
52
|
-
batch_size: int,
|
|
53
|
-
) -> int:
|
|
54
|
-
memory_analysis = (
|
|
55
|
-
jax.jit(
|
|
56
|
-
functools.partial(
|
|
57
|
-
LanguageModel.generate_tokens,
|
|
58
|
-
max_output_length=max_output_length,
|
|
59
|
-
num_top_logits_to_return=num_logits_per_token,
|
|
60
|
-
),
|
|
61
|
-
)
|
|
62
|
-
.lower(
|
|
63
|
-
model,
|
|
64
|
-
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
65
|
-
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
66
|
-
)
|
|
67
|
-
# disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
|
|
68
|
-
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
69
|
-
.memory_analysis()
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
assert hasattr(memory_analysis, "argument_size_in_bytes")
|
|
73
|
-
assert hasattr(memory_analysis, "output_size_in_bytes")
|
|
74
|
-
assert hasattr(memory_analysis, "temp_size_in_bytes")
|
|
75
|
-
|
|
76
|
-
return (
|
|
77
|
-
memory_analysis.argument_size_in_bytes
|
|
78
|
-
+ memory_analysis.output_size_in_bytes
|
|
79
|
-
+ memory_analysis.temp_size_in_bytes
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class EstimateBatchsizeFromMemoryEvent(NamedTuple):
|
|
84
|
-
lo: int
|
|
85
|
-
hi: int | None
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def estimate_batchsize_from_memory(
|
|
89
|
-
model: LanguageModel,
|
|
90
|
-
max_input_length: int,
|
|
91
|
-
max_output_length: int,
|
|
92
|
-
num_logits_per_token: int,
|
|
93
|
-
target_mem: int,
|
|
94
|
-
progress: Callable[[EstimateBatchsizeFromMemoryEvent], None] | None = None,
|
|
95
|
-
) -> int:
|
|
96
|
-
mem_for_bs = functools.cache(
|
|
97
|
-
functools.partial(
|
|
98
|
-
estimate_memory_from_batchsize,
|
|
99
|
-
model,
|
|
100
|
-
max_input_length,
|
|
101
|
-
max_output_length,
|
|
102
|
-
num_logits_per_token,
|
|
103
|
-
),
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
lo = 0
|
|
107
|
-
hi = 0
|
|
108
|
-
for candidate_exp in itertools.count():
|
|
109
|
-
lo = hi
|
|
110
|
-
hi = 2**candidate_exp
|
|
111
|
-
|
|
112
|
-
if progress is not None:
|
|
113
|
-
progress(EstimateBatchsizeFromMemoryEvent(lo, None))
|
|
114
|
-
if target_mem < mem_for_bs(hi):
|
|
115
|
-
break
|
|
116
|
-
|
|
117
|
-
while hi - lo > 1:
|
|
118
|
-
mid = (lo + hi) // 2
|
|
119
|
-
|
|
120
|
-
if progress is not None:
|
|
121
|
-
progress(EstimateBatchsizeFromMemoryEvent(lo, hi))
|
|
122
|
-
if target_mem < mem_for_bs(mid):
|
|
123
|
-
hi = mid
|
|
124
|
-
else:
|
|
125
|
-
lo = mid
|
|
126
|
-
|
|
127
|
-
return lo
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|