inferencesh 0.2.31__py3-none-any.whl → 0.4.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
inferencesh/models/llm.py CHANGED
@@ -6,6 +6,7 @@ from threading import Thread
6
6
  import time
7
7
  from contextlib import contextmanager
8
8
  import base64
9
+ import json
9
10
 
10
11
  from .base import BaseAppInput, BaseAppOutput
11
12
  from .file import File
@@ -14,40 +15,48 @@ class ContextMessageRole(str, Enum):
14
15
  USER = "user"
15
16
  ASSISTANT = "assistant"
16
17
  SYSTEM = "system"
18
+ TOOL = "tool"
17
19
 
18
20
 
19
21
  class Message(BaseAppInput):
20
22
  role: ContextMessageRole
21
23
  content: str
22
24
 
23
-
24
25
  class ContextMessage(BaseAppInput):
25
26
  role: ContextMessageRole = Field(
26
- description="The role of the message",
27
+ description="the role of the message. user, assistant, or system",
27
28
  )
28
29
  text: str = Field(
29
- description="The text content of the message"
30
+ description="the text content of the message"
30
31
  )
31
32
  image: Optional[File] = Field(
32
- description="The image url of the message",
33
+ description="the image file of the message",
34
+ default=None
35
+ )
36
+ images: Optional[List[File]] = Field(
37
+ description="the images of the message",
38
+ default=None
39
+ )
40
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
41
+ description="the tool calls of the message",
42
+ default=None
43
+ )
44
+ tool_call_id: Optional[str] = Field(
45
+ description="the tool call id for tool role messages",
33
46
  default=None
34
47
  )
35
48
 
36
49
  class BaseLLMInput(BaseAppInput):
37
50
  """Base class with common LLM fields."""
38
51
  system_prompt: str = Field(
39
- description="The system prompt to use for the model",
40
- default="You are a helpful assistant that can answer questions and help with tasks.",
52
+ description="the system prompt to use for the model",
53
+ default="you are a helpful assistant that can answer questions and help with tasks.",
41
54
  examples=[
42
- "You are a helpful assistant that can answer questions and help with tasks.",
43
- "You are a certified medical professional who can provide accurate health information.",
44
- "You are a certified financial advisor who can give sound investment guidance.",
45
- "You are a certified cybersecurity expert who can explain security best practices.",
46
- "You are a certified environmental scientist who can discuss climate and sustainability.",
55
+ "you are a helpful assistant that can answer questions and help with tasks.",
47
56
  ]
48
57
  )
49
58
  context: List[ContextMessage] = Field(
50
- description="The context to use for the model",
59
+ description="the context to use for the model",
51
60
  default=[],
52
61
  examples=[
53
62
  [
@@ -56,38 +65,50 @@ class BaseLLMInput(BaseAppInput):
56
65
  ]
57
66
  ]
58
67
  )
68
+ role: ContextMessageRole = Field(
69
+ description="the role of the input text",
70
+ default=ContextMessageRole.USER
71
+ )
59
72
  text: str = Field(
60
- description="The user prompt to use for the model",
73
+ description="the input text to use for the model",
61
74
  examples=[
62
- "What is the capital of France?",
63
- "What is the weather like today?",
64
- "Can you help me write a poem about spring?",
65
- "Explain quantum computing in simple terms"
75
+ "write a haiku about artificial general intelligence"
66
76
  ]
67
77
  )
68
- temperature: float = Field(default=0.7)
69
- top_p: float = Field(default=0.95)
70
- max_tokens: int = Field(default=4096)
78
+ temperature: float = Field(default=0.7, ge=0.0, le=1.0)
79
+ top_p: float = Field(default=0.95, ge=0.0, le=1.0)
71
80
  context_size: int = Field(default=4096)
72
81
 
73
82
  class ImageCapabilityMixin(BaseModel):
74
83
  """Mixin for models that support image inputs."""
75
84
  image: Optional[File] = Field(
76
- description="The image to use for the model",
77
- default=None
85
+ description="the image to use for the model",
86
+ default=None,
87
+ contentMediaType="image/*",
88
+ )
89
+
90
+ class MultipleImageCapabilityMixin(BaseModel):
91
+ """Mixin for models that support image inputs."""
92
+ images: Optional[List[File]] = Field(
93
+ description="the images to use for the model",
94
+ default=None,
78
95
  )
79
96
 
80
97
  class ReasoningCapabilityMixin(BaseModel):
81
98
  """Mixin for models that support reasoning."""
82
99
  reasoning: bool = Field(
83
- description="Enable step-by-step reasoning",
100
+ description="enable step-by-step reasoning",
84
101
  default=False
85
102
  )
86
103
 
87
104
  class ToolsCapabilityMixin(BaseModel):
88
105
  """Mixin for models that support tool/function calling."""
89
106
  tools: Optional[List[Dict[str, Any]]] = Field(
90
- description="Tool definitions for function calling",
107
+ description="tool definitions for function calling",
108
+ default=None
109
+ )
110
+ tool_call_id: Optional[str] = Field(
111
+ description="the tool call id for tool role messages",
91
112
  default=None
92
113
  )
93
114
 
@@ -112,26 +133,26 @@ class LLMUsage(BaseAppOutput):
112
133
 
113
134
  class BaseLLMOutput(BaseAppOutput):
114
135
  """Base class for LLM outputs with common fields."""
115
- response: str = Field(description="The generated text response")
136
+ response: str = Field(description="the generated text response")
116
137
 
117
138
  class LLMUsageMixin(BaseModel):
118
139
  """Mixin for models that provide token usage statistics."""
119
140
  usage: Optional[LLMUsage] = Field(
120
- description="Token usage statistics",
141
+ description="token usage statistics",
121
142
  default=None
122
143
  )
123
144
 
124
145
  class ReasoningMixin(BaseModel):
125
146
  """Mixin for models that support reasoning."""
126
147
  reasoning: Optional[str] = Field(
127
- description="The reasoning output of the model",
148
+ description="the reasoning output of the model",
128
149
  default=None
129
150
  )
130
151
 
131
152
  class ToolCallsMixin(BaseModel):
132
153
  """Mixin for models that support tool calls."""
133
154
  tool_calls: Optional[List[Dict[str, Any]]] = Field(
134
- description="Tool calls for function calling",
155
+ description="tool calls for function calling",
135
156
  default=None
136
157
  )
137
158
 
@@ -217,28 +238,75 @@ def build_messages(
217
238
  text = transform_user_message(msg.text) if transform_user_message and msg.role == ContextMessageRole.USER else msg.text
218
239
  if text:
219
240
  parts.append({"type": "text", "text": text})
241
+ else:
242
+ parts.append({"type": "text", "text": ""})
220
243
  if msg.image:
221
244
  if msg.image.path:
222
245
  image_data_uri = image_to_base64_data_uri(msg.image.path)
223
246
  parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
224
247
  elif msg.image.uri:
225
248
  parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
249
+ if msg.images:
250
+ for image in msg.images:
251
+ if image.path:
252
+ image_data_uri = image_to_base64_data_uri(image.path)
253
+ parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
254
+ elif image.uri:
255
+ parts.append({"type": "image_url", "image_url": {"url": image.uri}})
226
256
  if allow_multipart:
227
257
  return parts
228
258
  if len(parts) == 1 and parts[0]["type"] == "text":
229
259
  return parts[0]["text"]
230
- raise ValueError("Image content requires multipart support")
260
+ if len(parts) > 1:
261
+ if parts.any(lambda x: x["type"] == "image_url"):
262
+ raise ValueError("Image content requires multipart support")
263
+ return parts
264
+ raise ValueError("Invalid message content")
231
265
 
232
- multipart = any(m.image for m in input_data.context) or input_data.image is not None
233
266
  messages = [{"role": "system", "content": input_data.system_prompt}] if input_data.system_prompt is not None and input_data.system_prompt != "" else []
234
267
 
235
268
  def merge_messages(messages: List[ContextMessage]) -> ContextMessage:
236
269
  text = "\n\n".join(msg.text for msg in messages if msg.text)
237
- images = [msg.image for msg in messages if msg.image]
238
- image = images[0] if images else None # TODO: handle multiple images
239
- return ContextMessage(role=messages[0].role, text=text, image=image)
270
+ images = []
271
+ # Collect single images
272
+ for msg in messages:
273
+ if msg.image:
274
+ images.append(msg.image)
275
+ # Collect multiple images (flatten the list)
276
+ for msg in messages:
277
+ if msg.images:
278
+ images.extend(msg.images)
279
+ # Set image to single File if there's exactly one, otherwise None
280
+ image = images[0] if len(images) == 1 else None
281
+ # Set images to the list if there are multiple, otherwise None
282
+ images_list = images if len(images) > 1 else None
283
+ return ContextMessage(role=messages[0].role, text=text, image=image, images=images_list)
284
+
285
+ def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
286
+ tool_calls = []
287
+ for msg in messages:
288
+ if msg.tool_calls:
289
+ tool_calls.extend(msg.tool_calls)
290
+ return tool_calls
291
+
292
+ user_input_text = ""
293
+ if hasattr(input_data, "text"):
294
+ user_input_text = transform_user_message(input_data.text) if transform_user_message else input_data.text
295
+
296
+ user_input_image = None
297
+ multipart = any(m.image for m in input_data.context)
298
+ if hasattr(input_data, "image"):
299
+ user_input_image = input_data.image
300
+ multipart = multipart or input_data.image is not None
301
+
302
+ user_input_images = None
303
+ if hasattr(input_data, "images"):
304
+ user_input_images = input_data.images
305
+ multipart = multipart or input_data.images is not None
240
306
 
241
- user_msg = ContextMessage(role=ContextMessageRole.USER, text=input_data.text, image=input_data.image)
307
+ input_role = input_data.role if hasattr(input_data, "role") else ContextMessageRole.USER
308
+ input_tool_call_id = input_data.tool_call_id if hasattr(input_data, "tool_call_id") else None
309
+ user_msg = ContextMessage(role=input_role, text=user_input_text, image=user_input_image, images=user_input_images, tool_call_id=input_tool_call_id)
242
310
 
243
311
  input_data.context.append(user_msg)
244
312
 
@@ -250,21 +318,104 @@ def build_messages(
250
318
  current_messages.append(msg)
251
319
  current_role = msg.role
252
320
  else:
253
- messages.append({
254
- "role": current_role,
255
- "content": render_message(merge_messages(current_messages), allow_multipart=multipart)
256
- })
321
+ # Convert role enum to string for OpenAI API compatibility
322
+ role_str = current_role.value if hasattr(current_role, "value") else current_role
323
+ msg_dict = {
324
+ "role": role_str,
325
+ "content": render_message(merge_messages(current_messages), allow_multipart=multipart),
326
+ }
327
+
328
+ # Only add tool_calls if not empty
329
+ tool_calls = merge_tool_calls(current_messages)
330
+ if tool_calls:
331
+ # Ensure arguments are JSON strings (OpenAI API requirement)
332
+ for tc in tool_calls:
333
+ if "function" in tc and "arguments" in tc["function"]:
334
+ if isinstance(tc["function"]["arguments"], dict):
335
+ tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
336
+ msg_dict["tool_calls"] = tool_calls
337
+
338
+ # Add tool_call_id for tool role messages (required by OpenAI API)
339
+ if role_str == "tool":
340
+ if current_messages and current_messages[0].tool_call_id:
341
+ msg_dict["tool_call_id"] = current_messages[0].tool_call_id
342
+ else:
343
+ # If not provided, use empty string to satisfy schema
344
+ msg_dict["tool_call_id"] = ""
345
+
346
+ messages.append(msg_dict)
257
347
  current_messages = [msg]
258
348
  current_role = msg.role
349
+
259
350
  if len(current_messages) > 0:
260
- messages.append({
261
- "role": current_role,
262
- "content": render_message(merge_messages(current_messages), allow_multipart=multipart)
263
- })
351
+ # Convert role enum to string for OpenAI API compatibility
352
+ role_str = current_role.value if hasattr(current_role, "value") else current_role
353
+ msg_dict = {
354
+ "role": role_str,
355
+ "content": render_message(merge_messages(current_messages), allow_multipart=multipart),
356
+ }
357
+
358
+ # Only add tool_calls if not empty
359
+ tool_calls = merge_tool_calls(current_messages)
360
+ if tool_calls:
361
+ # Ensure arguments are JSON strings (OpenAI API requirement)
362
+ for tc in tool_calls:
363
+ if "function" in tc and "arguments" in tc["function"]:
364
+ if isinstance(tc["function"]["arguments"], dict):
365
+ tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
366
+ msg_dict["tool_calls"] = tool_calls
367
+
368
+ # Add tool_call_id for tool role messages (required by OpenAI API)
369
+ if role_str == "tool":
370
+ if current_messages and current_messages[0].tool_call_id:
371
+ msg_dict["tool_call_id"] = current_messages[0].tool_call_id
372
+ else:
373
+ # If not provided, use empty string to satisfy schema
374
+ msg_dict["tool_call_id"] = ""
375
+
376
+ messages.append(msg_dict)
264
377
 
265
378
  return messages
266
379
 
267
380
 
381
+ def build_tools(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
382
+ """Build tools in OpenAI API format.
383
+
384
+ Ensures tools are properly formatted:
385
+ - Wrapped in {"type": "function", "function": {...}}
386
+ - Parameters is never None (OpenAI API requirement)
387
+ """
388
+ if not tools:
389
+ return None
390
+
391
+ result = []
392
+ for tool in tools:
393
+ # Extract function definition
394
+ if "type" in tool and "function" in tool:
395
+ func_def = tool["function"].copy()
396
+ else:
397
+ func_def = tool.copy()
398
+
399
+ # Ensure parameters is not None (OpenAI API requirement)
400
+ if func_def.get("parameters") is None:
401
+ func_def["parameters"] = {"type": "object", "properties": {}}
402
+ # Also ensure properties within parameters is not None
403
+ elif func_def["parameters"].get("properties") is None:
404
+ func_def["parameters"]["properties"] = {}
405
+ else:
406
+ # Remove properties with null values (OpenAI API doesn't accept them)
407
+ properties = func_def["parameters"].get("properties", {})
408
+ if properties:
409
+ func_def["parameters"]["properties"] = {
410
+ k: v for k, v in properties.items() if v is not None
411
+ }
412
+
413
+ # Wrap in OpenAI format
414
+ result.append({"type": "function", "function": func_def})
415
+
416
+ return result
417
+
418
+
268
419
  class StreamResponse:
269
420
  """Holds a single chunk of streamed response."""
270
421
  def __init__(self):
@@ -462,26 +613,27 @@ class ResponseTransformer:
462
613
  text: Cleaned text to process for reasoning
463
614
  """
464
615
  # Default implementation for <think> style reasoning
465
- if "<think>" in text and not self.state.state_changes["reasoning_started"]:
616
+ # Check for tags in the complete buffer
617
+ if "<think>" in self.state.buffer and not self.state.state_changes["reasoning_started"]:
466
618
  self.state.state_changes["reasoning_started"] = True
467
619
  if self.timing:
468
620
  self.timing.start_reasoning()
469
621
 
470
- if "</think>" in text and not self.state.state_changes["reasoning_ended"]:
471
- self.state.state_changes["reasoning_ended"] = True
472
- if self.timing:
473
- # Estimate token count from character count (rough approximation)
474
- token_count = len(self.state.buffer.split("<think>")[1].split("</think>")[0]) // 4
475
- self.timing.end_reasoning(token_count)
476
-
477
- if "<think>" in self.state.buffer:
478
- parts = self.state.buffer.split("</think>", 1)
479
- if len(parts) > 1:
480
- self.state.reasoning = parts[0].split("<think>", 1)[1].strip()
481
- self.state.response = parts[1].strip()
482
- else:
483
- self.state.reasoning = self.state.buffer.split("<think>", 1)[1].strip()
484
- self.state.response = ""
622
+ # Extract content and handle end of reasoning
623
+ parts = self.state.buffer.split("<think>", 1)
624
+ if len(parts) > 1:
625
+ reasoning_text = parts[1]
626
+ end_parts = reasoning_text.split("</think>", 1)
627
+ self.state.reasoning = end_parts[0].strip()
628
+ self.state.response = end_parts[1].strip() if len(end_parts) > 1 else ""
629
+
630
+ # Check for end tag in complete buffer
631
+ if "</think>" in self.state.buffer and not self.state.state_changes["reasoning_ended"]:
632
+ self.state.state_changes["reasoning_ended"] = True
633
+ if self.timing:
634
+ # Estimate token count from character count (rough approximation)
635
+ token_count = len(self.state.reasoning) // 4
636
+ self.timing.end_reasoning(token_count)
485
637
  else:
486
638
  self.state.response = self.state.buffer
487
639
 
@@ -579,13 +731,13 @@ def stream_generate(
579
731
  tool_choice: Optional[Dict[str, Any]] = None,
580
732
  temperature: float = 0.7,
581
733
  top_p: float = 0.95,
582
- max_tokens: int = 4096,
583
734
  stop: Optional[List[str]] = None,
584
735
  verbose: bool = False,
585
736
  output_cls: type[BaseLLMOutput] = LLMOutput,
737
+ kwargs: Optional[Dict[str, Any]] = None,
586
738
  ) -> Generator[BaseLLMOutput, None, None]:
587
739
  """Stream generate from LLaMA.cpp model with timing and usage tracking."""
588
-
740
+
589
741
  # Create queues for communication between threads
590
742
  response_queue = Queue()
591
743
  error_queue = Queue()
@@ -603,9 +755,10 @@ def stream_generate(
603
755
  "stream": True,
604
756
  "temperature": temperature,
605
757
  "top_p": top_p,
606
- "max_tokens": max_tokens,
607
- "stop": stop
758
+ "stop": stop,
608
759
  }
760
+ if kwargs:
761
+ completion_kwargs.update(kwargs)
609
762
  if tools is not None:
610
763
  completion_kwargs["tools"] = tools
611
764
  if tool_choice is not None:
@@ -617,8 +770,6 @@ def stream_generate(
617
770
  completion = model.create_chat_completion(**completion_kwargs)
618
771
 
619
772
  for chunk in completion:
620
- if verbose:
621
- print(chunk)
622
773
  response_queue.put(("chunk", chunk))
623
774
  # Update keep-alive timestamp
624
775
  keep_alive_queue.put(("alive", time.time()))
@@ -627,7 +778,9 @@ def stream_generate(
627
778
  response_queue.put(("done", None))
628
779
 
629
780
  except Exception as e:
630
- error_queue.put(e)
781
+ # Preserve the full exception with traceback
782
+ import sys
783
+ error_queue.put((e, sys.exc_info()[2]))
631
784
  response_queue.put(("error", str(e)))
632
785
 
633
786
  with timing_context() as timing:
@@ -645,6 +798,7 @@ def stream_generate(
645
798
  last_activity = time.time()
646
799
  init_timeout = 30.0 # 30 seconds for initial response
647
800
  chunk_timeout = 10.0 # 10 seconds between chunks
801
+ chunks_begun = False
648
802
 
649
803
  try:
650
804
  # Wait for initial setup
@@ -657,17 +811,25 @@ def stream_generate(
657
811
  raise RuntimeError(f"Model failed to initialize within {init_timeout} seconds")
658
812
 
659
813
  while True:
660
- # Check for errors
814
+ # Check for errors - now with proper exception chaining
661
815
  if not error_queue.empty():
662
- raise error_queue.get()
816
+ exc, tb = error_queue.get()
817
+ if isinstance(exc, Exception):
818
+ raise exc.with_traceback(tb)
819
+ else:
820
+ raise RuntimeError(f"Unknown error in worker thread: {exc}")
663
821
 
664
822
  # Check keep-alive
665
- while not keep_alive_queue.empty():
666
- _, timestamp = keep_alive_queue.get_nowait()
667
- last_activity = timestamp
823
+ try:
824
+ while not keep_alive_queue.empty():
825
+ _, timestamp = keep_alive_queue.get_nowait()
826
+ last_activity = timestamp
827
+ except Empty:
828
+ # Ignore empty queue - this is expected
829
+ pass
668
830
 
669
831
  # Check for timeout
670
- if time.time() - last_activity > chunk_timeout:
832
+ if chunks_begun and time.time() - last_activity > chunk_timeout:
671
833
  raise RuntimeError(f"No response from model for {chunk_timeout} seconds")
672
834
 
673
835
  # Get next chunk
@@ -677,16 +839,23 @@ def stream_generate(
677
839
  continue
678
840
 
679
841
  if msg_type == "error":
842
+ # If we get an error message but no exception in error_queue,
843
+ # create a new error
680
844
  raise RuntimeError(f"Generation error: {data}")
681
845
  elif msg_type == "done":
682
846
  break
683
847
 
684
848
  chunk = data
685
849
 
850
+ if verbose:
851
+ print(chunk)
852
+
686
853
  # Mark first token time
687
854
  if not timing.first_token_time:
688
855
  timing.mark_first_token()
689
856
 
857
+ chunks_begun = True
858
+
690
859
  # Update response state from chunk
691
860
  response.update_from_chunk(chunk, timing)
692
861
 
@@ -700,12 +869,17 @@ def stream_generate(
700
869
  break
701
870
 
702
871
  # Wait for generation thread to finish
703
- generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
704
872
  if generation_thread.is_alive():
705
- # Thread didn't finish - this shouldn't happen normally
706
- # but we handle it gracefully
707
- raise RuntimeError("Generation thread failed to finish")
873
+ generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
874
+ if generation_thread.is_alive():
875
+ # Thread didn't finish - this shouldn't happen normally
876
+ raise RuntimeError("Generation thread failed to finish")
708
877
 
709
878
  except Exception as e:
710
- # Ensure any error is properly propagated
711
- raise e
879
+ # Check if there's a thread error we should chain with
880
+ if not error_queue.empty():
881
+ thread_exc, thread_tb = error_queue.get()
882
+ if isinstance(thread_exc, Exception):
883
+ raise e from thread_exc
884
+ # If no thread error, raise the original exception
885
+ raise
@@ -24,16 +24,24 @@ def download(url: str, directory: Union[str, Path, StorageDir]) -> str:
24
24
  dir_path = Path(directory)
25
25
  dir_path.mkdir(exist_ok=True)
26
26
 
27
- # Create hash directory from URL
28
- url_hash = hashlib.sha256(url.encode()).hexdigest()[:12]
29
- hash_dir = dir_path / url_hash
30
- hash_dir.mkdir(exist_ok=True)
27
+ # Parse URL components
28
+ parsed_url = urllib.parse.urlparse(url)
31
29
 
32
- # Keep original filename
33
- filename = os.path.basename(urllib.parse.urlparse(url).path)
30
+ # Create hash from URL path and query parameters for uniqueness
31
+ url_components = parsed_url.netloc + parsed_url.path
32
+ if parsed_url.query:
33
+ url_components += '?' + parsed_url.query
34
+ url_hash = hashlib.sha256(url_components.encode()).hexdigest()[:12]
35
+
36
+ # Keep original filename or use a default
37
+ filename = os.path.basename(parsed_url.path)
34
38
  if not filename:
35
39
  filename = 'download'
36
-
40
+
41
+ # Create hash directory and store file
42
+ hash_dir = dir_path / url_hash
43
+ hash_dir.mkdir(exist_ok=True)
44
+
37
45
  output_path = hash_dir / filename
38
46
 
39
47
  # If file exists in directory and it's not a temp directory, return it