npcpy 1.3.22__py3-none-any.whl → 1.3.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/data/image.py +15 -15
- npcpy/data/web.py +2 -2
- npcpy/gen/image_gen.py +113 -62
- npcpy/gen/response.py +239 -0
- npcpy/llm_funcs.py +73 -71
- npcpy/memory/command_history.py +117 -69
- npcpy/memory/kg_vis.py +74 -74
- npcpy/npc_compiler.py +261 -26
- npcpy/npc_sysenv.py +4 -1
- npcpy/serve.py +393 -91
- npcpy/work/desktop.py +31 -5
- npcpy-1.3.23.dist-info/METADATA +416 -0
- {npcpy-1.3.22.dist-info → npcpy-1.3.23.dist-info}/RECORD +16 -16
- npcpy-1.3.22.dist-info/METADATA +0 -1039
- {npcpy-1.3.22.dist-info → npcpy-1.3.23.dist-info}/WHEEL +0 -0
- {npcpy-1.3.22.dist-info → npcpy-1.3.23.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.3.22.dist-info → npcpy-1.3.23.dist-info}/top_level.txt +0 -0
npcpy/data/image.py
CHANGED
|
@@ -85,21 +85,21 @@ def capture_screenshot( full=False) -> Dict[str, str]:
|
|
|
85
85
|
subprocess.run(["screencapture", file_path], capture_output=True)
|
|
86
86
|
|
|
87
87
|
elif system == "Linux":
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
88
|
+
_took = False
|
|
89
|
+
# Try non-interactive tools first
|
|
90
|
+
for _cmd, _args in [
|
|
91
|
+
("grim", [file_path]), # Wayland
|
|
92
|
+
("scrot", [file_path]), # X11, non-interactive full
|
|
93
|
+
("import", ["-window", "root", file_path]), # ImageMagick X11
|
|
94
|
+
("gnome-screenshot", ["-f", file_path]), # GNOME (may show dialog on newer versions)
|
|
95
|
+
]:
|
|
96
|
+
if subprocess.run(["which", _cmd], capture_output=True).returncode == 0:
|
|
97
|
+
subprocess.run([_cmd] + _args, capture_output=True, timeout=10)
|
|
98
|
+
if os.path.exists(file_path):
|
|
99
|
+
_took = True
|
|
100
|
+
break
|
|
101
|
+
if not _took:
|
|
102
|
+
print("No supported screenshot tool found. Install scrot, grim, or imagemagick.")
|
|
103
103
|
|
|
104
104
|
elif system == "Windows":
|
|
105
105
|
|
npcpy/data/web.py
CHANGED
|
@@ -146,8 +146,8 @@ def search_perplexity(
|
|
|
146
146
|
):
|
|
147
147
|
if api_key is None:
|
|
148
148
|
api_key = os.environ.get("PERPLEXITY_API_KEY")
|
|
149
|
-
if api_key is None:
|
|
150
|
-
raise
|
|
149
|
+
if api_key is None:
|
|
150
|
+
raise ValueError("PERPLEXITY_API_KEY not set. Set it in your environment or ~/.npcshrc.")
|
|
151
151
|
|
|
152
152
|
|
|
153
153
|
url = "https://api.perplexity.ai/chat/completions"
|
npcpy/gen/image_gen.py
CHANGED
|
@@ -34,27 +34,21 @@ def generate_image_diffusers(
|
|
|
34
34
|
if os.path.exists(checkpoint_path):
|
|
35
35
|
print(f"🌋 Found model_final.pt at {checkpoint_path}.")
|
|
36
36
|
|
|
37
|
-
# Load checkpoint to inspect it
|
|
38
37
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
39
38
|
|
|
40
|
-
# Check if this is a custom SimpleUNet model (from your training code)
|
|
41
|
-
# vs a Stable Diffusion UNet2DConditionModel
|
|
42
39
|
if 'config' in checkpoint and hasattr(checkpoint['config'], 'image_size'):
|
|
43
40
|
print(f"🌋 Detected custom SimpleUNet model, using custom generation")
|
|
44
|
-
# Use your custom generate_image function from npcpy.ft.diff
|
|
45
41
|
from npcpy.ft.diff import generate_image as custom_generate_image
|
|
46
42
|
|
|
47
|
-
# Your custom model ignores prompts and generates based on training data
|
|
48
43
|
image = custom_generate_image(
|
|
49
44
|
model_path=checkpoint_path,
|
|
50
45
|
prompt=prompt,
|
|
51
46
|
num_samples=1,
|
|
52
|
-
image_size=height
|
|
47
|
+
image_size=height
|
|
53
48
|
)
|
|
54
49
|
return image
|
|
55
50
|
|
|
56
51
|
else:
|
|
57
|
-
# This is a Stable Diffusion checkpoint
|
|
58
52
|
print(f"🌋 Detected Stable Diffusion UNet checkpoint")
|
|
59
53
|
base_model_id = "runwayml/stable-diffusion-v1-5"
|
|
60
54
|
print(f"🌋 Loading base pipeline: {base_model_id}")
|
|
@@ -67,7 +61,6 @@ def generate_image_diffusers(
|
|
|
67
61
|
|
|
68
62
|
print(f"🌋 Loading custom UNet weights from {checkpoint_path}")
|
|
69
63
|
|
|
70
|
-
# Extract the actual model state dict
|
|
71
64
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
72
65
|
unet_state_dict = checkpoint['model_state_dict']
|
|
73
66
|
print(f"🌋 Extracted model_state_dict from checkpoint")
|
|
@@ -75,7 +68,6 @@ def generate_image_diffusers(
|
|
|
75
68
|
unet_state_dict = checkpoint
|
|
76
69
|
print(f"🌋 Using checkpoint directly as state_dict")
|
|
77
70
|
|
|
78
|
-
# Load the state dict into the UNet
|
|
79
71
|
pipe.unet.load_state_dict(unet_state_dict)
|
|
80
72
|
pipe = pipe.to(device)
|
|
81
73
|
print(f"🌋 Successfully loaded fine-tuned UNet weights")
|
|
@@ -100,7 +92,6 @@ def generate_image_diffusers(
|
|
|
100
92
|
variant="fp16" if torch_dtype == torch.float16 else None,
|
|
101
93
|
)
|
|
102
94
|
|
|
103
|
-
# Common pipeline setup for Stable Diffusion models
|
|
104
95
|
if hasattr(pipe, 'enable_attention_slicing'):
|
|
105
96
|
pipe.enable_attention_slicing()
|
|
106
97
|
|
|
@@ -142,16 +133,7 @@ def generate_image_diffusers(
|
|
|
142
133
|
raise MemoryError(f"Insufficient memory for image generation with model {model}. Try a smaller model or reduce image size.")
|
|
143
134
|
else:
|
|
144
135
|
raise e
|
|
145
|
-
import os
|
|
146
|
-
import base64
|
|
147
|
-
import io
|
|
148
|
-
from typing import Union, List, Optional
|
|
149
136
|
|
|
150
|
-
import PIL
|
|
151
|
-
from PIL import Image
|
|
152
|
-
|
|
153
|
-
import requests
|
|
154
|
-
from urllib.request import urlopen
|
|
155
137
|
|
|
156
138
|
def openai_image_gen(
|
|
157
139
|
prompt: str,
|
|
@@ -184,13 +166,13 @@ def openai_image_gen(
|
|
|
184
166
|
files_to_close.append(file_handle)
|
|
185
167
|
elif isinstance(attachment, bytes):
|
|
186
168
|
img_byte_arr = io.BytesIO(attachment)
|
|
187
|
-
img_byte_arr.name = 'image.png'
|
|
169
|
+
img_byte_arr.name = 'image.png'
|
|
188
170
|
processed_images.append(img_byte_arr)
|
|
189
171
|
elif isinstance(attachment, Image.Image):
|
|
190
172
|
img_byte_arr = io.BytesIO()
|
|
191
173
|
attachment.save(img_byte_arr, format='PNG')
|
|
192
174
|
img_byte_arr.seek(0)
|
|
193
|
-
img_byte_arr.name = 'image.png'
|
|
175
|
+
img_byte_arr.name = 'image.png'
|
|
194
176
|
processed_images.append(img_byte_arr)
|
|
195
177
|
|
|
196
178
|
try:
|
|
@@ -202,7 +184,6 @@ def openai_image_gen(
|
|
|
202
184
|
size=size_str,
|
|
203
185
|
)
|
|
204
186
|
finally:
|
|
205
|
-
# This ensures any files we opened are properly closed
|
|
206
187
|
for f in files_to_close:
|
|
207
188
|
f.close()
|
|
208
189
|
else:
|
|
@@ -231,7 +212,6 @@ def openai_image_gen(
|
|
|
231
212
|
return collected_images
|
|
232
213
|
|
|
233
214
|
|
|
234
|
-
|
|
235
215
|
def gemini_image_gen(
|
|
236
216
|
prompt: str,
|
|
237
217
|
model: str = "gemini-2.5-flash",
|
|
@@ -305,18 +285,21 @@ def gemini_image_gen(
|
|
|
305
285
|
response = client.models.generate_content(
|
|
306
286
|
model=model,
|
|
307
287
|
contents=processed_contents,
|
|
288
|
+
config=types.GenerateContentConfig(
|
|
289
|
+
response_modalities=["IMAGE", "TEXT"],
|
|
290
|
+
),
|
|
308
291
|
)
|
|
309
|
-
|
|
292
|
+
|
|
310
293
|
if hasattr(response, 'candidates') and response.candidates:
|
|
311
294
|
for candidate in response.candidates:
|
|
312
295
|
for part in candidate.content.parts:
|
|
313
296
|
if hasattr(part, 'inline_data') and part.inline_data:
|
|
314
297
|
image_data = part.inline_data.data
|
|
315
298
|
collected_images.append(Image.open(BytesIO(image_data)))
|
|
316
|
-
|
|
299
|
+
|
|
317
300
|
if not collected_images and hasattr(response, 'text'):
|
|
318
301
|
print(f"Gemini response text: {response.text}")
|
|
319
|
-
|
|
302
|
+
|
|
320
303
|
return collected_images
|
|
321
304
|
else:
|
|
322
305
|
if 'imagen' in model:
|
|
@@ -335,6 +318,9 @@ def gemini_image_gen(
|
|
|
335
318
|
response = client.models.generate_content(
|
|
336
319
|
model=model,
|
|
337
320
|
contents=[prompt],
|
|
321
|
+
config=types.GenerateContentConfig(
|
|
322
|
+
response_modalities=["IMAGE", "TEXT"],
|
|
323
|
+
),
|
|
338
324
|
)
|
|
339
325
|
|
|
340
326
|
if hasattr(response, 'candidates') and response.candidates:
|
|
@@ -351,7 +337,86 @@ def gemini_image_gen(
|
|
|
351
337
|
|
|
352
338
|
else:
|
|
353
339
|
raise ValueError(f"Unsupported Gemini image model or API usage for new generation: '{model}'")
|
|
354
|
-
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def ollama_image_gen(
|
|
343
|
+
prompt: str,
|
|
344
|
+
model: str = "x/z-image-turbo",
|
|
345
|
+
height: int = 512,
|
|
346
|
+
width: int = 512,
|
|
347
|
+
n_images: int = 1,
|
|
348
|
+
api_url: Optional[str] = None,
|
|
349
|
+
seed: Optional[int] = None,
|
|
350
|
+
negative_prompt: Optional[str] = None,
|
|
351
|
+
num_steps: Optional[int] = None,
|
|
352
|
+
):
|
|
353
|
+
"""Generate images using Ollama's image generation API.
|
|
354
|
+
|
|
355
|
+
Works with ollama image gen models like x/z-image-turbo and x/flux2-klein.
|
|
356
|
+
Uses the /api/generate endpoint with image gen specific options.
|
|
357
|
+
"""
|
|
358
|
+
import requests
|
|
359
|
+
|
|
360
|
+
if api_url is None:
|
|
361
|
+
api_url = os.environ.get('OLLAMA_API_URL', 'http://localhost:11434')
|
|
362
|
+
|
|
363
|
+
endpoint = f"{api_url}/api/generate"
|
|
364
|
+
|
|
365
|
+
collected_images = []
|
|
366
|
+
|
|
367
|
+
for _ in range(n_images):
|
|
368
|
+
options = {}
|
|
369
|
+
if width:
|
|
370
|
+
options["width"] = width
|
|
371
|
+
if height:
|
|
372
|
+
options["height"] = height
|
|
373
|
+
if seed is not None:
|
|
374
|
+
options["seed"] = seed
|
|
375
|
+
if num_steps is not None:
|
|
376
|
+
options["num_steps"] = num_steps
|
|
377
|
+
|
|
378
|
+
payload = {
|
|
379
|
+
"model": model,
|
|
380
|
+
"prompt": prompt,
|
|
381
|
+
"stream": False,
|
|
382
|
+
}
|
|
383
|
+
if options:
|
|
384
|
+
payload["options"] = options
|
|
385
|
+
if negative_prompt:
|
|
386
|
+
payload["negative_prompt"] = negative_prompt
|
|
387
|
+
|
|
388
|
+
response = requests.post(endpoint, json=payload)
|
|
389
|
+
|
|
390
|
+
if not response.ok:
|
|
391
|
+
try:
|
|
392
|
+
err = response.json()
|
|
393
|
+
err_msg = err.get('error', response.text)
|
|
394
|
+
except Exception:
|
|
395
|
+
err_msg = response.text
|
|
396
|
+
raise RuntimeError(
|
|
397
|
+
f"Ollama image gen failed ({response.status_code}): {err_msg}\n"
|
|
398
|
+
f"Model: {model} — make sure it's pulled (`ollama pull {model}`)"
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
result = response.json()
|
|
402
|
+
|
|
403
|
+
if 'image' in result and result['image']:
|
|
404
|
+
image_bytes = base64.b64decode(result['image'])
|
|
405
|
+
image = Image.open(io.BytesIO(image_bytes))
|
|
406
|
+
collected_images.append(image)
|
|
407
|
+
elif 'images' in result and result['images']:
|
|
408
|
+
for img_b64 in result['images']:
|
|
409
|
+
image_bytes = base64.b64decode(img_b64)
|
|
410
|
+
image = Image.open(io.BytesIO(image_bytes))
|
|
411
|
+
collected_images.append(image)
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError(
|
|
414
|
+
f"No images returned from Ollama. Response keys: {list(result.keys())}. "
|
|
415
|
+
f"Make sure '{model}' is an image generation model (e.g. x/z-image-turbo, x/flux2-klein)."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
return collected_images
|
|
419
|
+
|
|
355
420
|
|
|
356
421
|
def generate_image(
|
|
357
422
|
prompt: str,
|
|
@@ -364,7 +429,7 @@ def generate_image(
|
|
|
364
429
|
api_url: Optional[str] = None,
|
|
365
430
|
attachments: Union[List[Union[str, bytes, Image.Image]], None] = None,
|
|
366
431
|
save_path: Optional[str] = None,
|
|
367
|
-
custom_model_path: Optional[str] = None,
|
|
432
|
+
custom_model_path: Optional[str] = None,
|
|
368
433
|
|
|
369
434
|
):
|
|
370
435
|
"""
|
|
@@ -373,7 +438,7 @@ def generate_image(
|
|
|
373
438
|
Args:
|
|
374
439
|
prompt (str): The prompt for generating/editing the image.
|
|
375
440
|
model (str): The model to use.
|
|
376
|
-
provider (str): The provider to use ('openai', 'diffusers', 'gemini').
|
|
441
|
+
provider (str): The provider to use ('openai', 'diffusers', 'gemini', 'ollama').
|
|
377
442
|
height (int): The height of the output image.
|
|
378
443
|
width (int): The width of the output image.
|
|
379
444
|
n_images (int): Number of images to generate.
|
|
@@ -381,32 +446,31 @@ def generate_image(
|
|
|
381
446
|
api_url (str): API URL for the provider.
|
|
382
447
|
attachments (list): List of images for editing. Can be file paths, bytes, or PIL Images.
|
|
383
448
|
save_path (str): Path to save the generated image.
|
|
384
|
-
custom_model_path (str): Path to a locally fine-tuned Diffusers model.
|
|
449
|
+
custom_model_path (str): Path to a locally fine-tuned Diffusers model.
|
|
385
450
|
|
|
386
451
|
Returns:
|
|
387
452
|
List[PIL.Image.Image]: A list of generated PIL Image objects.
|
|
388
453
|
"""
|
|
389
454
|
from urllib.request import urlopen
|
|
390
|
-
import os
|
|
455
|
+
import os
|
|
391
456
|
|
|
392
|
-
if model is None and custom_model_path is None:
|
|
457
|
+
if model is None and custom_model_path is None:
|
|
393
458
|
if provider == "openai":
|
|
394
459
|
model = "dall-e-2"
|
|
395
460
|
elif provider == "diffusers":
|
|
396
461
|
model = "runwayml/stable-diffusion-v1-5"
|
|
397
462
|
elif provider == "gemini":
|
|
398
463
|
model = "gemini-2.5-flash-image-preview"
|
|
464
|
+
elif provider == "ollama":
|
|
465
|
+
model = "x/z-image-turbo"
|
|
399
466
|
|
|
400
467
|
all_generated_pil_images = []
|
|
401
468
|
|
|
402
|
-
# <--- CRITICAL FIX: Handle custom_model_path for Diffusers here
|
|
403
469
|
if provider == "diffusers":
|
|
404
|
-
# If a custom_model_path is provided and exists, use it instead of a generic model name
|
|
405
470
|
if custom_model_path and os.path.isdir(custom_model_path):
|
|
406
471
|
print(f"🌋 Using custom Diffusers model from path: {custom_model_path}")
|
|
407
472
|
model_to_use = custom_model_path
|
|
408
473
|
else:
|
|
409
|
-
# Otherwise, use the standard model name (e.g., "runwayml/stable-diffusion-v1-5")
|
|
410
474
|
model_to_use = model
|
|
411
475
|
print(f"🌋 Using standard Diffusers model: {model_to_use}")
|
|
412
476
|
|
|
@@ -414,7 +478,7 @@ def generate_image(
|
|
|
414
478
|
try:
|
|
415
479
|
image = generate_image_diffusers(
|
|
416
480
|
prompt=prompt,
|
|
417
|
-
model=model_to_use,
|
|
481
|
+
model=model_to_use,
|
|
418
482
|
height=height,
|
|
419
483
|
width=width
|
|
420
484
|
)
|
|
@@ -447,43 +511,29 @@ def generate_image(
|
|
|
447
511
|
)
|
|
448
512
|
all_generated_pil_images.extend(images)
|
|
449
513
|
|
|
514
|
+
elif provider == "ollama":
|
|
515
|
+
images = ollama_image_gen(
|
|
516
|
+
prompt=prompt,
|
|
517
|
+
model=model,
|
|
518
|
+
height=height,
|
|
519
|
+
width=width,
|
|
520
|
+
n_images=n_images,
|
|
521
|
+
api_url=api_url
|
|
522
|
+
)
|
|
523
|
+
all_generated_pil_images.extend(images)
|
|
524
|
+
|
|
450
525
|
else:
|
|
451
|
-
# This is the fallback for other providers or if provider is not explicitly handled
|
|
452
526
|
valid_sizes = ["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]
|
|
453
527
|
size = f"{width}x{height}"
|
|
454
528
|
|
|
455
529
|
if attachments is not None:
|
|
456
530
|
raise ValueError("Image editing not supported with litellm provider")
|
|
457
531
|
|
|
458
|
-
# The litellm.image_generation function expects the provider as part of the model string
|
|
459
|
-
# e.g., "huggingface/starcoder" or "openai/dall-e-3"
|
|
460
|
-
# Since we've already handled "diffusers", "openai", "gemini" above,
|
|
461
|
-
# this 'else' block implies a generic litellm call.
|
|
462
|
-
# We need to ensure the model string is correctly formatted for litellm.
|
|
463
|
-
# However, the error message "LLM Provider NOT provided" suggests litellm
|
|
464
|
-
# is not even getting the `provider` correctly.
|
|
465
|
-
# The fix for this is ensuring the `provider` is explicitly passed to litellm.image_generation
|
|
466
|
-
# which is already happening in `gen_image` in `llm_funcs.py`
|
|
467
|
-
|
|
468
|
-
# If we reach here, it means the provider is not 'diffusers', 'openai', or 'gemini',
|
|
469
|
-
# and litellm is the intended route. We need to pass the provider explicitly.
|
|
470
|
-
# The original code here was trying to construct `model=f"{provider}/{model}"`
|
|
471
|
-
# but the error indicates `provider` itself was missing.
|
|
472
|
-
# The `image_generation` from litellm expects `model` to be `provider/model_name`.
|
|
473
|
-
# Since the `provider` variable is available, we can construct this.
|
|
474
|
-
|
|
475
|
-
# This block is for generic litellm providers (not diffusers, openai, gemini)
|
|
476
|
-
# The error indicates `provider` itself was not making it to litellm.
|
|
477
|
-
# This `generate_image` function already receives `provider`.
|
|
478
|
-
# The issue is likely how `gen_image` in `llm_funcs.py` calls this `generate_image`.
|
|
479
|
-
# However, if this `else` branch is hit, we ensure litellm gets the provider.
|
|
480
|
-
|
|
481
|
-
# Construct the model string for litellm
|
|
482
532
|
litellm_model_string = f"{provider}/{model}" if provider and model else model
|
|
483
533
|
|
|
484
534
|
image_response = image_generation(
|
|
485
535
|
prompt=prompt,
|
|
486
|
-
model=litellm_model_string,
|
|
536
|
+
model=litellm_model_string,
|
|
487
537
|
n=n_images,
|
|
488
538
|
size=size,
|
|
489
539
|
api_key=api_key,
|
|
@@ -509,6 +559,7 @@ def generate_image(
|
|
|
509
559
|
|
|
510
560
|
return all_generated_pil_images
|
|
511
561
|
|
|
562
|
+
|
|
512
563
|
def edit_image(
|
|
513
564
|
prompt: str,
|
|
514
565
|
image_path: str,
|
npcpy/gen/response.py
CHANGED
|
@@ -830,6 +830,234 @@ def get_llamacpp_response(
|
|
|
830
830
|
return result
|
|
831
831
|
|
|
832
832
|
|
|
833
|
+
_AIRLLM_MODEL_CACHE = {}
|
|
834
|
+
_AIRLLM_MLX_PATCHED = False
|
|
835
|
+
|
|
836
|
+
def _patch_airllm_mlx_bias():
|
|
837
|
+
"""
|
|
838
|
+
Monkey-patch airllm's MLX Attention/FeedForward to use bias=True.
|
|
839
|
+
AirLLM hardcodes bias=False which fails for non-Llama architectures (e.g. Qwen2).
|
|
840
|
+
Using bias=True is safe: MLX nn.Linear(bias=True) accepts weight-only updates,
|
|
841
|
+
so Llama models (no bias in weights) still work correctly.
|
|
842
|
+
"""
|
|
843
|
+
global _AIRLLM_MLX_PATCHED
|
|
844
|
+
if _AIRLLM_MLX_PATCHED:
|
|
845
|
+
return
|
|
846
|
+
try:
|
|
847
|
+
import airllm.airllm_llama_mlx as mlx_mod
|
|
848
|
+
import mlx.core as mx
|
|
849
|
+
from mlx import nn
|
|
850
|
+
|
|
851
|
+
class PatchedAttention(nn.Module):
|
|
852
|
+
def __init__(self, args):
|
|
853
|
+
super().__init__()
|
|
854
|
+
self.args = args
|
|
855
|
+
self.n_heads = args.n_heads
|
|
856
|
+
self.n_kv_heads = args.n_kv_heads
|
|
857
|
+
self.repeats = self.n_heads // self.n_kv_heads
|
|
858
|
+
self.scale = args.head_dim ** -0.5
|
|
859
|
+
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=True)
|
|
860
|
+
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=True)
|
|
861
|
+
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=True)
|
|
862
|
+
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=True)
|
|
863
|
+
self.rope = nn.RoPE(
|
|
864
|
+
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
def __call__(self, x, mask=None, cache=None):
|
|
868
|
+
B, L, D = x.shape
|
|
869
|
+
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
|
870
|
+
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
871
|
+
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
872
|
+
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
873
|
+
|
|
874
|
+
def repeat(a):
|
|
875
|
+
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
876
|
+
return a.reshape([B, self.n_heads, L, -1])
|
|
877
|
+
keys, values = map(repeat, (keys, values))
|
|
878
|
+
|
|
879
|
+
if cache is not None:
|
|
880
|
+
key_cache, value_cache = cache
|
|
881
|
+
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
882
|
+
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
883
|
+
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
884
|
+
values = mx.concatenate([value_cache, values], axis=2)
|
|
885
|
+
else:
|
|
886
|
+
queries = self.rope(queries)
|
|
887
|
+
keys = self.rope(keys)
|
|
888
|
+
|
|
889
|
+
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
|
890
|
+
if mask is not None:
|
|
891
|
+
scores += mask
|
|
892
|
+
weights = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
|
893
|
+
output = (weights @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
894
|
+
return self.wo(output), (keys, values)
|
|
895
|
+
|
|
896
|
+
class PatchedFeedForward(nn.Module):
|
|
897
|
+
def __init__(self, args):
|
|
898
|
+
super().__init__()
|
|
899
|
+
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=True)
|
|
900
|
+
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=True)
|
|
901
|
+
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=True)
|
|
902
|
+
|
|
903
|
+
def __call__(self, x):
|
|
904
|
+
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
|
905
|
+
|
|
906
|
+
mlx_mod.Attention = PatchedAttention
|
|
907
|
+
mlx_mod.FeedForward = PatchedFeedForward
|
|
908
|
+
_AIRLLM_MLX_PATCHED = True
|
|
909
|
+
logger.debug("Patched airllm MLX classes for bias support")
|
|
910
|
+
except Exception as e:
|
|
911
|
+
logger.warning(f"Failed to patch airllm MLX bias support: {e}")
|
|
912
|
+
|
|
913
|
+
def get_airllm_response(
|
|
914
|
+
prompt: str = None,
|
|
915
|
+
model: str = None,
|
|
916
|
+
tools: list = None,
|
|
917
|
+
tool_map: Dict = None,
|
|
918
|
+
format: str = None,
|
|
919
|
+
messages: List[Dict[str, str]] = None,
|
|
920
|
+
auto_process_tool_calls: bool = False,
|
|
921
|
+
**kwargs,
|
|
922
|
+
) -> Dict[str, Any]:
|
|
923
|
+
"""
|
|
924
|
+
Generate response using AirLLM for 70B+ model inference.
|
|
925
|
+
Supports macOS (MLX backend) and Linux (CUDA backend with 4-bit compression).
|
|
926
|
+
"""
|
|
927
|
+
import platform
|
|
928
|
+
is_macos = platform.system() == "Darwin"
|
|
929
|
+
|
|
930
|
+
result = {
|
|
931
|
+
"response": None,
|
|
932
|
+
"messages": messages.copy() if messages else [],
|
|
933
|
+
"raw_response": None,
|
|
934
|
+
"tool_calls": [],
|
|
935
|
+
"tool_results": []
|
|
936
|
+
}
|
|
937
|
+
|
|
938
|
+
try:
|
|
939
|
+
from airllm import AutoModel
|
|
940
|
+
except ImportError:
|
|
941
|
+
result["response"] = ""
|
|
942
|
+
result["error"] = "airllm not installed. Install with: pip install airllm"
|
|
943
|
+
return result
|
|
944
|
+
|
|
945
|
+
# Patch airllm MLX classes to support models with bias (e.g. Qwen)
|
|
946
|
+
if is_macos:
|
|
947
|
+
_patch_airllm_mlx_bias()
|
|
948
|
+
|
|
949
|
+
if prompt:
|
|
950
|
+
if result['messages'] and result['messages'][-1]["role"] == "user":
|
|
951
|
+
result['messages'][-1]["content"] = prompt
|
|
952
|
+
else:
|
|
953
|
+
result['messages'].append({"role": "user", "content": prompt})
|
|
954
|
+
|
|
955
|
+
if format == "json":
|
|
956
|
+
json_instruction = """If you are returning a json object, begin directly with the opening {.
|
|
957
|
+
Do not include any additional markdown formatting or leading ```json tags in your response."""
|
|
958
|
+
if result["messages"] and result["messages"][-1]["role"] == "user":
|
|
959
|
+
result["messages"][-1]["content"] += "\n" + json_instruction
|
|
960
|
+
|
|
961
|
+
model_name = model or "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
|
962
|
+
# 4-bit compression requires CUDA via bitsandbytes; skip on macOS
|
|
963
|
+
default_compression = None if is_macos else "4bit"
|
|
964
|
+
compression = kwargs.get("compression", default_compression)
|
|
965
|
+
max_tokens = kwargs.get("max_tokens", 256)
|
|
966
|
+
temperature = kwargs.get("temperature", 0.7)
|
|
967
|
+
|
|
968
|
+
# Resolve HF token for gated model access
|
|
969
|
+
hf_token = kwargs.get("hf_token")
|
|
970
|
+
if not hf_token:
|
|
971
|
+
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
|
972
|
+
if not hf_token:
|
|
973
|
+
try:
|
|
974
|
+
from huggingface_hub import HfFolder
|
|
975
|
+
hf_token = HfFolder.get_token()
|
|
976
|
+
except Exception:
|
|
977
|
+
pass
|
|
978
|
+
|
|
979
|
+
# Load or retrieve cached model
|
|
980
|
+
cache_key = f"{model_name}:{compression}"
|
|
981
|
+
if cache_key not in _AIRLLM_MODEL_CACHE:
|
|
982
|
+
load_kwargs = {"pretrained_model_name_or_path": model_name}
|
|
983
|
+
if compression:
|
|
984
|
+
load_kwargs["compression"] = compression
|
|
985
|
+
if hf_token:
|
|
986
|
+
load_kwargs["hf_token"] = hf_token
|
|
987
|
+
# Pass through additional airllm kwargs
|
|
988
|
+
for k in ["delete_original", "max_seq_len", "prefetching"]:
|
|
989
|
+
if k in kwargs:
|
|
990
|
+
load_kwargs[k] = kwargs[k]
|
|
991
|
+
_AIRLLM_MODEL_CACHE[cache_key] = AutoModel.from_pretrained(**load_kwargs)
|
|
992
|
+
|
|
993
|
+
air_model = _AIRLLM_MODEL_CACHE[cache_key]
|
|
994
|
+
|
|
995
|
+
try:
|
|
996
|
+
chat_text = air_model.tokenizer.apply_chat_template(
|
|
997
|
+
result["messages"], tokenize=False, add_generation_prompt=True
|
|
998
|
+
)
|
|
999
|
+
except Exception:
|
|
1000
|
+
# Fallback if chat template is not available
|
|
1001
|
+
chat_text = "\n".join(
|
|
1002
|
+
f"{m['role']}: {m['content']}" for m in result["messages"]
|
|
1003
|
+
)
|
|
1004
|
+
chat_text += "\nassistant:"
|
|
1005
|
+
|
|
1006
|
+
try:
|
|
1007
|
+
if is_macos:
|
|
1008
|
+
import mlx.core as mx
|
|
1009
|
+
tokens = air_model.tokenizer(
|
|
1010
|
+
chat_text, return_tensors="np", truncation=True, max_length=2048
|
|
1011
|
+
)
|
|
1012
|
+
output = air_model.generate(
|
|
1013
|
+
mx.array(tokens['input_ids']),
|
|
1014
|
+
max_new_tokens=max_tokens,
|
|
1015
|
+
)
|
|
1016
|
+
# MLX backend returns string directly
|
|
1017
|
+
response_content = output if isinstance(output, str) else str(output)
|
|
1018
|
+
else:
|
|
1019
|
+
tokens = air_model.tokenizer(
|
|
1020
|
+
chat_text, return_tensors="pt", truncation=True, max_length=2048
|
|
1021
|
+
)
|
|
1022
|
+
gen_out = air_model.generate(
|
|
1023
|
+
tokens['input_ids'].cuda(),
|
|
1024
|
+
max_new_tokens=max_tokens,
|
|
1025
|
+
)
|
|
1026
|
+
# CUDA backend returns token IDs, decode them
|
|
1027
|
+
output_ids = gen_out.sequences[0] if hasattr(gen_out, 'sequences') else gen_out[0]
|
|
1028
|
+
response_content = air_model.tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
1029
|
+
# Strip the input prompt from the output
|
|
1030
|
+
input_text = air_model.tokenizer.decode(tokens['input_ids'][0], skip_special_tokens=True)
|
|
1031
|
+
if response_content.startswith(input_text):
|
|
1032
|
+
response_content = response_content[len(input_text):]
|
|
1033
|
+
|
|
1034
|
+
response_content = response_content.strip()
|
|
1035
|
+
# Strip at common stop/special tokens that airllm doesn't handle
|
|
1036
|
+
for stop_tok in ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "</s>"]:
|
|
1037
|
+
if stop_tok in response_content:
|
|
1038
|
+
response_content = response_content[:response_content.index(stop_tok)].strip()
|
|
1039
|
+
except Exception as e:
|
|
1040
|
+
logger.error(f"AirLLM inference error: {e}")
|
|
1041
|
+
result["error"] = f"AirLLM inference error: {str(e)}"
|
|
1042
|
+
result["response"] = ""
|
|
1043
|
+
return result
|
|
1044
|
+
|
|
1045
|
+
result["response"] = response_content
|
|
1046
|
+
result["raw_response"] = response_content
|
|
1047
|
+
result["messages"].append({"role": "assistant", "content": response_content})
|
|
1048
|
+
|
|
1049
|
+
if format == "json":
|
|
1050
|
+
try:
|
|
1051
|
+
if response_content.startswith("```json"):
|
|
1052
|
+
response_content = response_content.replace("```json", "").replace("```", "").strip()
|
|
1053
|
+
parsed_response = json.loads(response_content)
|
|
1054
|
+
result["response"] = parsed_response
|
|
1055
|
+
except json.JSONDecodeError:
|
|
1056
|
+
result["error"] = f"Invalid JSON response: {response_content}"
|
|
1057
|
+
|
|
1058
|
+
return result
|
|
1059
|
+
|
|
1060
|
+
|
|
833
1061
|
def get_litellm_response(
|
|
834
1062
|
prompt: str = None,
|
|
835
1063
|
model: str = None,
|
|
@@ -921,6 +1149,17 @@ def get_litellm_response(
|
|
|
921
1149
|
auto_process_tool_calls=auto_process_tool_calls,
|
|
922
1150
|
**kwargs
|
|
923
1151
|
)
|
|
1152
|
+
elif provider == 'airllm':
|
|
1153
|
+
return get_airllm_response(
|
|
1154
|
+
prompt=prompt,
|
|
1155
|
+
model=model,
|
|
1156
|
+
tools=tools,
|
|
1157
|
+
tool_map=tool_map,
|
|
1158
|
+
format=format,
|
|
1159
|
+
messages=messages,
|
|
1160
|
+
auto_process_tool_calls=auto_process_tool_calls,
|
|
1161
|
+
**kwargs
|
|
1162
|
+
)
|
|
924
1163
|
elif provider == 'lmstudio' or (model and '.lmstudio' in str(model)):
|
|
925
1164
|
# LM Studio uses OpenAI-compatible API on port 1234
|
|
926
1165
|
# Also detect models with .lmstudio in path (e.g., /home/user/.lmstudio/models/...)
|