npcpy 1.3.21__py3-none-any.whl → 1.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
npcpy/gen/image_gen.py CHANGED
@@ -34,27 +34,21 @@ def generate_image_diffusers(
34
34
  if os.path.exists(checkpoint_path):
35
35
  print(f"🌋 Found model_final.pt at {checkpoint_path}.")
36
36
 
37
- # Load checkpoint to inspect it
38
37
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
39
38
 
40
- # Check if this is a custom SimpleUNet model (from your training code)
41
- # vs a Stable Diffusion UNet2DConditionModel
42
39
  if 'config' in checkpoint and hasattr(checkpoint['config'], 'image_size'):
43
40
  print(f"🌋 Detected custom SimpleUNet model, using custom generation")
44
- # Use your custom generate_image function from npcpy.ft.diff
45
41
  from npcpy.ft.diff import generate_image as custom_generate_image
46
42
 
47
- # Your custom model ignores prompts and generates based on training data
48
43
  image = custom_generate_image(
49
44
  model_path=checkpoint_path,
50
45
  prompt=prompt,
51
46
  num_samples=1,
52
- image_size=height # Use the requested height
47
+ image_size=height
53
48
  )
54
49
  return image
55
50
 
56
51
  else:
57
- # This is a Stable Diffusion checkpoint
58
52
  print(f"🌋 Detected Stable Diffusion UNet checkpoint")
59
53
  base_model_id = "runwayml/stable-diffusion-v1-5"
60
54
  print(f"🌋 Loading base pipeline: {base_model_id}")
@@ -67,7 +61,6 @@ def generate_image_diffusers(
67
61
 
68
62
  print(f"🌋 Loading custom UNet weights from {checkpoint_path}")
69
63
 
70
- # Extract the actual model state dict
71
64
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
72
65
  unet_state_dict = checkpoint['model_state_dict']
73
66
  print(f"🌋 Extracted model_state_dict from checkpoint")
@@ -75,7 +68,6 @@ def generate_image_diffusers(
75
68
  unet_state_dict = checkpoint
76
69
  print(f"🌋 Using checkpoint directly as state_dict")
77
70
 
78
- # Load the state dict into the UNet
79
71
  pipe.unet.load_state_dict(unet_state_dict)
80
72
  pipe = pipe.to(device)
81
73
  print(f"🌋 Successfully loaded fine-tuned UNet weights")
@@ -100,7 +92,6 @@ def generate_image_diffusers(
100
92
  variant="fp16" if torch_dtype == torch.float16 else None,
101
93
  )
102
94
 
103
- # Common pipeline setup for Stable Diffusion models
104
95
  if hasattr(pipe, 'enable_attention_slicing'):
105
96
  pipe.enable_attention_slicing()
106
97
 
@@ -142,16 +133,7 @@ def generate_image_diffusers(
142
133
  raise MemoryError(f"Insufficient memory for image generation with model {model}. Try a smaller model or reduce image size.")
143
134
  else:
144
135
  raise e
145
- import os
146
- import base64
147
- import io
148
- from typing import Union, List, Optional
149
136
 
150
- import PIL
151
- from PIL import Image
152
-
153
- import requests
154
- from urllib.request import urlopen
155
137
 
156
138
  def openai_image_gen(
157
139
  prompt: str,
@@ -184,13 +166,13 @@ def openai_image_gen(
184
166
  files_to_close.append(file_handle)
185
167
  elif isinstance(attachment, bytes):
186
168
  img_byte_arr = io.BytesIO(attachment)
187
- img_byte_arr.name = 'image.png' # FIX: Add filename hint
169
+ img_byte_arr.name = 'image.png'
188
170
  processed_images.append(img_byte_arr)
189
171
  elif isinstance(attachment, Image.Image):
190
172
  img_byte_arr = io.BytesIO()
191
173
  attachment.save(img_byte_arr, format='PNG')
192
174
  img_byte_arr.seek(0)
193
- img_byte_arr.name = 'image.png' # FIX: Add filename hint
175
+ img_byte_arr.name = 'image.png'
194
176
  processed_images.append(img_byte_arr)
195
177
 
196
178
  try:
@@ -202,7 +184,6 @@ def openai_image_gen(
202
184
  size=size_str,
203
185
  )
204
186
  finally:
205
- # This ensures any files we opened are properly closed
206
187
  for f in files_to_close:
207
188
  f.close()
208
189
  else:
@@ -231,7 +212,6 @@ def openai_image_gen(
231
212
  return collected_images
232
213
 
233
214
 
234
-
235
215
  def gemini_image_gen(
236
216
  prompt: str,
237
217
  model: str = "gemini-2.5-flash",
@@ -305,18 +285,21 @@ def gemini_image_gen(
305
285
  response = client.models.generate_content(
306
286
  model=model,
307
287
  contents=processed_contents,
288
+ config=types.GenerateContentConfig(
289
+ response_modalities=["IMAGE", "TEXT"],
290
+ ),
308
291
  )
309
-
292
+
310
293
  if hasattr(response, 'candidates') and response.candidates:
311
294
  for candidate in response.candidates:
312
295
  for part in candidate.content.parts:
313
296
  if hasattr(part, 'inline_data') and part.inline_data:
314
297
  image_data = part.inline_data.data
315
298
  collected_images.append(Image.open(BytesIO(image_data)))
316
-
299
+
317
300
  if not collected_images and hasattr(response, 'text'):
318
301
  print(f"Gemini response text: {response.text}")
319
-
302
+
320
303
  return collected_images
321
304
  else:
322
305
  if 'imagen' in model:
@@ -335,6 +318,9 @@ def gemini_image_gen(
335
318
  response = client.models.generate_content(
336
319
  model=model,
337
320
  contents=[prompt],
321
+ config=types.GenerateContentConfig(
322
+ response_modalities=["IMAGE", "TEXT"],
323
+ ),
338
324
  )
339
325
 
340
326
  if hasattr(response, 'candidates') and response.candidates:
@@ -351,7 +337,86 @@ def gemini_image_gen(
351
337
 
352
338
  else:
353
339
  raise ValueError(f"Unsupported Gemini image model or API usage for new generation: '{model}'")
354
- # In npcpy/gen/image_gen.py, find the generate_image function and replace it with this:
340
+
341
+
342
+ def ollama_image_gen(
343
+ prompt: str,
344
+ model: str = "x/z-image-turbo",
345
+ height: int = 512,
346
+ width: int = 512,
347
+ n_images: int = 1,
348
+ api_url: Optional[str] = None,
349
+ seed: Optional[int] = None,
350
+ negative_prompt: Optional[str] = None,
351
+ num_steps: Optional[int] = None,
352
+ ):
353
+ """Generate images using Ollama's image generation API.
354
+
355
+ Works with ollama image gen models like x/z-image-turbo and x/flux2-klein.
356
+ Uses the /api/generate endpoint with image gen specific options.
357
+ """
358
+ import requests
359
+
360
+ if api_url is None:
361
+ api_url = os.environ.get('OLLAMA_API_URL', 'http://localhost:11434')
362
+
363
+ endpoint = f"{api_url}/api/generate"
364
+
365
+ collected_images = []
366
+
367
+ for _ in range(n_images):
368
+ options = {}
369
+ if width:
370
+ options["width"] = width
371
+ if height:
372
+ options["height"] = height
373
+ if seed is not None:
374
+ options["seed"] = seed
375
+ if num_steps is not None:
376
+ options["num_steps"] = num_steps
377
+
378
+ payload = {
379
+ "model": model,
380
+ "prompt": prompt,
381
+ "stream": False,
382
+ }
383
+ if options:
384
+ payload["options"] = options
385
+ if negative_prompt:
386
+ payload["negative_prompt"] = negative_prompt
387
+
388
+ response = requests.post(endpoint, json=payload)
389
+
390
+ if not response.ok:
391
+ try:
392
+ err = response.json()
393
+ err_msg = err.get('error', response.text)
394
+ except Exception:
395
+ err_msg = response.text
396
+ raise RuntimeError(
397
+ f"Ollama image gen failed ({response.status_code}): {err_msg}\n"
398
+ f"Model: {model} — make sure it's pulled (`ollama pull {model}`)"
399
+ )
400
+
401
+ result = response.json()
402
+
403
+ if 'image' in result and result['image']:
404
+ image_bytes = base64.b64decode(result['image'])
405
+ image = Image.open(io.BytesIO(image_bytes))
406
+ collected_images.append(image)
407
+ elif 'images' in result and result['images']:
408
+ for img_b64 in result['images']:
409
+ image_bytes = base64.b64decode(img_b64)
410
+ image = Image.open(io.BytesIO(image_bytes))
411
+ collected_images.append(image)
412
+ else:
413
+ raise ValueError(
414
+ f"No images returned from Ollama. Response keys: {list(result.keys())}. "
415
+ f"Make sure '{model}' is an image generation model (e.g. x/z-image-turbo, x/flux2-klein)."
416
+ )
417
+
418
+ return collected_images
419
+
355
420
 
356
421
  def generate_image(
357
422
  prompt: str,
@@ -364,7 +429,7 @@ def generate_image(
364
429
  api_url: Optional[str] = None,
365
430
  attachments: Union[List[Union[str, bytes, Image.Image]], None] = None,
366
431
  save_path: Optional[str] = None,
367
- custom_model_path: Optional[str] = None, # <--- NEW: Accept custom_model_path,
432
+ custom_model_path: Optional[str] = None,
368
433
 
369
434
  ):
370
435
  """
@@ -373,7 +438,7 @@ def generate_image(
373
438
  Args:
374
439
  prompt (str): The prompt for generating/editing the image.
375
440
  model (str): The model to use.
376
- provider (str): The provider to use ('openai', 'diffusers', 'gemini').
441
+ provider (str): The provider to use ('openai', 'diffusers', 'gemini', 'ollama').
377
442
  height (int): The height of the output image.
378
443
  width (int): The width of the output image.
379
444
  n_images (int): Number of images to generate.
@@ -381,32 +446,31 @@ def generate_image(
381
446
  api_url (str): API URL for the provider.
382
447
  attachments (list): List of images for editing. Can be file paths, bytes, or PIL Images.
383
448
  save_path (str): Path to save the generated image.
384
- custom_model_path (str): Path to a locally fine-tuned Diffusers model. <--- NEW
449
+ custom_model_path (str): Path to a locally fine-tuned Diffusers model.
385
450
 
386
451
  Returns:
387
452
  List[PIL.Image.Image]: A list of generated PIL Image objects.
388
453
  """
389
454
  from urllib.request import urlopen
390
- import os # Ensure os is imported for path checks
455
+ import os
391
456
 
392
- if model is None and custom_model_path is None: # Only set default if no model or custom path is provided
457
+ if model is None and custom_model_path is None:
393
458
  if provider == "openai":
394
459
  model = "dall-e-2"
395
460
  elif provider == "diffusers":
396
461
  model = "runwayml/stable-diffusion-v1-5"
397
462
  elif provider == "gemini":
398
463
  model = "gemini-2.5-flash-image-preview"
464
+ elif provider == "ollama":
465
+ model = "x/z-image-turbo"
399
466
 
400
467
  all_generated_pil_images = []
401
468
 
402
- # <--- CRITICAL FIX: Handle custom_model_path for Diffusers here
403
469
  if provider == "diffusers":
404
- # If a custom_model_path is provided and exists, use it instead of a generic model name
405
470
  if custom_model_path and os.path.isdir(custom_model_path):
406
471
  print(f"🌋 Using custom Diffusers model from path: {custom_model_path}")
407
472
  model_to_use = custom_model_path
408
473
  else:
409
- # Otherwise, use the standard model name (e.g., "runwayml/stable-diffusion-v1-5")
410
474
  model_to_use = model
411
475
  print(f"🌋 Using standard Diffusers model: {model_to_use}")
412
476
 
@@ -414,7 +478,7 @@ def generate_image(
414
478
  try:
415
479
  image = generate_image_diffusers(
416
480
  prompt=prompt,
417
- model=model_to_use, # <--- Pass the resolved model_to_use
481
+ model=model_to_use,
418
482
  height=height,
419
483
  width=width
420
484
  )
@@ -447,43 +511,29 @@ def generate_image(
447
511
  )
448
512
  all_generated_pil_images.extend(images)
449
513
 
514
+ elif provider == "ollama":
515
+ images = ollama_image_gen(
516
+ prompt=prompt,
517
+ model=model,
518
+ height=height,
519
+ width=width,
520
+ n_images=n_images,
521
+ api_url=api_url
522
+ )
523
+ all_generated_pil_images.extend(images)
524
+
450
525
  else:
451
- # This is the fallback for other providers or if provider is not explicitly handled
452
526
  valid_sizes = ["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]
453
527
  size = f"{width}x{height}"
454
528
 
455
529
  if attachments is not None:
456
530
  raise ValueError("Image editing not supported with litellm provider")
457
531
 
458
- # The litellm.image_generation function expects the provider as part of the model string
459
- # e.g., "huggingface/starcoder" or "openai/dall-e-3"
460
- # Since we've already handled "diffusers", "openai", "gemini" above,
461
- # this 'else' block implies a generic litellm call.
462
- # We need to ensure the model string is correctly formatted for litellm.
463
- # However, the error message "LLM Provider NOT provided" suggests litellm
464
- # is not even getting the `provider` correctly.
465
- # The fix for this is ensuring the `provider` is explicitly passed to litellm.image_generation
466
- # which is already happening in `gen_image` in `llm_funcs.py`
467
-
468
- # If we reach here, it means the provider is not 'diffusers', 'openai', or 'gemini',
469
- # and litellm is the intended route. We need to pass the provider explicitly.
470
- # The original code here was trying to construct `model=f"{provider}/{model}"`
471
- # but the error indicates `provider` itself was missing.
472
- # The `image_generation` from litellm expects `model` to be `provider/model_name`.
473
- # Since the `provider` variable is available, we can construct this.
474
-
475
- # This block is for generic litellm providers (not diffusers, openai, gemini)
476
- # The error indicates `provider` itself was not making it to litellm.
477
- # This `generate_image` function already receives `provider`.
478
- # The issue is likely how `gen_image` in `llm_funcs.py` calls this `generate_image`.
479
- # However, if this `else` branch is hit, we ensure litellm gets the provider.
480
-
481
- # Construct the model string for litellm
482
532
  litellm_model_string = f"{provider}/{model}" if provider and model else model
483
533
 
484
534
  image_response = image_generation(
485
535
  prompt=prompt,
486
- model=litellm_model_string, # <--- Ensure model string includes provider for litellm
536
+ model=litellm_model_string,
487
537
  n=n_images,
488
538
  size=size,
489
539
  api_key=api_key,
@@ -509,6 +559,7 @@ def generate_image(
509
559
 
510
560
  return all_generated_pil_images
511
561
 
562
+
512
563
  def edit_image(
513
564
  prompt: str,
514
565
  image_path: str,
npcpy/gen/response.py CHANGED
@@ -830,6 +830,234 @@ def get_llamacpp_response(
830
830
  return result
831
831
 
832
832
 
833
+ _AIRLLM_MODEL_CACHE = {}
834
+ _AIRLLM_MLX_PATCHED = False
835
+
836
+ def _patch_airllm_mlx_bias():
837
+ """
838
+ Monkey-patch airllm's MLX Attention/FeedForward to use bias=True.
839
+ AirLLM hardcodes bias=False which fails for non-Llama architectures (e.g. Qwen2).
840
+ Using bias=True is safe: MLX nn.Linear(bias=True) accepts weight-only updates,
841
+ so Llama models (no bias in weights) still work correctly.
842
+ """
843
+ global _AIRLLM_MLX_PATCHED
844
+ if _AIRLLM_MLX_PATCHED:
845
+ return
846
+ try:
847
+ import airllm.airllm_llama_mlx as mlx_mod
848
+ import mlx.core as mx
849
+ from mlx import nn
850
+
851
+ class PatchedAttention(nn.Module):
852
+ def __init__(self, args):
853
+ super().__init__()
854
+ self.args = args
855
+ self.n_heads = args.n_heads
856
+ self.n_kv_heads = args.n_kv_heads
857
+ self.repeats = self.n_heads // self.n_kv_heads
858
+ self.scale = args.head_dim ** -0.5
859
+ self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=True)
860
+ self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=True)
861
+ self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=True)
862
+ self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=True)
863
+ self.rope = nn.RoPE(
864
+ args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
865
+ )
866
+
867
+ def __call__(self, x, mask=None, cache=None):
868
+ B, L, D = x.shape
869
+ queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
870
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
871
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
872
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
873
+
874
+ def repeat(a):
875
+ a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
876
+ return a.reshape([B, self.n_heads, L, -1])
877
+ keys, values = map(repeat, (keys, values))
878
+
879
+ if cache is not None:
880
+ key_cache, value_cache = cache
881
+ queries = self.rope(queries, offset=key_cache.shape[2])
882
+ keys = self.rope(keys, offset=key_cache.shape[2])
883
+ keys = mx.concatenate([key_cache, keys], axis=2)
884
+ values = mx.concatenate([value_cache, values], axis=2)
885
+ else:
886
+ queries = self.rope(queries)
887
+ keys = self.rope(keys)
888
+
889
+ scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
890
+ if mask is not None:
891
+ scores += mask
892
+ weights = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
893
+ output = (weights @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
894
+ return self.wo(output), (keys, values)
895
+
896
+ class PatchedFeedForward(nn.Module):
897
+ def __init__(self, args):
898
+ super().__init__()
899
+ self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=True)
900
+ self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=True)
901
+ self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=True)
902
+
903
+ def __call__(self, x):
904
+ return self.w2(nn.silu(self.w1(x)) * self.w3(x))
905
+
906
+ mlx_mod.Attention = PatchedAttention
907
+ mlx_mod.FeedForward = PatchedFeedForward
908
+ _AIRLLM_MLX_PATCHED = True
909
+ logger.debug("Patched airllm MLX classes for bias support")
910
+ except Exception as e:
911
+ logger.warning(f"Failed to patch airllm MLX bias support: {e}")
912
+
913
+ def get_airllm_response(
914
+ prompt: str = None,
915
+ model: str = None,
916
+ tools: list = None,
917
+ tool_map: Dict = None,
918
+ format: str = None,
919
+ messages: List[Dict[str, str]] = None,
920
+ auto_process_tool_calls: bool = False,
921
+ **kwargs,
922
+ ) -> Dict[str, Any]:
923
+ """
924
+ Generate response using AirLLM for 70B+ model inference.
925
+ Supports macOS (MLX backend) and Linux (CUDA backend with 4-bit compression).
926
+ """
927
+ import platform
928
+ is_macos = platform.system() == "Darwin"
929
+
930
+ result = {
931
+ "response": None,
932
+ "messages": messages.copy() if messages else [],
933
+ "raw_response": None,
934
+ "tool_calls": [],
935
+ "tool_results": []
936
+ }
937
+
938
+ try:
939
+ from airllm import AutoModel
940
+ except ImportError:
941
+ result["response"] = ""
942
+ result["error"] = "airllm not installed. Install with: pip install airllm"
943
+ return result
944
+
945
+ # Patch airllm MLX classes to support models with bias (e.g. Qwen)
946
+ if is_macos:
947
+ _patch_airllm_mlx_bias()
948
+
949
+ if prompt:
950
+ if result['messages'] and result['messages'][-1]["role"] == "user":
951
+ result['messages'][-1]["content"] = prompt
952
+ else:
953
+ result['messages'].append({"role": "user", "content": prompt})
954
+
955
+ if format == "json":
956
+ json_instruction = """If you are returning a json object, begin directly with the opening {.
957
+ Do not include any additional markdown formatting or leading ```json tags in your response."""
958
+ if result["messages"] and result["messages"][-1]["role"] == "user":
959
+ result["messages"][-1]["content"] += "\n" + json_instruction
960
+
961
+ model_name = model or "meta-llama/Meta-Llama-3.1-70B-Instruct"
962
+ # 4-bit compression requires CUDA via bitsandbytes; skip on macOS
963
+ default_compression = None if is_macos else "4bit"
964
+ compression = kwargs.get("compression", default_compression)
965
+ max_tokens = kwargs.get("max_tokens", 256)
966
+ temperature = kwargs.get("temperature", 0.7)
967
+
968
+ # Resolve HF token for gated model access
969
+ hf_token = kwargs.get("hf_token")
970
+ if not hf_token:
971
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
972
+ if not hf_token:
973
+ try:
974
+ from huggingface_hub import HfFolder
975
+ hf_token = HfFolder.get_token()
976
+ except Exception:
977
+ pass
978
+
979
+ # Load or retrieve cached model
980
+ cache_key = f"{model_name}:{compression}"
981
+ if cache_key not in _AIRLLM_MODEL_CACHE:
982
+ load_kwargs = {"pretrained_model_name_or_path": model_name}
983
+ if compression:
984
+ load_kwargs["compression"] = compression
985
+ if hf_token:
986
+ load_kwargs["hf_token"] = hf_token
987
+ # Pass through additional airllm kwargs
988
+ for k in ["delete_original", "max_seq_len", "prefetching"]:
989
+ if k in kwargs:
990
+ load_kwargs[k] = kwargs[k]
991
+ _AIRLLM_MODEL_CACHE[cache_key] = AutoModel.from_pretrained(**load_kwargs)
992
+
993
+ air_model = _AIRLLM_MODEL_CACHE[cache_key]
994
+
995
+ try:
996
+ chat_text = air_model.tokenizer.apply_chat_template(
997
+ result["messages"], tokenize=False, add_generation_prompt=True
998
+ )
999
+ except Exception:
1000
+ # Fallback if chat template is not available
1001
+ chat_text = "\n".join(
1002
+ f"{m['role']}: {m['content']}" for m in result["messages"]
1003
+ )
1004
+ chat_text += "\nassistant:"
1005
+
1006
+ try:
1007
+ if is_macos:
1008
+ import mlx.core as mx
1009
+ tokens = air_model.tokenizer(
1010
+ chat_text, return_tensors="np", truncation=True, max_length=2048
1011
+ )
1012
+ output = air_model.generate(
1013
+ mx.array(tokens['input_ids']),
1014
+ max_new_tokens=max_tokens,
1015
+ )
1016
+ # MLX backend returns string directly
1017
+ response_content = output if isinstance(output, str) else str(output)
1018
+ else:
1019
+ tokens = air_model.tokenizer(
1020
+ chat_text, return_tensors="pt", truncation=True, max_length=2048
1021
+ )
1022
+ gen_out = air_model.generate(
1023
+ tokens['input_ids'].cuda(),
1024
+ max_new_tokens=max_tokens,
1025
+ )
1026
+ # CUDA backend returns token IDs, decode them
1027
+ output_ids = gen_out.sequences[0] if hasattr(gen_out, 'sequences') else gen_out[0]
1028
+ response_content = air_model.tokenizer.decode(output_ids, skip_special_tokens=True)
1029
+ # Strip the input prompt from the output
1030
+ input_text = air_model.tokenizer.decode(tokens['input_ids'][0], skip_special_tokens=True)
1031
+ if response_content.startswith(input_text):
1032
+ response_content = response_content[len(input_text):]
1033
+
1034
+ response_content = response_content.strip()
1035
+ # Strip at common stop/special tokens that airllm doesn't handle
1036
+ for stop_tok in ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "</s>"]:
1037
+ if stop_tok in response_content:
1038
+ response_content = response_content[:response_content.index(stop_tok)].strip()
1039
+ except Exception as e:
1040
+ logger.error(f"AirLLM inference error: {e}")
1041
+ result["error"] = f"AirLLM inference error: {str(e)}"
1042
+ result["response"] = ""
1043
+ return result
1044
+
1045
+ result["response"] = response_content
1046
+ result["raw_response"] = response_content
1047
+ result["messages"].append({"role": "assistant", "content": response_content})
1048
+
1049
+ if format == "json":
1050
+ try:
1051
+ if response_content.startswith("```json"):
1052
+ response_content = response_content.replace("```json", "").replace("```", "").strip()
1053
+ parsed_response = json.loads(response_content)
1054
+ result["response"] = parsed_response
1055
+ except json.JSONDecodeError:
1056
+ result["error"] = f"Invalid JSON response: {response_content}"
1057
+
1058
+ return result
1059
+
1060
+
833
1061
  def get_litellm_response(
834
1062
  prompt: str = None,
835
1063
  model: str = None,
@@ -921,6 +1149,17 @@ def get_litellm_response(
921
1149
  auto_process_tool_calls=auto_process_tool_calls,
922
1150
  **kwargs
923
1151
  )
1152
+ elif provider == 'airllm':
1153
+ return get_airllm_response(
1154
+ prompt=prompt,
1155
+ model=model,
1156
+ tools=tools,
1157
+ tool_map=tool_map,
1158
+ format=format,
1159
+ messages=messages,
1160
+ auto_process_tool_calls=auto_process_tool_calls,
1161
+ **kwargs
1162
+ )
924
1163
  elif provider == 'lmstudio' or (model and '.lmstudio' in str(model)):
925
1164
  # LM Studio uses OpenAI-compatible API on port 1234
926
1165
  # Also detect models with .lmstudio in path (e.g., /home/user/.lmstudio/models/...)