jaf-py 2.3.1__py3-none-any.whl → 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaf/providers/model.py CHANGED
@@ -5,17 +5,106 @@ This module provides model providers that integrate with various LLM services,
5
5
  starting with LiteLLM for multi-provider support.
6
6
  """
7
7
 
8
- from typing import Any, Dict, Optional, TypeVar
8
+ from typing import Any, Dict, Optional, TypeVar, AsyncIterator
9
+ import asyncio
9
10
  import httpx
11
+ import time
12
+ import os
13
+ import base64
10
14
 
11
15
  from openai import OpenAI
12
16
  from pydantic import BaseModel
13
17
 
14
- from ..core.types import Agent, ContentRole, Message, ModelProvider, RunConfig, RunState
18
+ from ..core.types import (
19
+ Agent, ContentRole, Message, ModelProvider, RunConfig, RunState,
20
+ CompletionStreamChunk, ToolCallDelta, ToolCallFunctionDelta,
21
+ MessageContentPart, get_text_content
22
+ )
15
23
  from ..core.proxy import ProxyConfig
24
+ from ..utils.document_processor import (
25
+ extract_document_content, is_document_supported,
26
+ get_document_description, DocumentProcessingError
27
+ )
16
28
 
17
29
  Ctx = TypeVar('Ctx')
18
30
 
31
+ # Vision model caching
32
+ VISION_MODEL_CACHE_TTL = 5 * 60 # 5 minutes
33
+ VISION_API_TIMEOUT = 3.0 # 3 seconds
34
+ _vision_model_cache: Dict[str, Dict[str, Any]] = {}
35
+ MAX_IMAGE_BYTES = int(os.environ.get("JAF_MAX_IMAGE_BYTES", 8 * 1024 * 1024))
36
+
37
+ async def _is_vision_model(model: str, base_url: str) -> bool:
38
+ """
39
+ Check if a model supports vision capabilities.
40
+
41
+ Args:
42
+ model: Model name to check
43
+ base_url: Base URL of the LiteLLM server
44
+
45
+ Returns:
46
+ True if model supports vision, False otherwise
47
+ """
48
+ cache_key = f"{base_url}:{model}"
49
+ cached = _vision_model_cache.get(cache_key)
50
+
51
+ if cached and time.time() - cached['timestamp'] < VISION_MODEL_CACHE_TTL:
52
+ return cached['supports']
53
+
54
+ try:
55
+ async with httpx.AsyncClient(timeout=VISION_API_TIMEOUT) as client:
56
+ response = await client.get(
57
+ f"{base_url}/model_group/info",
58
+ headers={'accept': 'application/json'}
59
+ )
60
+
61
+ if response.status_code == 200:
62
+ data = response.json()
63
+ model_info = None
64
+
65
+ if 'data' in data and isinstance(data['data'], list):
66
+ for m in data['data']:
67
+ if (m.get('model_group') == model or
68
+ model in str(m.get('model_group', ''))):
69
+ model_info = m
70
+ break
71
+
72
+ if model_info and 'supports_vision' in model_info:
73
+ result = model_info['supports_vision']
74
+ _vision_model_cache[cache_key] = {
75
+ 'supports': result,
76
+ 'timestamp': time.time()
77
+ }
78
+ return result
79
+ else:
80
+ print(f"Warning: Vision API returned status {response.status_code} for model {model}")
81
+
82
+ except Exception as e:
83
+ print(f"Warning: Vision API error for model {model}: {e}")
84
+
85
+ # Fallback to known vision models
86
+ known_vision_models = [
87
+ 'gpt-4-vision-preview',
88
+ 'gpt-4o',
89
+ 'gpt-4o-mini',
90
+ 'claude-sonnet-4',
91
+ 'claude-sonnet-4-20250514',
92
+ 'gemini-2.5-flash',
93
+ 'gemini-2.5-pro'
94
+ ]
95
+
96
+ is_known_vision_model = any(
97
+ vision_model.lower() in model.lower()
98
+ for vision_model in known_vision_models
99
+ )
100
+
101
+ _vision_model_cache[cache_key] = {
102
+ 'supports': is_known_vision_model,
103
+ 'timestamp': time.time()
104
+ }
105
+
106
+ return is_known_vision_model
107
+
19
108
  def make_litellm_provider(
20
109
  base_url: str,
21
110
  api_key: str = "anything",
@@ -75,6 +164,23 @@ def make_litellm_provider(
75
164
  model = (config.model_override or
76
165
  (agent.model_config.name if agent.model_config else "gpt-4o"))
77
166
 
167
+ # Check if any message contains image content or image attachments
168
+ has_image_content = any(
169
+ (isinstance(msg.content, list) and
170
+ any(part.type == 'image_url' for part in msg.content)) or
171
+ (msg.attachments and
172
+ any(att.kind == 'image' for att in msg.attachments))
173
+ for msg in state.messages
174
+ )
175
+
176
+ if has_image_content:
177
+ supports_vision = await _is_vision_model(model, base_url)
178
+ if not supports_vision:
179
+ raise ValueError(
180
+ f"Model {model} does not support vision capabilities. "
181
+ f"Please use a vision-capable model like gpt-4o, claude-3-5-sonnet, or gemini-1.5-pro."
182
+ )
183
+
78
184
  # Create system message
79
185
  system_message = {
80
186
  "role": "system",
@@ -82,9 +188,12 @@ def make_litellm_provider(
82
188
  }
83
189
 
84
190
  # Convert messages to OpenAI format
85
- messages = [system_message] + [
86
- _convert_message(msg) for msg in state.messages
87
- ]
191
+ converted_messages = []
192
+ for msg in state.messages:
193
+ converted_msg = await _convert_message(msg)
194
+ converted_messages.append(converted_msg)
195
+
196
+ messages = [system_message] + converted_messages
88
197
 
89
198
  # Convert tools to OpenAI format
90
199
  tools = None
@@ -169,19 +278,189 @@ def make_litellm_provider(
169
278
  'prompt': messages
170
279
  }
171
280
 
281
+ async def get_completion_stream(
282
+ self,
283
+ state: RunState[Ctx],
284
+ agent: Agent[Ctx, Any],
285
+ config: RunConfig[Ctx]
286
+ ) -> AsyncIterator[CompletionStreamChunk]:
287
+ """
288
+ Stream completion chunks from the model provider, yielding text deltas and tool-call deltas.
289
+ Uses OpenAI-compatible streaming via LiteLLM endpoint.
290
+ """
291
+ # Determine model to use
292
+ model = (config.model_override or
293
+ (agent.model_config.name if agent.model_config else "gpt-4o"))
294
+
295
+ # Create system message
296
+ system_message = {
297
+ "role": "system",
298
+ "content": agent.instructions(state)
299
+ }
300
+
301
+ # Convert messages to OpenAI format
302
+ converted_messages = []
303
+ for msg in state.messages:
304
+ converted_msg = await _convert_message(msg)
305
+ converted_messages.append(converted_msg)
306
+
307
+ messages = [system_message] + converted_messages
308
+
309
+ # Convert tools to OpenAI format
310
+ tools = None
311
+ if agent.tools:
312
+ tools = [
313
+ {
314
+ "type": "function",
315
+ "function": {
316
+ "name": tool.schema.name,
317
+ "description": tool.schema.description,
318
+ "parameters": _pydantic_to_json_schema(tool.schema.parameters),
319
+ }
320
+ }
321
+ for tool in agent.tools
322
+ ]
323
+
324
+ # Determine tool choice behavior
325
+ last_message = state.messages[-1] if state.messages else None
326
+ is_after_tool_call = last_message and (last_message.role == ContentRole.TOOL or last_message.role == 'tool')
327
+
328
+ # Prepare request parameters
329
+ request_params: Dict[str, Any] = {
330
+ "model": model,
331
+ "messages": messages,
332
+ }
333
+
334
+ # Add optional parameters
335
+ if agent.model_config:
336
+ if agent.model_config.temperature is not None:
337
+ request_params["temperature"] = agent.model_config.temperature
338
+ if agent.model_config.max_tokens is not None:
339
+ request_params["max_tokens"] = agent.model_config.max_tokens
340
+
341
+ if tools:
342
+ request_params["tools"] = tools
343
+ # Set tool_choice to auto when tools are available
344
+ request_params["tool_choice"] = "auto"
345
+
346
+ if agent.output_codec:
347
+ request_params["response_format"] = {"type": "json_object"}
348
+
349
+ # Enable streaming
350
+ request_params["stream"] = True
351
+
352
+ loop = asyncio.get_running_loop()
353
+ queue: asyncio.Queue = asyncio.Queue(maxsize=256)
354
+ SENTINEL = object()
355
+
356
+ def _put(item: CompletionStreamChunk):
357
+ try:
358
+ asyncio.run_coroutine_threadsafe(queue.put(item), loop)
359
+ except RuntimeError:
360
+ # Event loop closed; drop silently
361
+ pass
362
+
363
+ def _producer():
364
+ try:
365
+ stream = self.client.chat.completions.create(**request_params)
366
+ for chunk in stream:
367
+ try:
368
+ # Best-effort extraction of raw for debugging
369
+ try:
370
+ raw_obj = chunk.model_dump() # pydantic BaseModel
371
+ except Exception:
372
+ raw_obj = None
373
+
374
+ choice = None
375
+ if getattr(chunk, "choices", None):
376
+ choice = chunk.choices[0]
377
+
378
+ if choice is None:
379
+ continue
380
+
381
+ delta = getattr(choice, "delta", None)
382
+ finish_reason = getattr(choice, "finish_reason", None)
383
+
384
+ # Text content delta
385
+ if delta is not None:
386
+ content_delta = getattr(delta, "content", None)
387
+ if content_delta:
388
+ _put(CompletionStreamChunk(delta=content_delta, raw=raw_obj))
389
+
390
+ # Tool call deltas
391
+ tool_calls = getattr(delta, "tool_calls", None)
392
+ if isinstance(tool_calls, list):
393
+ for tc in tool_calls:
394
+ # Each tc is likely a pydantic model with .index/.id/.function
395
+ try:
396
+ idx = getattr(tc, "index", 0) or 0
397
+ tc_id = getattr(tc, "id", None)
398
+ fn = getattr(tc, "function", None)
399
+ fn_name = getattr(fn, "name", None) if fn is not None else None
400
+ # OpenAI streams "arguments" as incremental deltas
401
+ args_delta = getattr(fn, "arguments", None) if fn is not None else None
402
+
403
+ _put(CompletionStreamChunk(
404
+ tool_call_delta=ToolCallDelta(
405
+ index=idx,
406
+ id=tc_id,
407
+ type='function',
408
+ function=ToolCallFunctionDelta(
409
+ name=fn_name,
410
+ arguments_delta=args_delta
411
+ )
412
+ ),
413
+ raw=raw_obj
414
+ ))
415
+ except Exception:
416
+ # Skip malformed tool-call deltas
417
+ continue
418
+
419
+ # Completion ended
420
+ if finish_reason:
421
+ _put(CompletionStreamChunk(is_done=True, finish_reason=finish_reason, raw=raw_obj))
422
+ except Exception:
423
+ # Skip individual chunk errors, keep streaming
424
+ continue
425
+ except Exception:
426
+ # On top-level stream error, signal done
427
+ pass
428
+ finally:
429
+ try:
430
+ asyncio.run_coroutine_threadsafe(queue.put(SENTINEL), loop)
431
+ except RuntimeError:
432
+ pass
433
+
434
+ # Start producer in background
435
+ loop.run_in_executor(None, _producer)
436
+
437
+ # Consume queue and yield
438
+ while True:
439
+ item = await queue.get()
440
+ if item is SENTINEL:
441
+ break
442
+ # Guarantee type for consumers
443
+ if isinstance(item, CompletionStreamChunk):
444
+ yield item
445
+
172
446
  return LiteLLMProvider()
173
447
 
174
- def _convert_message(msg: Message) -> Dict[str, Any]:
175
- """Convert JAF Message to OpenAI message format."""
448
+ async def _convert_message(msg: Message) -> Dict[str, Any]:
449
+ """Convert JAF Message to OpenAI message format with attachment support."""
176
450
  if msg.role == 'user':
177
- return {
178
- "role": "user",
179
- "content": msg.content
180
- }
451
+ if isinstance(msg.content, list):
452
+ # Multi-part content
453
+ return {
454
+ "role": "user",
455
+ "content": [_convert_content_part(part) for part in msg.content]
456
+ }
457
+ else:
458
+ # Build message with attachments if available
459
+ return await _build_chat_message_with_attachments('user', msg)
181
460
  elif msg.role == 'assistant':
182
461
  result = {
183
462
  "role": "assistant",
184
- "content": msg.content,
463
+ "content": get_text_content(msg.content),
185
464
  }
186
465
  if msg.tool_calls:
187
466
  result["tool_calls"] = [
@@ -199,12 +478,156 @@ def _convert_message(msg: Message) -> Dict[str, Any]:
199
478
  elif msg.role == ContentRole.TOOL:
200
479
  return {
201
480
  "role": "tool",
202
- "content": msg.content,
481
+ "content": get_text_content(msg.content),
203
482
  "tool_call_id": msg.tool_call_id
204
483
  }
205
484
  else:
206
485
  raise ValueError(f"Unknown message role: {msg.role}")
207
486
 
487
+
488
+ def _convert_content_part(part: MessageContentPart) -> Dict[str, Any]:
489
+ """Convert MessageContentPart to OpenAI format."""
490
+ if part.type == 'text':
491
+ return {
492
+ "type": "text",
493
+ "text": part.text
494
+ }
495
+ elif part.type == 'image_url':
496
+ return {
497
+ "type": "image_url",
498
+ "image_url": part.image_url
499
+ }
500
+ elif part.type == 'file':
501
+ return {
502
+ "type": "file",
503
+ "file": part.file
504
+ }
505
+ else:
506
+ raise ValueError(f"Unknown content part type: {part.type}")
507
+
508
+
509
+ async def _build_chat_message_with_attachments(
510
+ role: str,
511
+ msg: Message
512
+ ) -> Dict[str, Any]:
513
+ """
514
+ Build multi-part content for Chat Completions if attachments exist.
515
+ Supports images via image_url and documents via content extraction.
516
+ """
517
+ has_attachments = msg.attachments and len(msg.attachments) > 0
518
+ if not has_attachments:
519
+ if role == 'assistant':
520
+ base_msg = {"role": "assistant", "content": get_text_content(msg.content)}
521
+ if msg.tool_calls:
522
+ base_msg["tool_calls"] = [
523
+ {
524
+ "id": tc.id,
525
+ "type": tc.type,
526
+ "function": {
527
+ "name": tc.function.name,
528
+ "arguments": tc.function.arguments
529
+ }
530
+ }
531
+ for tc in msg.tool_calls
532
+ ]
533
+ return base_msg
534
+ return {"role": "user", "content": get_text_content(msg.content)}
535
+
536
+ parts = []
537
+ text_content = get_text_content(msg.content)
538
+ if text_content and text_content.strip():
539
+ parts.append({"type": "text", "text": text_content})
540
+
541
+ for att in msg.attachments:
542
+ if att.kind == 'image':
543
+ # Prefer explicit URL; otherwise construct a data URL from base64
544
+ url = att.url
545
+ if not url and att.data and att.mime_type:
546
+ # Validate base64 data size before creating data URL
547
+ try:
548
+ # Estimate decoded size (base64 is ~4/3 of decoded size)
549
+ estimated_size = len(att.data) * 3 // 4
550
+
551
+ if estimated_size > MAX_IMAGE_BYTES:
552
+ print(f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
553
+ f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit.")
554
+ parts.append({
555
+ "type": "text",
556
+ "text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES//1024//1024}MB. "
557
+ f"Image name: {att.name or 'unnamed'}]"
558
+ })
559
+ continue
560
+
561
+ # Create data URL for valid-sized images
562
+ url = f"data:{att.mime_type};base64,{att.data}"
563
+ except Exception as e:
564
+ print(f"Error processing image data: {e}")
565
+ parts.append({
566
+ "type": "text",
567
+ "text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]"
568
+ })
569
+ continue
570
+
571
+ if url:
572
+ parts.append({
573
+ "type": "image_url",
574
+ "image_url": {"url": url}
575
+ })
576
+
577
+ elif att.kind in ['document', 'file']:
578
+ # Check if attachment has use_litellm_format flag or is a large document
579
+ use_litellm_format = att.use_litellm_format is True
580
+
581
+ if use_litellm_format and (att.url or att.data):
582
+ # For now, fall back to content extraction since most providers don't support native file format
583
+ # TODO: Add provider-specific file format support
584
+ print(f"Info: LiteLLM format requested for {att.name}, falling back to content extraction")
585
+ use_litellm_format = False
586
+
587
+ if not use_litellm_format:
588
+ # Extract document content if supported and we have data or URL
589
+ if is_document_supported(att.mime_type) and (att.data or att.url):
590
+ try:
591
+ processed = await extract_document_content(att)
592
+ file_name = att.name or 'document'
593
+ description = get_document_description(att.mime_type)
594
+
595
+ parts.append({
596
+ "type": "text",
597
+ "text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}"
598
+ })
599
+ except DocumentProcessingError as e:
600
+ # Fallback to filename if extraction fails
601
+ label = att.name or att.format or att.mime_type or 'attachment'
602
+ parts.append({
603
+ "type": "text",
604
+ "text": f"ERROR: Failed to process {att.kind}: {label} ({e})"
605
+ })
606
+ else:
607
+ # Unsupported document type - show placeholder
608
+ label = att.name or att.format or att.mime_type or 'attachment'
609
+ url_info = f" ({att.url})" if att.url else ""
610
+ parts.append({
611
+ "type": "text",
612
+ "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"
613
+ })
614
+
615
+ base_msg = {"role": role, "content": parts}
616
+ if role == 'assistant' and msg.tool_calls:
617
+ base_msg["tool_calls"] = [
618
+ {
619
+ "id": tc.id,
620
+ "type": tc.type,
621
+ "function": {
622
+ "name": tc.function.name,
623
+ "arguments": tc.function.arguments
624
+ }
625
+ }
626
+ for tc in msg.tool_calls
627
+ ]
628
+
629
+ return base_msg
630
+
208
631
  def _pydantic_to_json_schema(model_class: type[BaseModel]) -> Dict[str, Any]:
209
632
  """
210
633
  Convert a Pydantic model to JSON schema for OpenAI tools.
jaf/server/__init__.py CHANGED
@@ -10,6 +10,8 @@ __all__ = [
10
10
  "ChatResponse",
11
11
  "HealthResponse",
12
12
  "HttpMessage",
13
+ "HttpAttachment",
14
+ "HttpMessageContentPart",
13
15
  "ServerConfig",
14
16
  "create_jaf_server",
15
17
  "run_server",