ik-llama-cpp-python 0.1.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bin/ggml.dll ADDED
Binary file
bin/llama-quantize.exe ADDED
Binary file
bin/llama.dll ADDED
Binary file
@@ -0,0 +1,7 @@
1
+ """ik-llama-cpp-python — Python bindings for ik_llama.cpp."""
2
+
3
+ from .llama import IkLlama
4
+ from .quantize import find_quantize_bin, quantize, quantize_from_hf
5
+
6
+ __version__ = "0.1.0"
7
+ __all__ = ["IkLlama", "find_quantize_bin", "quantize", "quantize_from_hf"]
@@ -0,0 +1,418 @@
1
+ """Low-level ctypes bindings for the ik_llama.cpp C API (llama.h).
2
+
3
+ NOTE: ik_llama.cpp uses an older-style API compared to upstream llama.cpp:
4
+ - Sampling is direct (llama_sample_*) instead of sampler-chain objects
5
+ - Model free is llama_free_model (not llama_model_free)
6
+ - Timings use llama_get_timings / llama_print_timings / llama_reset_timings
7
+ - Struct layouts differ significantly
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import ctypes
13
+ import functools
14
+ from typing import Any, Callable, List, Optional, TypeVar
15
+
16
+ from ._lib_loader import load_shared_library
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Load shared library
20
+ # ---------------------------------------------------------------------------
21
+
22
+ _lib = load_shared_library()
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Decorator — binds a Python stub to a C symbol in the shared library
26
+ # ---------------------------------------------------------------------------
27
+
28
+ F = TypeVar("F", bound=Callable[..., Any])
29
+
30
+
31
+ def _cfunc(name: str, argtypes: List[Any], restype: Any):
32
+ def decorator(f: F) -> F:
33
+ func = getattr(_lib, name)
34
+ func.argtypes = argtypes
35
+ func.restype = restype
36
+ functools.wraps(f)(func)
37
+ return func # type: ignore[return-value]
38
+ return decorator
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Type aliases
43
+ # ---------------------------------------------------------------------------
44
+
45
+ llama_model_p = ctypes.c_void_p
46
+ llama_context_p = ctypes.c_void_p
47
+ llama_vocab_p = ctypes.c_void_p
48
+
49
+ llama_token = ctypes.c_int32
50
+ llama_token_p = ctypes.POINTER(llama_token)
51
+ llama_pos = ctypes.c_int32
52
+ llama_seq_id = ctypes.c_int32
53
+
54
+ # Callback types
55
+ llama_progress_callback = ctypes.CFUNCTYPE(
56
+ ctypes.c_bool, ctypes.c_float, ctypes.c_void_p
57
+ )
58
+ ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
59
+ ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
60
+ )
61
+ ggml_abort_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p)
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Structures (matching ik_llama.cpp fork layouts)
65
+ # ---------------------------------------------------------------------------
66
+
67
+
68
+ class llama_model_kv_override_value(ctypes.Union):
69
+ _fields_ = [
70
+ ("val_i64", ctypes.c_int64),
71
+ ("val_f64", ctypes.c_double),
72
+ ("val_bool", ctypes.c_bool),
73
+ ("val_str", ctypes.c_char * 128),
74
+ ]
75
+
76
+
77
+ class llama_model_kv_override(ctypes.Structure):
78
+ _fields_ = [
79
+ ("tag", ctypes.c_int),
80
+ ("key", ctypes.c_char * 128),
81
+ ("value", llama_model_kv_override_value),
82
+ ]
83
+
84
+
85
+ class llama_model_params(ctypes.Structure):
86
+ """ik_llama.cpp llama_model_params — differs from upstream."""
87
+ _fields_ = [
88
+ ("devices", ctypes.c_char_p),
89
+ ("n_gpu_layers", ctypes.c_int32),
90
+ ("mla", ctypes.c_int32),
91
+ ("split_mode", ctypes.c_int),
92
+ ("main_gpu", ctypes.c_int32),
93
+ ("max_gpu", ctypes.c_int32),
94
+ ("ncmoe", ctypes.c_int32),
95
+ ("type_k", ctypes.c_int),
96
+ ("type_v", ctypes.c_int),
97
+ ("max_ctx_size", ctypes.c_uint32),
98
+ ("n_seq_max", ctypes.c_int32),
99
+ ("n_ubatch", ctypes.c_int32),
100
+ ("amb", ctypes.c_int32),
101
+ ("fit_margin", ctypes.c_int32),
102
+ ("fit", ctypes.c_bool),
103
+ ("worst_graph_tokens", ctypes.c_int32),
104
+ ("type_k_first", ctypes.c_int),
105
+ ("type_k_last", ctypes.c_int),
106
+ ("type_v_first", ctypes.c_int),
107
+ ("type_v_last", ctypes.c_int),
108
+ ("n_k_first", ctypes.c_int32),
109
+ ("n_k_last", ctypes.c_int32),
110
+ ("n_v_first", ctypes.c_int32),
111
+ ("n_v_last", ctypes.c_int32),
112
+ ("tensor_split", ctypes.POINTER(ctypes.c_float)),
113
+ ("rpc_servers", ctypes.c_char_p),
114
+ ("progress_callback", llama_progress_callback),
115
+ ("progress_callback_user_data", ctypes.c_void_p),
116
+ ("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
117
+ ("tensor_buft_overrides", ctypes.c_void_p),
118
+ ("vocab_only", ctypes.c_bool),
119
+ ("use_mmap", ctypes.c_bool),
120
+ ("use_mlock", ctypes.c_bool),
121
+ ("check_tensors", ctypes.c_bool),
122
+ ("repack_tensors", ctypes.c_bool),
123
+ ("use_thp", ctypes.c_bool),
124
+ ("validate_quants", ctypes.c_bool),
125
+ ("merge_qkv", ctypes.c_bool),
126
+ ("merge_up_gate_exps", ctypes.c_bool),
127
+ ("mtp", ctypes.c_bool),
128
+ ("dry_run", ctypes.c_bool),
129
+ ("flash_attn", ctypes.c_bool),
130
+ ]
131
+
132
+
133
+ class llama_context_params(ctypes.Structure):
134
+ """ik_llama.cpp llama_context_params — differs from upstream."""
135
+ _fields_ = [
136
+ ("seed", ctypes.c_uint32),
137
+ ("n_ctx", ctypes.c_uint32),
138
+ ("n_batch", ctypes.c_uint32),
139
+ ("n_ubatch", ctypes.c_uint32),
140
+ ("n_seq_max", ctypes.c_uint32),
141
+ ("n_threads", ctypes.c_uint32),
142
+ ("n_threads_batch", ctypes.c_uint32),
143
+ ("max_extra_alloc", ctypes.c_int32),
144
+ ("worst_case_tokens", ctypes.c_int32),
145
+ ("rope_scaling_type", ctypes.c_int),
146
+ ("pooling_type", ctypes.c_int),
147
+ ("attention_type", ctypes.c_int),
148
+ ("rope_freq_base", ctypes.c_float),
149
+ ("rope_freq_scale", ctypes.c_float),
150
+ ("yarn_ext_factor", ctypes.c_float),
151
+ ("yarn_attn_factor", ctypes.c_float),
152
+ ("yarn_beta_fast", ctypes.c_float),
153
+ ("yarn_beta_slow", ctypes.c_float),
154
+ ("yarn_orig_ctx", ctypes.c_uint32),
155
+ ("defrag_thold", ctypes.c_float),
156
+ ("cb_eval", ggml_backend_sched_eval_callback),
157
+ ("cb_eval_user_data", ctypes.c_void_p),
158
+ ("type_k", ctypes.c_int),
159
+ ("type_v", ctypes.c_int),
160
+ ("type_reduce", ctypes.c_int),
161
+ ("type_k_first", ctypes.c_int),
162
+ ("type_k_last", ctypes.c_int),
163
+ ("type_v_first", ctypes.c_int),
164
+ ("type_v_last", ctypes.c_int),
165
+ ("n_k_first", ctypes.c_int32),
166
+ ("n_k_last", ctypes.c_int32),
167
+ ("n_v_first", ctypes.c_int32),
168
+ ("n_v_last", ctypes.c_int32),
169
+ ("logits_all", ctypes.c_bool),
170
+ ("embeddings", ctypes.c_bool),
171
+ ("offload_kqv", ctypes.c_bool),
172
+ ("flash_attn", ctypes.c_bool),
173
+ ("mla_attn", ctypes.c_int),
174
+ ("attn_max_batch", ctypes.c_int),
175
+ ("fused_moe_up_gate", ctypes.c_bool),
176
+ ("grouped_expert_routing", ctypes.c_bool),
177
+ ("fused_up_gate", ctypes.c_bool),
178
+ ("fused_mmad", ctypes.c_bool),
179
+ ("rope_cache", ctypes.c_bool),
180
+ ("graph_reuse", ctypes.c_bool),
181
+ ("min_experts", ctypes.c_int),
182
+ ("thresh_experts", ctypes.c_float),
183
+ ("only_active_experts", ctypes.c_bool),
184
+ ("k_cache_hadamard", ctypes.c_bool),
185
+ ("v_cache_hadamard", ctypes.c_bool),
186
+ ("split_mode_graph_scheduling", ctypes.c_bool),
187
+ ("scheduler_async", ctypes.c_bool),
188
+ ("mtp", ctypes.c_bool),
189
+ ("mtp_op_type", ctypes.c_int),
190
+ ("abort_callback", ggml_abort_callback),
191
+ ("abort_callback_data", ctypes.c_void_p),
192
+ ("offload_policy", ctypes.c_void_p),
193
+ ("cuda_params", ctypes.c_void_p),
194
+ ]
195
+
196
+
197
+ class llama_batch(ctypes.Structure):
198
+ _fields_ = [
199
+ ("n_tokens", ctypes.c_int32),
200
+ ("token", ctypes.POINTER(llama_token)),
201
+ ("embd", ctypes.POINTER(ctypes.c_float)),
202
+ ("pos", ctypes.POINTER(llama_pos)),
203
+ ("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
204
+ ("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
205
+ ("logits", ctypes.POINTER(ctypes.c_int8)),
206
+ ("all_pos_0", llama_pos),
207
+ ("all_pos_1", llama_pos),
208
+ ("all_seq_id", llama_seq_id),
209
+ ]
210
+
211
+
212
+ class llama_token_data(ctypes.Structure):
213
+ _fields_ = [
214
+ ("id", llama_token),
215
+ ("logit", ctypes.c_float),
216
+ ("p", ctypes.c_float),
217
+ ]
218
+
219
+
220
+ class llama_token_data_array(ctypes.Structure):
221
+ _fields_ = [
222
+ ("data", ctypes.POINTER(llama_token_data)),
223
+ ("size", ctypes.c_size_t),
224
+ ("selected", ctypes.c_int64),
225
+ ("sorted", ctypes.c_bool),
226
+ ]
227
+
228
+
229
+ class llama_timings(ctypes.Structure):
230
+ """ik_llama.cpp uses llama_timings instead of llama_perf_context_data."""
231
+ _fields_ = [
232
+ ("t_start_ms", ctypes.c_double),
233
+ ("t_end_ms", ctypes.c_double),
234
+ ("t_load_ms", ctypes.c_double),
235
+ ("t_sample_ms", ctypes.c_double),
236
+ ("t_p_eval_ms", ctypes.c_double),
237
+ ("t_eval_ms", ctypes.c_double),
238
+ ("n_sample", ctypes.c_int32),
239
+ ("n_p_eval", ctypes.c_int32),
240
+ ("n_eval", ctypes.c_int32),
241
+ ]
242
+
243
+
244
+ # ---------------------------------------------------------------------------
245
+ # Function bindings
246
+ # ---------------------------------------------------------------------------
247
+
248
+ # -- Backend lifecycle --
249
+
250
+ @_cfunc("llama_backend_init", [], None)
251
+ def llama_backend_init() -> None: ...
252
+
253
+
254
+ @_cfunc("llama_backend_free", [], None)
255
+ def llama_backend_free() -> None: ...
256
+
257
+
258
+ # -- Default params --
259
+
260
+ @_cfunc("llama_model_default_params", [], llama_model_params)
261
+ def llama_model_default_params() -> llama_model_params: ...
262
+
263
+
264
+ @_cfunc("llama_context_default_params", [], llama_context_params)
265
+ def llama_context_default_params() -> llama_context_params: ...
266
+
267
+
268
+ # -- Model load / free --
269
+
270
+ @_cfunc("llama_model_load_from_file", [ctypes.c_char_p, llama_model_params], ctypes.c_void_p)
271
+ def llama_model_load_from_file(path: bytes, params: llama_model_params) -> Optional[int]: ...
272
+
273
+
274
+ @_cfunc("llama_free_model", [ctypes.c_void_p], None)
275
+ def llama_free_model(model: int) -> None: ...
276
+
277
+
278
+ @_cfunc("llama_model_desc", [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_size_t], ctypes.c_int32)
279
+ def llama_model_desc(model: int, buf: Any, buf_size: int) -> int: ...
280
+
281
+
282
+ # -- Context init / free --
283
+
284
+ @_cfunc("llama_init_from_model", [ctypes.c_void_p, llama_context_params], ctypes.c_void_p)
285
+ def llama_init_from_model(model: int, params: llama_context_params) -> Optional[int]: ...
286
+
287
+
288
+ @_cfunc("llama_free", [ctypes.c_void_p], None)
289
+ def llama_free(ctx: int) -> None: ...
290
+
291
+
292
+ # -- Vocab --
293
+
294
+ @_cfunc("llama_model_get_vocab", [ctypes.c_void_p], ctypes.c_void_p)
295
+ def llama_model_get_vocab(model: int) -> Optional[int]: ...
296
+
297
+
298
+ @_cfunc("llama_vocab_n_tokens", [ctypes.c_void_p], ctypes.c_int32)
299
+ def llama_vocab_n_tokens(vocab: int) -> int: ...
300
+
301
+
302
+ # -- Tokenize / Detokenize --
303
+
304
+ @_cfunc(
305
+ "llama_tokenize",
306
+ [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int32,
307
+ llama_token_p, ctypes.c_int32, ctypes.c_bool, ctypes.c_bool],
308
+ ctypes.c_int32,
309
+ )
310
+ def llama_tokenize(
311
+ vocab: int, text: bytes, text_len: int,
312
+ tokens: Any, n_tokens_max: int,
313
+ add_special: bool, parse_special: bool,
314
+ ) -> int: ...
315
+
316
+
317
+ @_cfunc(
318
+ "llama_token_to_piece",
319
+ [ctypes.c_void_p, llama_token, ctypes.c_char_p,
320
+ ctypes.c_int32, ctypes.c_int32, ctypes.c_bool],
321
+ ctypes.c_int32,
322
+ )
323
+ def llama_token_to_piece(
324
+ vocab: int, token: int, buf: Any,
325
+ length: int, lstrip: int, special: bool,
326
+ ) -> int: ...
327
+
328
+
329
+ # -- Batch --
330
+
331
+ @_cfunc("llama_batch_init", [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32], llama_batch)
332
+ def llama_batch_init(n_tokens: int, embd: int, n_seq_max: int) -> llama_batch: ...
333
+
334
+
335
+ @_cfunc("llama_batch_free", [llama_batch], None)
336
+ def llama_batch_free(batch: llama_batch) -> None: ...
337
+
338
+
339
+ # -- Decode --
340
+
341
+ @_cfunc("llama_decode", [ctypes.c_void_p, llama_batch], ctypes.c_int32)
342
+ def llama_decode(ctx: int, batch: llama_batch) -> int: ...
343
+
344
+
345
+ # -- Logits --
346
+
347
+ @_cfunc("llama_get_logits_ith", [ctypes.c_void_p, ctypes.c_int32], ctypes.POINTER(ctypes.c_float))
348
+ def llama_get_logits_ith(ctx: int, i: int) -> Any: ...
349
+
350
+
351
+ # -- EOG detection --
352
+
353
+ @_cfunc("llama_token_is_eog", [ctypes.c_void_p, llama_token], ctypes.c_bool)
354
+ def llama_token_is_eog(model: int, token: int) -> bool: ...
355
+
356
+
357
+ # -- Direct sampling (ik_llama.cpp style — no sampler chain) --
358
+
359
+ @_cfunc(
360
+ "llama_sample_top_k",
361
+ [ctypes.c_void_p, ctypes.POINTER(llama_token_data_array), ctypes.c_int32, ctypes.c_size_t],
362
+ None,
363
+ )
364
+ def llama_sample_top_k(ctx: int, candidates: Any, k: int, min_keep: int) -> None: ...
365
+
366
+
367
+ @_cfunc(
368
+ "llama_sample_top_p",
369
+ [ctypes.c_void_p, ctypes.POINTER(llama_token_data_array), ctypes.c_float, ctypes.c_size_t],
370
+ None,
371
+ )
372
+ def llama_sample_top_p(ctx: int, candidates: Any, p: float, min_keep: int) -> None: ...
373
+
374
+
375
+ @_cfunc(
376
+ "llama_sample_temp",
377
+ [ctypes.c_void_p, ctypes.POINTER(llama_token_data_array), ctypes.c_float],
378
+ None,
379
+ )
380
+ def llama_sample_temp(ctx: int, candidates: Any, temp: float) -> None: ...
381
+
382
+
383
+ @_cfunc(
384
+ "llama_sample_softmax",
385
+ [ctypes.c_void_p, ctypes.POINTER(llama_token_data_array)],
386
+ None,
387
+ )
388
+ def llama_sample_softmax(ctx: int, candidates: Any) -> None: ...
389
+
390
+
391
+ @_cfunc(
392
+ "llama_sample_token_greedy",
393
+ [ctypes.c_void_p, ctypes.POINTER(llama_token_data_array)],
394
+ llama_token,
395
+ )
396
+ def llama_sample_token_greedy(ctx: int, candidates: Any) -> int: ...
397
+
398
+
399
+ @_cfunc(
400
+ "llama_sample_token",
401
+ [ctypes.c_void_p, ctypes.POINTER(llama_token_data_array)],
402
+ llama_token,
403
+ )
404
+ def llama_sample_token(ctx: int, candidates: Any) -> int: ...
405
+
406
+
407
+ # -- Timings (ik_llama.cpp style) --
408
+
409
+ @_cfunc("llama_get_timings", [ctypes.c_void_p], llama_timings)
410
+ def llama_get_timings(ctx: int) -> llama_timings: ...
411
+
412
+
413
+ @_cfunc("llama_print_timings", [ctypes.c_void_p], None)
414
+ def llama_print_timings(ctx: int) -> None: ...
415
+
416
+
417
+ @_cfunc("llama_reset_timings", [ctypes.c_void_p], None)
418
+ def llama_reset_timings(ctx: int) -> None: ...
@@ -0,0 +1,192 @@
1
+ """RAII wrappers for ik_llama.cpp C objects."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ctypes
6
+ from typing import Optional
7
+
8
+ from . import _ctypes_api as C
9
+
10
+
11
+ class IkModel:
12
+ """Wraps a llama_model pointer with automatic cleanup."""
13
+
14
+ def __init__(self, path: str, *, use_mmap: bool = True, use_mlock: bool = False,
15
+ n_gpu_layers: int = 0):
16
+ C.llama_backend_init()
17
+
18
+ params = C.llama_model_default_params()
19
+ params.use_mmap = use_mmap
20
+ params.use_mlock = use_mlock
21
+ params.n_gpu_layers = n_gpu_layers
22
+
23
+ self._model = C.llama_model_load_from_file(path.encode("utf-8"), params)
24
+ if not self._model:
25
+ raise RuntimeError(f"Failed to load model: {path}")
26
+
27
+ self._vocab = C.llama_model_get_vocab(self._model)
28
+ self.n_vocab = C.llama_vocab_n_tokens(self._vocab)
29
+
30
+ @property
31
+ def model(self):
32
+ return self._model
33
+
34
+ @property
35
+ def vocab(self):
36
+ return self._vocab
37
+
38
+ @property
39
+ def desc(self) -> str:
40
+ """Model description string (e.g. 'gemma4 2B IQ4_KT - 4.0 bpw')."""
41
+ buf = ctypes.create_string_buffer(256)
42
+ C.llama_model_desc(self._model, buf, 256)
43
+ return buf.value.decode("utf-8", errors="replace")
44
+
45
+ def tokenize(self, text: str, *, add_bos: bool = True, special: bool = False) -> list[int]:
46
+ text_bytes = text.encode("utf-8")
47
+ # First call to get required size (ik_llama.cpp takes model*, not vocab*)
48
+ n = C.llama_tokenize(
49
+ self._model, text_bytes, len(text_bytes),
50
+ None, 0, add_bos, special,
51
+ )
52
+ n = abs(n)
53
+ buf = (C.llama_token * n)()
54
+ n_actual = C.llama_tokenize(
55
+ self._model, text_bytes, len(text_bytes),
56
+ buf, n, add_bos, special,
57
+ )
58
+ return list(buf[:n_actual])
59
+
60
+ def detokenize(self, tokens: list[int], *, special: bool = False) -> str:
61
+ pieces = []
62
+ buf = ctypes.create_string_buffer(256)
63
+ for tok in tokens:
64
+ n = C.llama_token_to_piece(self._model, tok, buf, 256, 0, special)
65
+ if n > 0:
66
+ pieces.append(buf.value[:n].decode("utf-8", errors="replace"))
67
+ return "".join(pieces)
68
+
69
+ def close(self):
70
+ if self._model:
71
+ C.llama_free_model(self._model)
72
+ self._model = None
73
+
74
+ def __del__(self):
75
+ self.close()
76
+
77
+
78
+ class IkContext:
79
+ """Wraps a llama_context pointer with automatic cleanup."""
80
+
81
+ def __init__(self, model: IkModel, *, n_ctx: int = 4096, n_threads: int = 0,
82
+ flash_attn: bool = True):
83
+ params = C.llama_context_default_params()
84
+ params.n_ctx = n_ctx
85
+ params.n_batch = n_ctx
86
+ if n_threads > 0:
87
+ params.n_threads = n_threads
88
+ params.n_threads_batch = n_threads
89
+ params.flash_attn = flash_attn
90
+
91
+ self._ctx = C.llama_init_from_model(model.model, params)
92
+ if not self._ctx:
93
+ raise RuntimeError("Failed to create context")
94
+ self._model = model
95
+ self._n_ubatch = params.n_ubatch or 512
96
+
97
+ @property
98
+ def ctx(self):
99
+ return self._ctx
100
+
101
+ @property
102
+ def model(self) -> IkModel:
103
+ return self._model
104
+
105
+ def decode(self, batch: C.llama_batch) -> int:
106
+ return C.llama_decode(self._ctx, batch)
107
+
108
+ def get_logits(self, idx: int = -1):
109
+ return C.llama_get_logits_ith(self._ctx, idx)
110
+
111
+ def perf(self) -> dict:
112
+ data = C.llama_get_timings(self._ctx)
113
+ return {
114
+ "t_p_eval_ms": data.t_p_eval_ms,
115
+ "t_eval_ms": data.t_eval_ms,
116
+ "n_p_eval": data.n_p_eval,
117
+ "n_eval": data.n_eval,
118
+ }
119
+
120
+ def perf_reset(self):
121
+ C.llama_reset_timings(self._ctx)
122
+
123
+ def sample(self, idx: int, *, temperature: float = 0.0,
124
+ top_k: int = 40, top_p: float = 0.95) -> int:
125
+ """Sample a token from logits at position idx using direct sampling."""
126
+ n_vocab = self._model.n_vocab
127
+ logits = C.llama_get_logits_ith(self._ctx, idx)
128
+
129
+ # Build candidate array
130
+ candidates_data = (C.llama_token_data * n_vocab)()
131
+ for i in range(n_vocab):
132
+ candidates_data[i].id = i
133
+ candidates_data[i].logit = logits[i]
134
+ candidates_data[i].p = 0.0
135
+
136
+ candidates = C.llama_token_data_array()
137
+ candidates.data = candidates_data
138
+ candidates.size = n_vocab
139
+ candidates.selected = -1
140
+ candidates.sorted = False
141
+
142
+ candidates_p = ctypes.pointer(candidates)
143
+
144
+ if temperature <= 0:
145
+ return C.llama_sample_token_greedy(self._ctx, candidates_p)
146
+ else:
147
+ C.llama_sample_top_k(self._ctx, candidates_p, top_k, 1)
148
+ C.llama_sample_top_p(self._ctx, candidates_p, top_p, 1)
149
+ C.llama_sample_temp(self._ctx, candidates_p, temperature)
150
+ return C.llama_sample_token(self._ctx, candidates_p)
151
+
152
+ def close(self):
153
+ if self._ctx:
154
+ C.llama_free(self._ctx)
155
+ self._ctx = None
156
+
157
+ def __del__(self):
158
+ self.close()
159
+
160
+
161
+ def make_batch(tokens: list[int], *, logits_last: bool = True) -> C.llama_batch:
162
+ """Create a llama_batch from a token list (positions start at 0)."""
163
+ return make_batch_range(tokens, pos_start=0, logits_last=logits_last)
164
+
165
+
166
+ def make_batch_range(tokens: list[int], *, pos_start: int = 0,
167
+ logits_last: bool = True) -> C.llama_batch:
168
+ """Create a llama_batch from a token list with explicit position offset."""
169
+ n = len(tokens)
170
+ batch = C.llama_batch_init(n, 0, 1)
171
+ batch.n_tokens = n
172
+
173
+ for i, tok in enumerate(tokens):
174
+ batch.token[i] = tok
175
+ batch.pos[i] = pos_start + i
176
+ batch.n_seq_id[i] = 1
177
+ batch.seq_id[i][0] = 0
178
+ batch.logits[i] = 1 if (logits_last and i == n - 1) else 0
179
+
180
+ return batch
181
+
182
+
183
+ def make_batch_single(token: int, pos: int) -> C.llama_batch:
184
+ """Create a single-token batch for autoregressive generation."""
185
+ batch = C.llama_batch_init(1, 0, 1)
186
+ batch.n_tokens = 1
187
+ batch.token[0] = token
188
+ batch.pos[0] = pos
189
+ batch.n_seq_id[0] = 1
190
+ batch.seq_id[0][0] = 0
191
+ batch.logits[0] = 1
192
+ return batch
@@ -0,0 +1,62 @@
1
+ """Shared library loader for ik_llama.cpp."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ctypes
6
+ import os
7
+ import platform
8
+ import sys
9
+ from pathlib import Path
10
+
11
+
12
+ def _lib_names() -> list[str]:
13
+ system = platform.system()
14
+ if system == "Windows":
15
+ return ["llama.dll"]
16
+ elif system == "Darwin":
17
+ return ["libllama.dylib"]
18
+ else:
19
+ return ["libllama.so"]
20
+
21
+
22
+ def load_shared_library() -> ctypes.CDLL:
23
+ """Find and load the ik_llama shared library.
24
+
25
+ Search order:
26
+ 1. ``IK_LLAMA_CPP_LIB_PATH`` env var (exact path to .dll/.so/.dylib)
27
+ 2. ``ik_llama_cpp/lib/`` directory next to this file (pip-installed)
28
+ """
29
+ # 1. Explicit override
30
+ override = os.environ.get("IK_LLAMA_CPP_LIB_PATH")
31
+ if override:
32
+ p = Path(override)
33
+ if p.is_file():
34
+ return _load(p)
35
+ raise FileNotFoundError(f"IK_LLAMA_CPP_LIB_PATH points to missing file: {p}")
36
+
37
+ # 2. Package lib/ directory (source tree or site-packages)
38
+ search_dirs = [Path(__file__).parent / "lib"]
39
+
40
+ # Also check site-packages (editable installs put DLLs there)
41
+ for sp in sys.path:
42
+ sp_lib = Path(sp) / "ik_llama_cpp" / "lib"
43
+ if sp_lib != search_dirs[0] and sp_lib.is_dir():
44
+ search_dirs.append(sp_lib)
45
+
46
+ for lib_dir in search_dirs:
47
+ for name in _lib_names():
48
+ candidate = lib_dir / name
49
+ if candidate.exists():
50
+ return _load(candidate)
51
+
52
+ raise FileNotFoundError(
53
+ f"Cannot find ik_llama shared library. Searched: {search_dirs}. "
54
+ "Set IK_LLAMA_CPP_LIB_PATH or rebuild with: pip install -e ."
55
+ )
56
+
57
+
58
+ def _load(path: Path) -> ctypes.CDLL:
59
+ # On Windows, add the library directory to DLL search path
60
+ if platform.system() == "Windows":
61
+ os.add_dll_directory(str(path.parent))
62
+ return ctypes.CDLL(str(path))
Binary file
Binary file
Binary file