lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,12 @@
1
- import functools
2
1
  from collections.abc import Callable, Iterable
3
- from itertools import batched, chain
2
+ from itertools import chain
4
3
  from typing import NamedTuple
5
4
 
6
- import jax
7
- import jax.numpy as jnp
8
-
9
5
  from lalamo.data.lalamo_completions import LalamoCompletion
10
6
  from lalamo.data.utils import get_prefixes_ending_in_user_message
11
7
  from lalamo.message_processor import Message
12
8
  from lalamo.models import LanguageModel
9
+ from lalamo.models.common import InferenceConfig
13
10
 
14
11
 
15
12
  class CollectTracesEvent(NamedTuple):
@@ -27,75 +24,52 @@ def inference_collect_traces(
27
24
  tokens_to_generate: int | None = None,
28
25
  progress_callback: Callable[[CollectTracesEvent], None] | None = None,
29
26
  ) -> Iterable[LalamoCompletion]:
30
- generate_tokens_compiled = (
31
- jax.jit(
32
- functools.partial(
33
- LanguageModel.generate_tokens,
34
- max_output_length=max_output_length,
35
- num_top_logits_to_return=num_top_logits_to_collect,
36
- ),
37
- )
38
- .lower(
39
- model,
40
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
- )
43
- # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
- # 0 - no autotune, gpu shouldn't be touched
45
- # 1 - basic level, gpu should be touched veeery little
46
- # 2,3 - gpu touched more and more
47
- # 4 (default) - gpu might allocate more memory than the run would require!
48
- .compile(compiler_options={"xla_gpu_autotune_level": "2"})
49
- )
50
-
51
27
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
52
-
53
28
  tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
54
29
  filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
30
+ filtered_prefixes = list(filtered_prefixes) # eagerly materialize the prompts into RAM
55
31
 
56
- tokens_generated, sequences_processed = 0, 0
57
-
58
- for real_batch in batched(filtered_prefixes, n=batch_size):
59
- batch_padding = batch_size - len(real_batch)
60
- batch = (*real_batch, *(([0],) * batch_padding))
61
-
62
- length_without_padding = jnp.array(list(map(len, batch)))
63
-
64
- padded = jnp.array(
65
- [jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0) for tokens in batch],
66
- )
32
+ config = InferenceConfig(
33
+ max_output_length=max_output_length,
34
+ num_top_logits_to_return=num_top_logits_to_collect,
35
+ padded_length=max_input_length,
36
+ batch_size=batch_size,
37
+ )
67
38
 
68
- generated = generate_tokens_compiled(
69
- model,
70
- prompt_token_ids=padded,
71
- prompt_lengths_without_padding=length_without_padding,
39
+ tokens_generated = 0
40
+
41
+ for idx, generated in enumerate(
42
+ model.generate_tokens_many(
43
+ filtered_prefixes,
44
+ inference_config=config,
45
+ ),
46
+ ):
47
+ token_ids = generated.token_ids.tolist()
48
+ seqlen = next(
49
+ (i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids),
50
+ len(token_ids),
72
51
  )
73
52
 
74
- assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
53
+ if tokens_to_generate is not None:
54
+ seqlen = min(seqlen, tokens_to_generate - tokens_generated)
75
55
 
76
- for conv_idx in range(len(real_batch)):
77
- token_ids = generated.token_ids[conv_idx].tolist()
78
- seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
79
- if tokens_to_generate is not None:
80
- seqlen = min(seqlen, tokens_to_generate - tokens_generated)
81
- tokens_generated += seqlen
82
- sequences_processed += 1
56
+ tokens_generated += seqlen
83
57
 
84
- token_ids = token_ids[:seqlen]
85
- token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
86
- token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
87
- token_logits = [
88
- dict(zip(keys, values, strict=True))
89
- for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
90
- ]
58
+ token_ids = token_ids[:seqlen]
91
59
 
92
- yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
60
+ assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
61
+ token_logits_ids = generated.top_k_token_ids[:seqlen].tolist()
62
+ token_logits_values = generated.top_k_token_logits[:seqlen].tolist()
63
+ token_logits = [
64
+ dict(zip(keys, values, strict=True))
65
+ for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
66
+ ]
93
67
 
94
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
95
- break
68
+ # We need the original prompt tokens - get from indexed_inputs
69
+ yield LalamoCompletion(filtered_prefixes[idx], token_ids, token_logits)
96
70
 
97
71
  if progress_callback is not None:
98
- progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
72
+ progress_callback(CollectTracesEvent(idx + 1, tokens_generated))
99
73
 
100
74
  if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
101
75
  break
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.4
3
+ Version: 0.6.6
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,21 +1,22 @@
1
- lalamo/__init__.py,sha256=RDkf5Hhglc-fLZ-CmI4R-th6UgJKYmN-1hdbCzTiVx8,1532
2
- lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
3
- lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
4
- lalamo/main.py,sha256=f1zHYQpX_OndAguOE0wqIOkzjzUolUC7w3_1ndtMC4Y,27655
5
- lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
1
+ lalamo/__init__.py,sha256=FFdoG3pwkVS9Xi1X_aCiByG9QyzSP_NufH2lIhj60EY,1532
2
+ lalamo/commands.py,sha256=3zUZE2bg39XFFGh5PQ-L2oIs73GPsp0EkFinMZQNHro,18097
3
+ lalamo/common.py,sha256=odAWGfzRRfsKJPtVekOLTlN3SrHcMerPzMk4BfVjn9I,5262
4
+ lalamo/main.py,sha256=V6VwWlo7QO5RH5HpsbpxQgDeOGFv2Y4abhJOjHFw2MA,37153
5
+ lalamo/message_processor.py,sha256=gf-CiidoRp1XmLdy8jkv06Gg0Nqe_DAlpYObOF9JfpA,6490
6
6
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
7
7
  lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
8
8
  lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
9
9
  lalamo/sampling.py,sha256=GE6Av7zS-pr5Bg7FtOivRce7I0JIYuNYqfqsRe-yjQk,3867
10
10
  lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
11
- lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
12
- lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
11
+ lalamo/data/__init__.py,sha256=QH23n37CWLcY69oLVE8gNNr6aJ57G0D_ZsO8eYs7-Jk,225
12
+ lalamo/data/huggingface_message.py,sha256=8oTCxL_IOHRCVgQneRv52sgJNprsiAIPvps5nBu6LWo,1037
13
13
  lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
14
14
  lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
15
15
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
16
- lalamo/model_import/common.py,sha256=MIbvK3mxgrDSXea6jujvCOu9Jjyip6MXeTsJjNTBJAU,12325
16
+ lalamo/model_import/common.py,sha256=evZeeizFev2i7Whd9X7sgUQV8v5apjfmy_BhshnFbyo,13011
17
17
  lalamo/model_import/huggingface_generation_config.py,sha256=xicv_kJOfIGlz4gi5fRFIkiAZ9_QRDLRtW8nKMm5tVU,2022
18
18
  lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
19
+ lalamo/model_import/remote_registry.py,sha256=4VjZSwlYMqflMfSSPi7-GSb9tmTLMZELzXoJJ3Tsx5s,1045
19
20
  lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
20
21
  lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
21
22
  lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDjOTrWevfyw7QbWF1mUdOQ,5924
@@ -47,20 +48,22 @@ lalamo/model_import/model_specs/lfm2.py,sha256=wg4Ggt6BbMO4ScJ6h8tjvBc3IVSrMudES
47
48
  lalamo/model_import/model_specs/llama.py,sha256=TxhKbIBFmGV2NopOg_k3ltsKlJccbxKyu-GQ7hYWCyw,3140
48
49
  lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
49
50
  lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
50
- lalamo/model_import/model_specs/mistral.py,sha256=HAojorjOqsJn2DoMBzYRw8A70qCslhFEsE9AF5xumlg,1278
51
+ lalamo/model_import/model_specs/mistral.py,sha256=i616AQg876PP9GHqHwaFIc29rUlsuhs0Z8_p0wv9eYg,1479
51
52
  lalamo/model_import/model_specs/pleias.py,sha256=5sRpZGYwLdsav6bLiW-459y1Cs9iJKgKkBIuGsOxtsQ,368
52
53
  lalamo/model_import/model_specs/polaris.py,sha256=Mw1-6bByjDmPIKlIUIV46CsmV5xUp_laI5Qquo5DmAQ,520
53
54
  lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiEIhWxUdTRfo,7563
54
55
  lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
55
- lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
56
+ lalamo/models/__init__.py,sha256=XMYuKSsiiIQUOq-ZtjIJcaIjTeCMYaO9bKJ9kvvLq98,394
56
57
  lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
57
- lalamo/models/common.py,sha256=uU6eCHtIqMeC_aRGVo09NdpAtvQ6RKSbm6pumVvL8pc,2943
58
- lalamo/models/language_model.py,sha256=HtFS-R4Uqr7SohFstoAZFVrJI293N9cG_LVkXhZxgFI,13546
58
+ lalamo/models/common.py,sha256=8gMDvu0JXNejRslhzdurrPAS3ZymcR2Grq1RRpddc4M,3402
59
+ lalamo/models/compile_helpers.py,sha256=t_rGCznSAQC2W4ioGZUg4Oc7lpTL6VfutKtOZ06qfXo,2227
60
+ lalamo/models/language_model.py,sha256=YL86--CwI-T7h4ymCk3DXZ5Cswq3OCn_7wJGfYI6swk,26113
61
+ lalamo/models/lm_helpers.py,sha256=rocQ184MCF5gnFwLbzWR7mDrV6b-0VxvOkqbhPxsCKE,6590
59
62
  lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
60
63
  lalamo/modules/activations.py,sha256=25F4XytJMIwPPmUbxiDUrcrdUi4c-O9SUbwv9lnZbuU,992
61
64
  lalamo/modules/classifier.py,sha256=Q5eNzJ68to6JGk8IDZiKv6Rmwh15UyT2xC52tP5njoQ,11767
62
65
  lalamo/modules/common.py,sha256=Rc9zenrUMntDKZydI1tzt1ZIY8ggfyk3ZDB-xi81ibw,3406
63
- lalamo/modules/decoder.py,sha256=I30fptNifcdw9OOCU50aZnEqsJ2X4VM9YXdtRkxbqGc,7014
66
+ lalamo/modules/decoder.py,sha256=zC4IlSzBeEbHiAlGCl8TGCBqGLVtXb_FrJuC9cPwYqo,7103
64
67
  lalamo/modules/embedding.py,sha256=PdNy4tGt9F1zve4X73WKNS0DXL-nHUFOlZmGFUAarkQ,27727
65
68
  lalamo/modules/linear.py,sha256=4xIhmeouD7R10lt8KJBLxgypVXYhpGmXdHUc-96Upfk,42871
66
69
  lalamo/modules/mlp.py,sha256=ogxi9q8J38FnuBkAtC7_KTMc7JZG4BRdsAHYprHZNvM,17690
@@ -74,22 +77,21 @@ lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
74
77
  lalamo/modules/token_mixers/__init__.py,sha256=lwxUl0eG5IvuVc_HOsINP2vtbv9F0cUmSNHFHaEmPGk,1109
75
78
  lalamo/modules/token_mixers/attention.py,sha256=ielw1-KWBfCPCPmzSHgM0TaSUcmSkWKTxrN3N_FsGm4,16144
76
79
  lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
77
- lalamo/modules/token_mixers/mamba.py,sha256=zV5CnhEbAtJ32V32a2VZGsbjZ-sohMqRbR5kW9XH1AI,19087
80
+ lalamo/modules/token_mixers/mamba.py,sha256=EFyuAEjp6pNwOriesFnulOafyGRHYgqotou6UE55axE,28945
78
81
  lalamo/modules/token_mixers/short_conv.py,sha256=k1z9UwcJGag2NHWad7cYiAnhxULtmva9RrdhqVbir18,5085
79
82
  lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
80
83
  lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
81
84
  lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
82
85
  lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxbnMJCl--JOoe9Fkcc9Mg,1642
83
86
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
84
- lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
87
+ lalamo/speculator/__init__.py,sha256=Ye3gMhrtNxaWPMzWbXFqKX7Rv32LGlT2k9eX2uvifKg,364
85
88
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
86
- lalamo/speculator/estimator.py,sha256=6T8NdmDdhvP0BPg7vdkB_pxAkfgpu4WktNpUHtFuyiE,3833
87
- lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
89
+ lalamo/speculator/inference.py,sha256=-RfgtdwMU4-EGnzc7oT8zJEhiA_Md03rPMyyi_BF26k,2792
88
90
  lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
89
91
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
90
- lalamo-0.6.4.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
- lalamo-0.6.4.dist-info/METADATA,sha256=oS1EAJBl3jBtvZU0Rd-UcjnL2Trngree7Syn2L16Rx8,3112
92
- lalamo-0.6.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
- lalamo-0.6.4.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
- lalamo-0.6.4.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
- lalamo-0.6.4.dist-info/RECORD,,
92
+ lalamo-0.6.6.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
93
+ lalamo-0.6.6.dist-info/METADATA,sha256=Yc0I-RS-xekkjmN-Y5LQf-0R3gKuIeuFmZ3hPXJh4tY,3112
94
+ lalamo-0.6.6.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
95
+ lalamo-0.6.6.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
96
+ lalamo-0.6.6.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
97
+ lalamo-0.6.6.dist-info/RECORD,,
@@ -1,127 +0,0 @@
1
- import functools
2
- import itertools
3
- import os
4
- from collections.abc import Callable
5
- from typing import NamedTuple
6
-
7
- import jax
8
- import jax.numpy as jnp
9
-
10
- from lalamo.models import LanguageModel
11
-
12
-
13
- def get_default_device_bytes() -> int | None:
14
- dynamic_allocate = False
15
-
16
- preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
17
- dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
18
-
19
- allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
20
- dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
21
-
22
- if dynamic_allocate:
23
- return None
24
-
25
- memory_stats = jax.local_devices()[0].memory_stats()
26
- if memory_stats is None or "bytes_limit" not in memory_stats:
27
- return None
28
-
29
- mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
30
- try:
31
- mem_fraction = float(mem_fraction_raw)
32
- except ValueError:
33
- mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
34
-
35
- # 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
36
- # so it should correspond to something you'd see in nvidia-smi
37
- memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
38
-
39
- return get_usable_memory_from_bytes(memory_limit)
40
-
41
-
42
- def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
- # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
- return int(limit_bytes * 0.95)
45
-
46
-
47
- def estimate_memory_from_batchsize(
48
- model: LanguageModel,
49
- max_input_length: int,
50
- max_output_length: int,
51
- num_logits_per_token: int,
52
- batch_size: int,
53
- ) -> int:
54
- memory_analysis = (
55
- jax.jit(
56
- functools.partial(
57
- LanguageModel.generate_tokens,
58
- max_output_length=max_output_length,
59
- num_top_logits_to_return=num_logits_per_token,
60
- ),
61
- )
62
- .lower(
63
- model,
64
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
65
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
66
- )
67
- # disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
68
- .compile(compiler_options={"xla_gpu_autotune_level": "0"})
69
- .memory_analysis()
70
- )
71
-
72
- assert hasattr(memory_analysis, "argument_size_in_bytes")
73
- assert hasattr(memory_analysis, "output_size_in_bytes")
74
- assert hasattr(memory_analysis, "temp_size_in_bytes")
75
-
76
- return (
77
- memory_analysis.argument_size_in_bytes
78
- + memory_analysis.output_size_in_bytes
79
- + memory_analysis.temp_size_in_bytes
80
- )
81
-
82
-
83
- class EstimateBatchsizeFromMemoryEvent(NamedTuple):
84
- lo: int
85
- hi: int | None
86
-
87
-
88
- def estimate_batchsize_from_memory(
89
- model: LanguageModel,
90
- max_input_length: int,
91
- max_output_length: int,
92
- num_logits_per_token: int,
93
- target_mem: int,
94
- progress: Callable[[EstimateBatchsizeFromMemoryEvent], None] | None = None,
95
- ) -> int:
96
- mem_for_bs = functools.cache(
97
- functools.partial(
98
- estimate_memory_from_batchsize,
99
- model,
100
- max_input_length,
101
- max_output_length,
102
- num_logits_per_token,
103
- ),
104
- )
105
-
106
- lo = 0
107
- hi = 0
108
- for candidate_exp in itertools.count():
109
- lo = hi
110
- hi = 2**candidate_exp
111
-
112
- if progress is not None:
113
- progress(EstimateBatchsizeFromMemoryEvent(lo, None))
114
- if target_mem < mem_for_bs(hi):
115
- break
116
-
117
- while hi - lo > 1:
118
- mid = (lo + hi) // 2
119
-
120
- if progress is not None:
121
- progress(EstimateBatchsizeFromMemoryEvent(lo, hi))
122
- if target_mem < mem_for_bs(mid):
123
- hi = mid
124
- else:
125
- lo = mid
126
-
127
- return lo
File without changes