donkit-llm 0.1.6__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: donkit-llm
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Unified LLM model implementations for Donkit (OpenAI, Azure OpenAI, Claude, Vertex AI, Ollama)
5
5
  License: MIT
6
6
  Author: Donkit AI
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "donkit-llm"
3
- version = "0.1.6"
3
+ version = "0.1.8"
4
4
  description = "Unified LLM model implementations for Donkit (OpenAI, Azure OpenAI, Claude, Vertex AI, Ollama)"
5
5
  authors = ["Donkit AI <opensource@donkit.ai>"]
6
6
  license = "MIT"
@@ -19,6 +19,7 @@ donkit-ragops-api-gateway-client = "^0.1.5"
19
19
  ruff = "^0.13.3"
20
20
  pytest = "^8.4.2"
21
21
  pytest-asyncio = "^1.3.0"
22
+ donkit-llm-gate-client = { path = "../llm-gate-client", develop = true }
22
23
 
23
24
  [build-system]
24
25
  requires = ["poetry-core>=1.9.0"]
@@ -26,6 +26,13 @@ from .factory import ModelFactory
26
26
  from .gemini_model import GeminiModel, GeminiEmbeddingModel
27
27
  from .donkit_model import DonkitModel
28
28
 
29
+ import importlib.util
30
+
31
+ if importlib.util.find_spec("donkit.llm_gate.client") is not None:
32
+ from .llm_gate_model import LLMGateModel
33
+ else:
34
+ LLMGateModel = None
35
+
29
36
  __all__ = [
30
37
  "ModelFactory",
31
38
  # Abstract base
@@ -58,3 +65,6 @@ __all__ = [
58
65
  "GeminiEmbeddingModel",
59
66
  "DonkitModel",
60
67
  ]
68
+
69
+ if LLMGateModel is not None:
70
+ __all__.append("LLMGateModel")
@@ -4,6 +4,13 @@ from .claude_model import ClaudeModel
4
4
  from .claude_model import ClaudeVertexModel
5
5
  from .donkit_model import DonkitModel
6
6
  from .gemini_model import GeminiModel
7
+
8
+ import importlib.util
9
+
10
+ if importlib.util.find_spec("donkit.llm_gate.client") is not None:
11
+ from .llm_gate_model import LLMGateModel
12
+ else:
13
+ LLMGateModel = None
7
14
  from .model_abstract import LLMModelAbstract
8
15
  from .openai_model import AzureOpenAIEmbeddingModel
9
16
  from .openai_model import AzureOpenAIModel
@@ -174,6 +181,30 @@ class ModelFactory:
174
181
  model_name=model_name,
175
182
  )
176
183
 
184
+ @staticmethod
185
+ def create_llm_gate_model(
186
+ model_name: str | None,
187
+ base_url: str,
188
+ provider: str = "default",
189
+ embedding_provider: str | None = None,
190
+ embedding_model_name: str | None = None,
191
+ user_id: str | None = None,
192
+ project_id: str | None = None,
193
+ ) -> LLMGateModel:
194
+ if LLMGateModel is None:
195
+ raise ImportError(
196
+ "Provider 'llm_gate' requires optional dependency 'donkit-llm-gate-client'"
197
+ )
198
+ return LLMGateModel(
199
+ base_url=base_url,
200
+ provider=provider,
201
+ model_name=model_name,
202
+ embedding_provider=embedding_provider,
203
+ embedding_model_name=embedding_model_name,
204
+ user_id=user_id,
205
+ project_id=project_id,
206
+ )
207
+
177
208
  @staticmethod
178
209
  def create_model(
179
210
  provider: Literal[
@@ -184,6 +215,7 @@ class ModelFactory:
184
215
  "vertex",
185
216
  "ollama",
186
217
  "donkit",
218
+ "llm_gate",
187
219
  ],
188
220
  model_name: str | None,
189
221
  credentials: dict,
@@ -198,6 +230,7 @@ class ModelFactory:
198
230
  "vertex": "gemini-2.5-flash",
199
231
  "ollama": "mistral",
200
232
  "donkit": None,
233
+ "llm_gate": None,
201
234
  }
202
235
  model_name = default_models.get(provider, "default")
203
236
  if provider == "openai":
@@ -258,5 +291,15 @@ class ModelFactory:
258
291
  api_key=credentials["api_key"],
259
292
  base_url=credentials["base_url"],
260
293
  )
294
+ elif provider == "llm_gate":
295
+ return ModelFactory.create_llm_gate_model(
296
+ model_name=model_name,
297
+ base_url=credentials["base_url"],
298
+ provider=credentials.get("provider", "default"),
299
+ embedding_provider=credentials.get("embedding_provider"),
300
+ embedding_model_name=credentials.get("embedding_model_name"),
301
+ user_id=credentials.get("user_id"),
302
+ project_id=credentials.get("project_id"),
303
+ )
261
304
  else:
262
305
  raise ValueError(f"Unknown provider: {provider}")
@@ -0,0 +1,210 @@
1
+ from typing import Any, AsyncIterator
2
+
3
+ from .model_abstract import (
4
+ EmbeddingRequest,
5
+ EmbeddingResponse,
6
+ FunctionCall,
7
+ GenerateRequest,
8
+ GenerateResponse,
9
+ LLMModelAbstract,
10
+ Message,
11
+ ModelCapability,
12
+ StreamChunk,
13
+ Tool,
14
+ ToolCall,
15
+ )
16
+
17
+
18
+ class LLMGateModel(LLMModelAbstract):
19
+ name = "llm_gate"
20
+
21
+ @staticmethod
22
+ def _get_client() -> type:
23
+ try:
24
+ from donkit.llm_gate.client import LLMGate
25
+
26
+ return LLMGate
27
+ except Exception as e:
28
+ raise ImportError(
29
+ "LLMGateModel requires 'donkit-llm-gate-client' to be installed"
30
+ ) from e
31
+
32
+ def __init__(
33
+ self,
34
+ base_url: str = "http://localhost:8002",
35
+ provider: str = "default",
36
+ model_name: str | None = None,
37
+ embedding_provider: str | None = None,
38
+ embedding_model_name: str | None = None,
39
+ user_id: str | None = None,
40
+ project_id: str | None = None,
41
+ ):
42
+ self.base_url = base_url
43
+ self.provider = provider
44
+ self._model_name = model_name
45
+ self.embedding_provider = embedding_provider
46
+ self.embedding_model_name = embedding_model_name
47
+ self.user_id = user_id
48
+ self.project_id = project_id
49
+ self._capabilities = self._determine_capabilities()
50
+
51
+ @property
52
+ def model_name(self) -> str:
53
+ return self._model_name or "default"
54
+
55
+ @model_name.setter
56
+ def model_name(self, value: str):
57
+ self._model_name = value
58
+ self._capabilities = self._determine_capabilities()
59
+
60
+ @property
61
+ def capabilities(self) -> ModelCapability:
62
+ return self._capabilities
63
+
64
+ def _determine_capabilities(self) -> ModelCapability:
65
+ caps = (
66
+ ModelCapability.TEXT_GENERATION
67
+ | ModelCapability.STREAMING
68
+ | ModelCapability.STRUCTURED_OUTPUT
69
+ | ModelCapability.TOOL_CALLING
70
+ | ModelCapability.MULTIMODAL_INPUT
71
+ | ModelCapability.EMBEDDINGS
72
+ )
73
+ return caps
74
+
75
+ def _convert_message(self, msg: Message) -> dict:
76
+ result: dict[str, Any] = {"role": msg.role}
77
+ if isinstance(msg.content, str):
78
+ result["content"] = msg.content
79
+ else:
80
+ content_parts = []
81
+ for part in msg.content if msg.content else []:
82
+ content_parts.append(part.model_dump(exclude_none=True))
83
+ result["content"] = content_parts
84
+ if msg.tool_calls:
85
+ result["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
86
+ if msg.tool_call_id:
87
+ result["tool_call_id"] = msg.tool_call_id
88
+ if msg.name:
89
+ result["name"] = msg.name
90
+ return result
91
+
92
+ def _convert_tools(self, tools: list[Tool]) -> list[dict]:
93
+ return [tool.model_dump(exclude_none=True) for tool in tools]
94
+
95
+ def _prepare_generate_kwargs(self, request: GenerateRequest) -> dict:
96
+ messages = [self._convert_message(msg) for msg in request.messages]
97
+ tools_payload = self._convert_tools(request.tools) if request.tools else None
98
+
99
+ kwargs: dict[str, Any] = {
100
+ "provider": self.provider,
101
+ "model_name": self.model_name,
102
+ "messages": messages,
103
+ "user_id": self.user_id,
104
+ "project_id": self.project_id,
105
+ }
106
+
107
+ if request.temperature is not None:
108
+ kwargs["temperature"] = request.temperature
109
+ if request.max_tokens is not None:
110
+ kwargs["max_tokens"] = request.max_tokens
111
+ if request.top_p is not None:
112
+ kwargs["top_p"] = request.top_p
113
+ if request.stop:
114
+ kwargs["stop"] = request.stop
115
+ if tools_payload:
116
+ kwargs["tools"] = tools_payload
117
+ if request.tool_choice:
118
+ if isinstance(request.tool_choice, (str, dict)):
119
+ kwargs["tool_choice"] = request.tool_choice
120
+ else:
121
+ kwargs["tool_choice"] = "auto"
122
+ if request.response_format:
123
+ kwargs["response_format"] = request.response_format
124
+
125
+ return kwargs
126
+
127
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
128
+ await self.validate_request(request)
129
+
130
+ kwargs = self._prepare_generate_kwargs(request)
131
+
132
+ llm_gate = self._get_client()
133
+
134
+ async with llm_gate(base_url=self.base_url) as client:
135
+ response = await client.generate(**kwargs)
136
+
137
+ tool_calls = None
138
+ if response.tool_calls:
139
+ tool_calls = [
140
+ ToolCall(
141
+ id=tc.id,
142
+ type=tc.type,
143
+ function=FunctionCall(
144
+ name=tc.function.name,
145
+ arguments=tc.function.arguments,
146
+ ),
147
+ )
148
+ for tc in response.tool_calls
149
+ ]
150
+
151
+ return GenerateResponse(
152
+ content=response.content,
153
+ tool_calls=tool_calls,
154
+ finish_reason=response.finish_reason,
155
+ usage=response.usage,
156
+ metadata=response.metadata,
157
+ )
158
+
159
+ async def generate_stream(
160
+ self, request: GenerateRequest
161
+ ) -> AsyncIterator[StreamChunk]:
162
+ await self.validate_request(request)
163
+
164
+ kwargs = self._prepare_generate_kwargs(request)
165
+
166
+ llm_gate = self._get_client()
167
+
168
+ async with llm_gate(base_url=self.base_url) as client:
169
+ async for chunk in client.generate_stream(**kwargs):
170
+ tool_calls = None
171
+ if chunk.tool_calls:
172
+ tool_calls = [
173
+ ToolCall(
174
+ id=tc.id,
175
+ type=tc.type,
176
+ function=FunctionCall(
177
+ name=tc.function.name,
178
+ arguments=tc.function.arguments,
179
+ ),
180
+ )
181
+ for tc in chunk.tool_calls
182
+ ]
183
+
184
+ yield StreamChunk(
185
+ content=chunk.content,
186
+ tool_calls=tool_calls,
187
+ finish_reason=chunk.finish_reason,
188
+ )
189
+
190
+ async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
191
+ provider = self.embedding_provider or self.provider
192
+ model_name = self.embedding_model_name
193
+
194
+ llm_gate = self._get_client()
195
+
196
+ async with llm_gate(base_url=self.base_url) as client:
197
+ response = await client.embeddings(
198
+ provider=provider,
199
+ input=request.input,
200
+ model_name=model_name,
201
+ dimensions=request.dimensions,
202
+ user_id=self.user_id,
203
+ project_id=self.project_id,
204
+ )
205
+
206
+ return EmbeddingResponse(
207
+ embeddings=response.embeddings,
208
+ usage=response.usage,
209
+ metadata=response.metadata,
210
+ )
@@ -45,27 +45,29 @@ class OpenAIModel(LLMModelAbstract):
45
45
 
46
46
  def _clean_schema_for_openai(self, schema: dict, is_gpt5: bool = False) -> dict:
47
47
  """Clean JSON Schema for OpenAI strict mode.
48
-
48
+
49
49
  Args:
50
50
  schema: JSON Schema to clean
51
51
  is_gpt5: Currently unused - both GPT-4o and GPT-5 support title/description
52
-
52
+
53
53
  OpenAI Structured Outputs supports:
54
54
  - title and description (useful metadata for the model)
55
55
  - Automatically adds additionalProperties: false for all objects
56
56
  - Recursively processes nested objects and arrays
57
-
57
+
58
58
  Note: The main requirement is additionalProperties: false for all objects,
59
59
  which is automatically added by this method.
60
60
  """
61
61
  cleaned = {}
62
-
62
+
63
63
  # Copy all fields, recursively processing nested structures
64
64
  for key, value in schema.items():
65
65
  if key == "properties" and isinstance(value, dict):
66
66
  # Recursively clean properties
67
67
  cleaned["properties"] = {
68
- k: self._clean_schema_for_openai(v, is_gpt5) if isinstance(v, dict) else v
68
+ k: self._clean_schema_for_openai(v, is_gpt5)
69
+ if isinstance(v, dict)
70
+ else v
69
71
  for k, v in value.items()
70
72
  }
71
73
  elif key == "items" and isinstance(value, dict):
@@ -73,11 +75,11 @@ class OpenAIModel(LLMModelAbstract):
73
75
  cleaned["items"] = self._clean_schema_for_openai(value, is_gpt5)
74
76
  else:
75
77
  cleaned[key] = value
76
-
78
+
77
79
  # Ensure additionalProperties is false for objects (required by OpenAI)
78
80
  if cleaned.get("type") == "object" and "additionalProperties" not in cleaned:
79
81
  cleaned["additionalProperties"] = False
80
-
82
+
81
83
  return cleaned
82
84
 
83
85
  def _get_base_model_name(self) -> str:
@@ -93,9 +95,9 @@ class OpenAIModel(LLMModelAbstract):
93
95
  Reasoning models don't support temperature, top_p, presence_penalty, frequency_penalty.
94
96
  They only support max_completion_tokens (not max_tokens).
95
97
  """
96
- model_lower = self._get_base_model_name().lower()
98
+ model_lower = self.model_name.lower()
97
99
  # Check for reasoning model prefixes
98
- reasoning_prefixes = ("gpt-5", "o1", "o3", "o4")
100
+ reasoning_prefixes = ("gpt-5", "o1", "o3", "o4", "deepseek")
99
101
  return any(model_lower.startswith(prefix) for prefix in reasoning_prefixes)
100
102
 
101
103
  def _supports_max_completion_tokens(self) -> bool:
@@ -333,6 +335,18 @@ class OpenAIModel(LLMModelAbstract):
333
335
  messages = [self._convert_message(msg) for msg in request.messages]
334
336
  kwargs = self._build_request_kwargs(request, messages, stream=False)
335
337
 
338
+ # Log request to LLM
339
+ # request_log = {
340
+ # "model": self._model_name,
341
+ # "messages": messages,
342
+ # "temperature": kwargs.get("temperature"),
343
+ # "max_tokens": kwargs.get("max_tokens")
344
+ # or kwargs.get("max_completion_tokens"),
345
+ # "top_p": kwargs.get("top_p"),
346
+ # "tools": kwargs.get("tools"),
347
+ # "tool_choice": kwargs.get("tool_choice"),
348
+ # }
349
+
336
350
  try:
337
351
  response = await self.client.chat.completions.create(**kwargs)
338
352
 
@@ -359,6 +373,33 @@ class OpenAIModel(LLMModelAbstract):
359
373
  for tc in message.tool_calls
360
374
  ]
361
375
 
376
+ # Log response from LLM
377
+ # response_log = {
378
+ # "content": content,
379
+ # "tool_calls": [
380
+ # {
381
+ # "id": tc.id,
382
+ # "type": tc.type,
383
+ # "function": {
384
+ # "name": tc.function.name,
385
+ # "arguments": tc.function.arguments,
386
+ # },
387
+ # }
388
+ # for tc in (message.tool_calls or [])
389
+ # ],
390
+ # "finish_reason": choice.finish_reason,
391
+ # "usage": {
392
+ # "prompt_tokens": response.usage.prompt_tokens,
393
+ # "completion_tokens": response.usage.completion_tokens,
394
+ # "total_tokens": response.usage.total_tokens,
395
+ # }
396
+ # if response.usage
397
+ # else None,
398
+ # }
399
+ # logger.info(
400
+ # f"LLM Response (generate): {json.dumps(response_log, ensure_ascii=False, indent=2)}"
401
+ # )
402
+
362
403
  return GenerateResponse(
363
404
  content=content,
364
405
  tool_calls=tool_calls,
@@ -384,6 +425,22 @@ class OpenAIModel(LLMModelAbstract):
384
425
  messages = [self._convert_message(msg) for msg in request.messages]
385
426
  kwargs = self._build_request_kwargs(request, messages, stream=True)
386
427
 
428
+ # Log request to LLM
429
+ # request_log = {
430
+ # "model": self._model_name,
431
+ # "messages": messages,
432
+ # "temperature": kwargs.get("temperature"),
433
+ # "max_tokens": kwargs.get("max_tokens")
434
+ # or kwargs.get("max_completion_tokens"),
435
+ # "top_p": kwargs.get("top_p"),
436
+ # "tools": kwargs.get("tools"),
437
+ # "tool_choice": kwargs.get("tool_choice"),
438
+ # "stream": True,
439
+ # }
440
+ # logger.info(
441
+ # f"LLM Request (generate_stream): {json.dumps(request_log, ensure_ascii=False, indent=2)}"
442
+ # )
443
+
387
444
  try:
388
445
  stream = await self.client.chat.completions.create(**kwargs)
389
446
 
@@ -399,6 +456,11 @@ class OpenAIModel(LLMModelAbstract):
399
456
 
400
457
  # Yield text content if present
401
458
  if delta.content:
459
+ # Log chunk content: uncomment to debug
460
+ # chunk_log = {"content": delta.content}
461
+ # logger.info(
462
+ # f"LLM Stream Chunk: {json.dumps(chunk_log, ensure_ascii=False)}"
463
+ # )
402
464
  yield StreamChunk(content=delta.content, tool_calls=None)
403
465
 
404
466
  # Accumulate tool calls
@@ -443,6 +505,21 @@ class OpenAIModel(LLMModelAbstract):
443
505
  )
444
506
  for tc_data in accumulated_tool_calls.values()
445
507
  ]
508
+ # Log final tool calls
509
+ # tool_calls_log = [
510
+ # {
511
+ # "id": tc.id,
512
+ # "type": tc.type,
513
+ # "function": {
514
+ # "name": tc.function.name,
515
+ # "arguments": tc.function.arguments,
516
+ # },
517
+ # }
518
+ # for tc in tool_calls
519
+ # ]
520
+ # logger.info(
521
+ # f"LLM Stream Final Tool Calls: {json.dumps(tool_calls_log, ensure_ascii=False, indent=2)}"
522
+ # )
446
523
  yield StreamChunk(content=None, tool_calls=tool_calls)
447
524
 
448
525
  except Exception as e:
@@ -1,442 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Test script for Ollama LLM provider integration."""
3
-
4
- import asyncio
5
- import json
6
-
7
- import pytest
8
-
9
- from llm_gate.models import ModelCapability
10
- from .factory import ModelFactory
11
- from .model_abstract import (
12
- ContentPart,
13
- ContentType,
14
- EmbeddingRequest,
15
- GenerateRequest,
16
- Message,
17
- )
18
- from .openai_model import OpenAIModel
19
-
20
-
21
- def test_ollama_model_creation() -> None:
22
- """Test that Ollama model initializes correctly."""
23
- print("=" * 60)
24
- print("TEST 1: Ollama Model Creation via Factory")
25
- print("=" * 60)
26
-
27
- credentials = {
28
- "api_key": "ollama",
29
- "base_url": "http://localhost:11434/v1",
30
- }
31
-
32
- try:
33
- model = ModelFactory.create_model(
34
- provider="ollama",
35
- model_name="gpt-oss:20b-cloud",
36
- credentials=credentials,
37
- )
38
- print(f"✅ Model created: {model.__class__.__name__}")
39
- print(f" Model name: {model.model_name}")
40
- print(f" Is OpenAI model: {isinstance(model, OpenAIModel)}")
41
- except Exception as e:
42
- print(f"❌ Failed to create model: {e}")
43
- raise
44
-
45
-
46
- def test_openai_model_direct_creation() -> None:
47
- """Test direct OpenAI model creation for Ollama."""
48
- print("\n" + "=" * 60)
49
- print("TEST 2: Direct OpenAI Model Creation for Ollama")
50
- print("=" * 60)
51
-
52
- try:
53
- model = ModelFactory.create_openai_model(
54
- model_name="gpt-oss:20b-cloud",
55
- api_key="ollama",
56
- base_url="http://localhost:11434/v1",
57
- )
58
- print(f"✅ Model created: {model.__class__.__name__}")
59
- print(f" Model name: {model.model_name}")
60
- print(f" Has client: {hasattr(model, 'client')}")
61
- print(f" Client type: {type(model.client).__name__}")
62
- except Exception as e:
63
- print(f"❌ Failed to create model: {e}")
64
- raise
65
-
66
-
67
- def test_embedding_model_creation() -> None:
68
- """Test embedding model creation for Ollama."""
69
- print("\n" + "=" * 60)
70
- print("TEST 3: Embedding Model Creation")
71
- print("=" * 60)
72
-
73
- try:
74
- # OpenAI embedding model with Ollama base URL
75
- model = ModelFactory.create_embedding_model(
76
- provider="openai",
77
- model_name="nomic-embed-text",
78
- api_key="ollama",
79
- base_url="http://localhost:11434/v1",
80
- )
81
- print(f"✅ Embedding model created: {model.__class__.__name__}")
82
- print(f" Model name: {model.model_name}")
83
- except Exception as e:
84
- print(f"❌ Failed to create embedding model: {e}")
85
- raise
86
-
87
-
88
- def test_model_configuration() -> None:
89
- """Test model configuration."""
90
- print("\n" + "=" * 60)
91
- print("TEST 4: Model Configuration")
92
- print("=" * 60)
93
-
94
- try:
95
- model = ModelFactory.create_openai_model(
96
- model_name="gpt-oss:20b-cloud",
97
- api_key="ollama",
98
- base_url="http://localhost:11434/v1",
99
- )
100
-
101
- config = {
102
- "model_name": model.model_name,
103
- "model_type": model.__class__.__name__,
104
- "has_client": hasattr(model, "client"),
105
- "capabilities": str(model.capabilities),
106
- }
107
-
108
- print("✅ Model configuration:")
109
- print(json.dumps(config, indent=2, default=str))
110
- except Exception as e:
111
- print(f"❌ Failed to get model configuration: {e}")
112
- raise
113
-
114
-
115
- def test_factory_provider_support() -> None:
116
- """Test that factory supports ollama provider."""
117
- print("\n" + "=" * 60)
118
- print("TEST 5: Factory Provider Support")
119
- print("=" * 60)
120
-
121
- credentials = {
122
- "api_key": "ollama",
123
- "base_url": "http://localhost:11434/v1",
124
- }
125
-
126
- try:
127
- # Test all supported providers
128
- providers = ["openai", "ollama"]
129
- for provider in providers:
130
- model = ModelFactory.create_model(
131
- provider=provider,
132
- model_name="test-model",
133
- credentials=credentials,
134
- )
135
- print(f"✅ Provider '{provider}' supported: {model.__class__.__name__}")
136
- except Exception as e:
137
- print(f"❌ Provider support test failed: {e}")
138
- raise
139
-
140
-
141
- @pytest.mark.asyncio
142
- async def test_text_generation() -> None:
143
- """Test text generation with Ollama."""
144
- print("\n" + "=" * 60)
145
- print("TEST 6: Text Generation Request")
146
- print("=" * 60)
147
-
148
- credentials = {
149
- "api_key": "ollama",
150
- "base_url": "http://localhost:11434/v1",
151
- }
152
-
153
- try:
154
- model = ModelFactory.create_model(
155
- provider="ollama",
156
- model_name="gpt-oss:20b-cloud",
157
- credentials=credentials,
158
- )
159
-
160
- request = GenerateRequest(
161
- messages=[
162
- Message(
163
- role="user", content="Say 'Hello from Ollama!' in one sentence."
164
- )
165
- ],
166
- temperature=0.7,
167
- max_tokens=100,
168
- )
169
-
170
- print("📤 Sending request to Ollama...")
171
- print(" Model: gpt-oss:20b-cloud")
172
- print(f" Message: {request.messages[0].content}")
173
-
174
- response = await model.generate(request)
175
-
176
- print("✅ Generation successful")
177
- print(f" Response: {response.content}")
178
-
179
- except Exception as e:
180
- print(f"❌ Generation failed: {e}")
181
- raise
182
-
183
-
184
- @pytest.mark.asyncio
185
- async def test_streaming_generation() -> None:
186
- """Test streaming text generation with Ollama."""
187
- print("\n" + "=" * 60)
188
- print("TEST 7: Streaming Text Generation")
189
- print("=" * 60)
190
-
191
- credentials = {
192
- "api_key": "ollama",
193
- "base_url": "http://localhost:11434/v1",
194
- }
195
-
196
- try:
197
- model = ModelFactory.create_model(
198
- provider="ollama",
199
- model_name="gpt-oss:20b-cloud",
200
- credentials=credentials,
201
- )
202
-
203
- request = GenerateRequest(
204
- messages=[
205
- Message(role="user", content="Count from 1 to 5, one number per line.")
206
- ],
207
- temperature=0.5,
208
- max_tokens=100,
209
- )
210
-
211
- print("📤 Sending streaming request to Ollama...")
212
- print(" Model: gpt-oss:20b-cloud")
213
- print(f" Message: {request.messages[0].content}")
214
- print("\n Response stream:")
215
-
216
- full_content = ""
217
- async for chunk in model.generate_stream(request):
218
- if chunk.content:
219
- print(f" {chunk.content}", end="", flush=True)
220
- full_content += chunk.content
221
-
222
- print("\n\n✅ Streaming complete")
223
- print(f" Total content length: {len(full_content)} characters")
224
-
225
- except Exception as e:
226
- print(f"❌ Streaming failed: {e}")
227
- raise
228
-
229
-
230
- @pytest.mark.asyncio
231
- async def test_embedding_generation() -> None:
232
- """Test embedding generation with Ollama."""
233
- print("\n" + "=" * 60)
234
- print("TEST 8: Embedding Generation")
235
- print("=" * 60)
236
- try:
237
- model = ModelFactory.create_embedding_model(
238
- provider="openai",
239
- model_name="embeddinggemma",
240
- api_key="ollama",
241
- base_url="http://localhost:11434/v1",
242
- )
243
-
244
- texts = [
245
- "Hello from Ollama!",
246
- "This is a test embedding.",
247
- "Embedding models are useful for semantic search.",
248
- ]
249
-
250
- request = EmbeddingRequest(input=texts)
251
-
252
- print("📤 Sending embedding request to Ollama...")
253
- print(" Model: nomic-embed-text")
254
- print(f" Texts to embed: {len(texts)}")
255
- for i, text in enumerate(texts, 1):
256
- print(f" {i}. {text}")
257
-
258
- response = await model.embed(request)
259
-
260
- print("✅ Embedding successful")
261
- print(f" Number of embeddings: {len(response.embeddings)}")
262
- print(f" Embedding dimension: {len(response.embeddings[0])}")
263
- print(f" First embedding (first 5 values): {response.embeddings[0][:5]}")
264
-
265
- except Exception as e:
266
- print(f"❌ Embedding failed: {e}")
267
- raise
268
-
269
-
270
- @pytest.mark.asyncio
271
- async def test_multimodal_vision() -> None:
272
- """Test multimodal vision capabilities with Ollama."""
273
- print("\n" + "=" * 60)
274
- print("TEST 9: Multimodal Vision (Image Understanding)")
275
- print("=" * 60)
276
-
277
- credentials = {
278
- "api_key": "ollama",
279
- "base_url": "http://localhost:11434/v1",
280
- }
281
-
282
- try:
283
- model = ModelFactory.create_model(
284
- provider="ollama",
285
- model_name="qwen3-vl:235b-cloud",
286
- credentials=credentials,
287
- )
288
-
289
- # Check if model supports vision
290
- print("📋 Model capabilities:")
291
- print(f" Vision support: {model.supports_capability(ModelCapability.VISION)}")
292
- print(
293
- f" Multimodal input: {model.supports_capability(ModelCapability.MULTIMODAL_INPUT)}"
294
- )
295
-
296
- # Create a message with image URL (using a public test image)
297
- image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
298
-
299
- content_parts = [
300
- ContentPart(
301
- content_type=ContentType.TEXT,
302
- content="What do you see in this image? Describe it briefly.",
303
- ),
304
- ContentPart(
305
- content_type=ContentType.IMAGE_URL,
306
- content=image_url,
307
- mime_type="image/jpeg",
308
- ),
309
- ]
310
-
311
- message = Message(role="user", content=content_parts)
312
-
313
- request = GenerateRequest(
314
- messages=[message],
315
- temperature=0.7,
316
- max_tokens=200,
317
- )
318
-
319
- print("\n📤 Sending multimodal request to Ollama...")
320
- print(" Model: gpt-oss:20b-cloud")
321
- print(f" Message parts: {len(content_parts)}")
322
- print(" - Text: What do you see in this image?")
323
- print(f" - Image: {image_url}")
324
-
325
- response = await model.generate(request)
326
-
327
- print("\n✅ Vision analysis successful")
328
- print(f" Response: {response.content}")
329
-
330
- except Exception as e:
331
- print(f"❌ Vision test failed: {e}")
332
- raise
333
-
334
-
335
- @pytest.mark.asyncio
336
- async def test_multimodal_with_base64() -> None:
337
- """Test multimodal with base64 encoded image."""
338
- print("\n" + "=" * 60)
339
- print("TEST 10: Multimodal with Base64 Image")
340
- print("=" * 60)
341
-
342
- credentials = {
343
- "api_key": "ollama",
344
- "base_url": "http://localhost:11434/v1",
345
- }
346
-
347
- try:
348
- model = ModelFactory.create_model(
349
- provider="ollama",
350
- model_name="qwen3-vl:235b-cloud",
351
- credentials=credentials,
352
- )
353
-
354
- # For this test, we'll use a simple base64 encoded 1x1 pixel image
355
- # In real usage, you would encode an actual image
356
- base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
357
-
358
- content_parts = [
359
- ContentPart(
360
- content_type=ContentType.TEXT,
361
- content="Analyze this image and tell me what you see.",
362
- ),
363
- ContentPart(
364
- content_type=ContentType.IMAGE_BASE64,
365
- content=base64_image,
366
- mime_type="image/png",
367
- ),
368
- ]
369
-
370
- message = Message(role="user", content=content_parts)
371
-
372
- request = GenerateRequest(
373
- messages=[message],
374
- temperature=0.5,
375
- max_tokens=100,
376
- )
377
-
378
- print("📤 Sending base64 image request to Ollama...")
379
- print(" Model: gpt-oss:20b-cloud")
380
- print(f" Message parts: {len(content_parts)}")
381
- print(" - Text: Analyze this image")
382
- print(f" - Image (base64): {len(base64_image)} characters")
383
-
384
- response = await model.generate(request)
385
-
386
- print("✅ Base64 image analysis successful")
387
- print(f" Response: {response.content}")
388
-
389
- except Exception as e:
390
- print(f"❌ Base64 image test failed: {e}")
391
- raise
392
-
393
-
394
- async def run_async_tests() -> None:
395
- """Run async tests."""
396
- try:
397
- # await test_text_generation()
398
- # await test_streaming_generation()
399
- # await test_embedding_generation()
400
- await test_multimodal_vision()
401
- await test_multimodal_with_base64()
402
- except Exception as e:
403
- print(f"\n❌ Async tests failed: {e}")
404
- raise
405
-
406
-
407
- def main() -> None:
408
- """Run all tests."""
409
- print("\n")
410
- print("╔" + "=" * 58 + "╗")
411
- print("║" + " " * 58 + "║")
412
- print("║" + " OLLAMA LLM PROVIDER INTEGRATION TESTS".center(58) + "║")
413
- print("║" + " " * 58 + "║")
414
- print("╚" + "=" * 58 + "╝")
415
-
416
- try:
417
- # Sync tests
418
- # test_ollama_model_creation()
419
- # test_model_configuration()
420
- # test_embedding_model_creation()
421
- # test_factory_provider_support()
422
- # test_openai_model_direct_creation()
423
-
424
- # Async tests (actual API calls)
425
- print("\n" + "=" * 60)
426
- print("Running async tests (actual API calls)...")
427
- print("=" * 60)
428
- asyncio.run(run_async_tests())
429
-
430
- print("\n" + "=" * 60)
431
- print("✅ ALL TESTS PASSED!")
432
- print("=" * 60 + "\n")
433
-
434
- except Exception as e:
435
- print("\n" + "=" * 60)
436
- print(f"❌ TESTS FAILED: {e}")
437
- print("=" * 60 + "\n")
438
- raise
439
-
440
-
441
- if __name__ == "__main__":
442
- main()