pop-python 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- POP/Embedder.py +229 -0
- POP/LLMClient.py +403 -0
- POP/POP.py +392 -0
- POP/__init__.py +22 -0
- POP/prompts/2024-11-19-content_finder.md +46 -0
- POP/prompts/2024-11-19-get_content.md +71 -0
- POP/prompts/2024-11-19-get_title_and_url.md +62 -0
- POP/prompts/CLI_AI_helper.md +75 -0
- POP/prompts/content_finder.md +42 -0
- POP/prompts/corpus_splitter.md +28 -0
- POP/prompts/fabric-improve_prompt.md +518 -0
- POP/prompts/function_code_generator.md +51 -0
- POP/prompts/function_description_generator.md +45 -0
- POP/prompts/get_content.md +75 -0
- POP/prompts/get_title_and_url.md +62 -0
- POP/prompts/json_formatter_prompt.md +36 -0
- POP/prompts/openai-function_description_generator.md +126 -0
- POP/prompts/openai-json_schema_generator.md +165 -0
- POP/prompts/openai-prompt_generator.md +49 -0
- POP/schemas/biomedical_ner_extractor.json +37 -0
- POP/schemas/entity_extraction_per_sentence.json +92 -0
- {pop_python-1.0.0.dist-info → pop_python-1.0.2.dist-info}/METADATA +1 -1
- pop_python-1.0.2.dist-info/RECORD +26 -0
- pop_python-1.0.2.dist-info/top_level.txt +1 -0
- pop_python-1.0.0.dist-info/RECORD +0 -5
- pop_python-1.0.0.dist-info/top_level.txt +0 -1
- {pop_python-1.0.0.dist-info → pop_python-1.0.2.dist-info}/WHEEL +0 -0
- {pop_python-1.0.0.dist-info → pop_python-1.0.2.dist-info}/licenses/LICENSE +0 -0
POP/Embedder.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
# Embedder.py
|
|
2
|
+
import numpy as np
|
|
3
|
+
import openai
|
|
4
|
+
import requests as HTTPRequests ## some packages already have "requests"
|
|
5
|
+
from os import getenv
|
|
6
|
+
from backoff import on_exception, expo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from transformers import AutoTokenizer, AutoModel
|
|
10
|
+
|
|
11
|
+
MAX_TOKENS = 8194
|
|
12
|
+
|
|
13
|
+
class Embedder:
|
|
14
|
+
def __init__(self, model_name=None, use_api=None, to_cuda=False, attn_implementation=None):
|
|
15
|
+
"""
|
|
16
|
+
Initializes the Embedder class, which supports multiple embedding methods, including Jina API,
|
|
17
|
+
OpenAI API, and local model embeddings.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_name (str): Name of the model to use for embedding.
|
|
21
|
+
use_api (str): Flag to determine whether to use an API for embedding ('jina', 'openai') or a local model (None).
|
|
22
|
+
to_cuda (bool): If True, use GPU; otherwise use CPU. (Some model must run on GPU)
|
|
23
|
+
attn_implementation (str): Attention implementation method for the transformer model.
|
|
24
|
+
"""
|
|
25
|
+
self.use_api = use_api
|
|
26
|
+
self.model_name = model_name
|
|
27
|
+
self.to_cuda = to_cuda
|
|
28
|
+
|
|
29
|
+
# API-based embedding initialization
|
|
30
|
+
if self.use_api or self.use_api == "":
|
|
31
|
+
supported_apis = ["", 'jina', 'openai',]
|
|
32
|
+
if self.use_api not in supported_apis:
|
|
33
|
+
raise ValueError(f"API type '{self.use_api}' not supported. Supported APIs: {supported_apis}")
|
|
34
|
+
|
|
35
|
+
elif self.use_api == "": # default
|
|
36
|
+
self.use_api == 'openai'
|
|
37
|
+
|
|
38
|
+
elif self.use_api == 'jina':
|
|
39
|
+
pass # maybe add something later
|
|
40
|
+
|
|
41
|
+
elif self.use_api == 'openai':
|
|
42
|
+
self.client = openai.Client(api_key=getenv("OPENAI_API_KEY"))
|
|
43
|
+
else:
|
|
44
|
+
# Load PyTorch model for local embedding generation
|
|
45
|
+
if not model_name:
|
|
46
|
+
raise ValueError("Model name must be provided when using a local model.")
|
|
47
|
+
self.attn_implementation = attn_implementation
|
|
48
|
+
self._initialize_local_model()
|
|
49
|
+
|
|
50
|
+
def _initialize_local_model(self):
|
|
51
|
+
import torch # Importing PyTorch only when needed
|
|
52
|
+
import torch.nn.functional as F
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
"""Initializes the PyTorch model and tokenizer for local embedding generation."""
|
|
56
|
+
if self.attn_implementation:
|
|
57
|
+
self.model = AutoModel.from_pretrained(self.model_name,
|
|
58
|
+
trust_remote_code=True,
|
|
59
|
+
attn_implementation=self.attn_implementation,
|
|
60
|
+
torch_dtype=torch.float16).to('cuda' if self.to_cuda else 'cpu')
|
|
61
|
+
else:
|
|
62
|
+
self.model = AutoModel.from_pretrained(self.model_name,
|
|
63
|
+
trust_remote_code=True,
|
|
64
|
+
torch_dtype=torch.float16).to('cuda' if self.to_cuda else 'cpu')
|
|
65
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
66
|
+
self.model.eval()
|
|
67
|
+
|
|
68
|
+
def get_embedding(self, texts: list) -> np.ndarray:
|
|
69
|
+
"""
|
|
70
|
+
Generates embeddings for a list of texts.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
texts (list of str): A list of texts to be embedded.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
np.ndarray: The embeddings as a numpy array of shape (len(texts), embedding_dim).
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if not isinstance(texts, list):
|
|
80
|
+
raise ValueError("Input must be a list of strings.")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if self.use_api:
|
|
84
|
+
if self.use_api == 'jina':
|
|
85
|
+
if not self.model_name:
|
|
86
|
+
self.model_name = "jina-embeddings-v3"
|
|
87
|
+
print(f"use default model: {self.model_name}")
|
|
88
|
+
return self._get_jina_embedding(texts)
|
|
89
|
+
elif self.use_api == 'openai':
|
|
90
|
+
# set the default to be GPT embedding
|
|
91
|
+
if not self.model_name:
|
|
92
|
+
self.model_name = "text-embedding-3-small"
|
|
93
|
+
print(f"use default model: {self.model_name}")
|
|
94
|
+
return self._get_openai_embedding(texts)
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"API type '{self.use_api}' is not supported.")
|
|
97
|
+
else:
|
|
98
|
+
return self._get_torch_embedding(texts)
|
|
99
|
+
|
|
100
|
+
## Below are model-specific functions
|
|
101
|
+
|
|
102
|
+
@on_exception(expo, HTTPRequests.exceptions.RequestException, max_time=30)
|
|
103
|
+
def _get_jina_embedding(self, texts: list) -> np.ndarray:
|
|
104
|
+
"""Fetches embeddings from the Jina API. Requires Jina API key in .env file."""
|
|
105
|
+
url = 'https://api.jina.ai/v1/embeddings'
|
|
106
|
+
|
|
107
|
+
headers = {
|
|
108
|
+
'Content-Type': 'application/json',
|
|
109
|
+
'Authorization': f'Bearer {getenv("JINAAI_API_KEY")}'
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
input_texts = [text for text in texts]
|
|
113
|
+
data = {
|
|
114
|
+
"model": "jina-embeddings-v3",
|
|
115
|
+
"task": "text-matching",
|
|
116
|
+
"dimensions": 1024,
|
|
117
|
+
"late_chunking": False,
|
|
118
|
+
"embedding_type": "float",
|
|
119
|
+
"input": input_texts
|
|
120
|
+
}
|
|
121
|
+
response = HTTPRequests.post(url, headers=headers, json=data)
|
|
122
|
+
|
|
123
|
+
# Process the response
|
|
124
|
+
if response.status_code == 200:
|
|
125
|
+
# Extract embeddings from the response and convert them to a single NumPy array
|
|
126
|
+
embeddings = response.json().get('data', [])
|
|
127
|
+
embeddings_np = np.array([embedding_data['embedding'] for embedding_data in embeddings], dtype="f")
|
|
128
|
+
return embeddings_np
|
|
129
|
+
elif response.status_code == 429:
|
|
130
|
+
raise HTTPRequests.exceptions.RequestException(
|
|
131
|
+
f"Rate limit exceeded: {response.status_code}, {response.text}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
## When the input is too long, we need to segment the text
|
|
135
|
+
elif response.status_code == 400:
|
|
136
|
+
ebd = []
|
|
137
|
+
for text in texts:
|
|
138
|
+
chunks = self._Jina_segmenter(text, max_token=MAX_TOKENS)
|
|
139
|
+
token_counts = [len(chunk) for chunk in chunks]
|
|
140
|
+
chunk_embedding = self.get_embedding(chunks)
|
|
141
|
+
weighted_avg = np.average(chunk_embedding, weights=token_counts, axis=0)
|
|
142
|
+
ebd.append(weighted_avg)
|
|
143
|
+
return np.array(ebd, dtype="f")
|
|
144
|
+
|
|
145
|
+
else:
|
|
146
|
+
print(f"Error: {response.status_code}, {response.text}")
|
|
147
|
+
raise Exception(f"Failed to get embedding from Jina API: {response.status_code}, {response.text}")
|
|
148
|
+
|
|
149
|
+
@on_exception(expo, HTTPRequests.exceptions.RequestException, max_time=30)
|
|
150
|
+
def _get_openai_embedding(self, texts: list) -> np.ndarray:
|
|
151
|
+
"""Fetches embeddings from the OpenAI API and returns them as a NumPy array. Requires OpenAI API key in .env file."""
|
|
152
|
+
# openai embedding API has a limit on single batch size of 2048 texts, so we may need to batch here
|
|
153
|
+
batch_size = 2048
|
|
154
|
+
if len(texts) > batch_size:
|
|
155
|
+
all_embeddings = []
|
|
156
|
+
for i in range(0, len(texts), batch_size):
|
|
157
|
+
batch_texts = texts[i:i+batch_size]
|
|
158
|
+
batch_embeddings = self._get_openai_embedding(batch_texts)
|
|
159
|
+
all_embeddings.append(batch_embeddings)
|
|
160
|
+
return np.vstack(all_embeddings)
|
|
161
|
+
|
|
162
|
+
texts = [text.replace("\n", " ") for text in texts] # Clean text input
|
|
163
|
+
response = self.client.embeddings.create(input=texts, model=self.model_name)
|
|
164
|
+
|
|
165
|
+
# Extract embeddings from response
|
|
166
|
+
embeddings = [item.embedding for item in response.data]
|
|
167
|
+
|
|
168
|
+
# Convert the list of embeddings to a NumPy array with the desired data type
|
|
169
|
+
return np.array(embeddings, dtype="f")
|
|
170
|
+
|
|
171
|
+
def _get_torch_embedding(self, texts: list) -> np.ndarray:
|
|
172
|
+
"""Generates embeddings using a local PyTorch model."""
|
|
173
|
+
import torch # Importing PyTorch only when needed
|
|
174
|
+
@torch.no_grad()
|
|
175
|
+
def _encode(self, input_texts):
|
|
176
|
+
"""
|
|
177
|
+
Generates embeddings for a list of texts using a pytorch local model.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
input_texts (list of str): A list of texts to encode.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
np.ndarray: An array of embeddings.
|
|
184
|
+
"""
|
|
185
|
+
batch_dict = self.tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt', return_attention_mask=True).to('cuda' if self.to_cuda else 'cpu')
|
|
186
|
+
|
|
187
|
+
outputs = self.model(**batch_dict)
|
|
188
|
+
attention_mask = batch_dict["attention_mask"]
|
|
189
|
+
hidden = outputs.last_hidden_state
|
|
190
|
+
|
|
191
|
+
reps = _weighted_mean_pooling(hidden, attention_mask)
|
|
192
|
+
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
|
|
193
|
+
return embeddings
|
|
194
|
+
|
|
195
|
+
def _weighted_mean_pooling(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
196
|
+
"""
|
|
197
|
+
Computes weighted mean pooling over the hidden states.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
hidden (torch.Tensor): The hidden states output from the transformer model.
|
|
201
|
+
attention_mask (torch.Tensor): The attention mask for the input sequences.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
torch.Tensor: The pooled representation of the input.
|
|
205
|
+
"""
|
|
206
|
+
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
|
207
|
+
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
|
|
208
|
+
d = attention_mask_.sum(dim=1, keepdim=True).float()
|
|
209
|
+
reps = s / d
|
|
210
|
+
return reps
|
|
211
|
+
|
|
212
|
+
return _encode(self, texts)
|
|
213
|
+
|
|
214
|
+
@on_exception(expo, HTTPRequests.exceptions.RequestException, max_time=30)
|
|
215
|
+
def _Jina_segmenter(self, text: str, max_token: int) -> list[str]:
|
|
216
|
+
"""Segments text into chunks using Jina API. (free but need API key)"""
|
|
217
|
+
url = 'https://segment.jina.ai/'
|
|
218
|
+
headers = {
|
|
219
|
+
'Content-Type': 'application/json',
|
|
220
|
+
'Authorization': f'Bearer {getenv("JINAAI_API_KEY")}'
|
|
221
|
+
}
|
|
222
|
+
data = {
|
|
223
|
+
"content": text,
|
|
224
|
+
"return_tokens": True,
|
|
225
|
+
"return_chunks": True,
|
|
226
|
+
"max_chunk_length": max_token
|
|
227
|
+
}
|
|
228
|
+
response = HTTPRequests.post(url, headers=headers, json=data)
|
|
229
|
+
return response.json().get('chunks', [])
|
POP/LLMClient.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from xml.parsers.expat import model # unused but kept for consistency
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from os import getenv
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
# Note: The real POP repository uses the ``openai`` and ``google.genai``
|
|
8
|
+
# libraries to access external LLM services. Those packages are not
|
|
9
|
+
# available in this environment, so the client implementations here
|
|
10
|
+
# serve as placeholders. They preserve the API surface but will
|
|
11
|
+
# raise at runtime if invoked without the required third‑party
|
|
12
|
+
# dependencies. If you wish to use remote LLMs, ensure the
|
|
13
|
+
# corresponding packages are installed and API keys are set in your
|
|
14
|
+
# environment.
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from openai import OpenAI
|
|
18
|
+
except Exception:
|
|
19
|
+
OpenAI = None # type: ignore
|
|
20
|
+
try:
|
|
21
|
+
from google import genai
|
|
22
|
+
from google.genai import types
|
|
23
|
+
except Exception:
|
|
24
|
+
genai = None # type: ignore
|
|
25
|
+
types = None # type: ignore
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
##############################################
|
|
29
|
+
# LLM Client Interface and Implementations
|
|
30
|
+
##############################################
|
|
31
|
+
|
|
32
|
+
class LLMClient(ABC):
|
|
33
|
+
"""
|
|
34
|
+
Abstract Base Class for LLM Clients.
|
|
35
|
+
"""
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def chat_completion(self, messages: list, model: str, temperature: float, **kwargs):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class OpenAIClient(LLMClient):
|
|
42
|
+
"""
|
|
43
|
+
OpenAI API Client implementation. Requires the ``openai`` package.
|
|
44
|
+
"""
|
|
45
|
+
def __init__(self):
|
|
46
|
+
if OpenAI is None:
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"openai package is not installed. Install it to use OpenAIClient."
|
|
49
|
+
)
|
|
50
|
+
# Instantiate a new OpenAI client with the API key.
|
|
51
|
+
self.client = OpenAI(api_key=getenv("OPENAI_API_KEY"))
|
|
52
|
+
|
|
53
|
+
def chat_completion(self, messages: list, model: str, temperature: float = 0.7, **kwargs):
|
|
54
|
+
request_payload = {
|
|
55
|
+
"model": model,
|
|
56
|
+
"messages": [],
|
|
57
|
+
"temperature": temperature
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Optional images
|
|
61
|
+
images = kwargs.pop("images", None)
|
|
62
|
+
|
|
63
|
+
# Build OpenAI-style messages
|
|
64
|
+
for msg in messages:
|
|
65
|
+
content = msg.get("content", "")
|
|
66
|
+
role = msg.get("role", "user")
|
|
67
|
+
|
|
68
|
+
# If images are provided, attach them to the last user message
|
|
69
|
+
if images and role == "user":
|
|
70
|
+
multi_content = [{"type": "text", "text": content}]
|
|
71
|
+
for img in images:
|
|
72
|
+
if isinstance(img, str) and img.startswith("http"):
|
|
73
|
+
multi_content.append({"type": "image_url", "image_url": {"url": img}})
|
|
74
|
+
else:
|
|
75
|
+
multi_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}})
|
|
76
|
+
content = multi_content
|
|
77
|
+
request_payload["messages"].append({"role": role, "content": content})
|
|
78
|
+
|
|
79
|
+
# Handle response_format (JSON schema)
|
|
80
|
+
fmt = kwargs.get("response_format", None)
|
|
81
|
+
if fmt:
|
|
82
|
+
if isinstance(fmt, BaseModel):
|
|
83
|
+
request_payload["response_format"] = fmt
|
|
84
|
+
else:
|
|
85
|
+
request_payload["response_format"] = {"type": "json_schema", "json_schema": fmt}
|
|
86
|
+
|
|
87
|
+
# Handle function tools
|
|
88
|
+
tools = kwargs.get("tools", None)
|
|
89
|
+
if tools:
|
|
90
|
+
request_payload["tools"] = [{"type": "function", "function": tool} for tool in tools]
|
|
91
|
+
request_payload["tool_choice"] = "auto"
|
|
92
|
+
|
|
93
|
+
# Temporary patch for models not supporting system roles
|
|
94
|
+
if model == "o1-mini" and request_payload["messages"] and request_payload["messages"][0]["role"] == "system":
|
|
95
|
+
request_payload["messages"][0]["role"] = "user"
|
|
96
|
+
|
|
97
|
+
# Execute request
|
|
98
|
+
try:
|
|
99
|
+
response = self.client.chat.completions.create(**request_payload)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
raise RuntimeError(f"OpenAI chat_completion error: {e}")
|
|
102
|
+
return response
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class LocalPyTorchClient(LLMClient):
|
|
106
|
+
"""
|
|
107
|
+
Local PyTorch-based LLM client.
|
|
108
|
+
(Placeholder implementation)
|
|
109
|
+
"""
|
|
110
|
+
def chat_completion(self, messages: list, model: str, temperature: float, **kwargs):
|
|
111
|
+
return "Local PyTorch LLM response (stub)"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class DeepseekClient(LLMClient):
|
|
115
|
+
"""
|
|
116
|
+
Deepseek API client. Requires ``openai`` with a Deepseek API base URL.
|
|
117
|
+
"""
|
|
118
|
+
def __init__(self):
|
|
119
|
+
if OpenAI is None:
|
|
120
|
+
raise ImportError(
|
|
121
|
+
"openai package is not installed. Install it to use DeepseekClient."
|
|
122
|
+
)
|
|
123
|
+
self.client = OpenAI(api_key=getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")
|
|
124
|
+
|
|
125
|
+
def chat_completion(self, messages: list, model: str, temperature: float, **kwargs):
|
|
126
|
+
request_payload = {
|
|
127
|
+
"model": model,
|
|
128
|
+
"messages": [],
|
|
129
|
+
"temperature": temperature
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
# Optional images
|
|
133
|
+
images = kwargs.pop("images", None)
|
|
134
|
+
if images:
|
|
135
|
+
raise NotImplementedError("DeepseekClient does not support images yet.")
|
|
136
|
+
|
|
137
|
+
# Build OpenAI-style messages
|
|
138
|
+
for msg in messages:
|
|
139
|
+
content = msg.get("content", "")
|
|
140
|
+
role = msg.get("role", "user")
|
|
141
|
+
request_payload["messages"].append({"role": role, "content": content})
|
|
142
|
+
|
|
143
|
+
# Execute request
|
|
144
|
+
try:
|
|
145
|
+
response = self.client.chat.completions.create(**request_payload)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
raise RuntimeError(f"Deepseek chat_completion error: {e}")
|
|
148
|
+
return response
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class GeminiClient(LLMClient):
|
|
152
|
+
"""
|
|
153
|
+
GCP Gemini API client. Requires the ``google.generativeai`` library.
|
|
154
|
+
"""
|
|
155
|
+
def __init__(self, model="gemini-2.5-flash"):
|
|
156
|
+
if genai is None or types is None:
|
|
157
|
+
raise ImportError(
|
|
158
|
+
"google generativeai package is not installed. Install it to use GeminiClient."
|
|
159
|
+
)
|
|
160
|
+
self.client = genai.Client(api_key=getenv("GEMINI_API_KEY"))
|
|
161
|
+
self.model_name = model
|
|
162
|
+
|
|
163
|
+
def chat_completion(self, messages: list, model: str = None, temperature: float = 0.7, **kwargs):
|
|
164
|
+
model_name = model or self.model_name
|
|
165
|
+
|
|
166
|
+
# Extract system instruction and user content
|
|
167
|
+
system_instruction = None
|
|
168
|
+
user_contents = []
|
|
169
|
+
for msg in messages:
|
|
170
|
+
role = msg.get("role", "user")
|
|
171
|
+
content = msg.get("content", "")
|
|
172
|
+
if role == "system" and system_instruction is None:
|
|
173
|
+
system_instruction = content
|
|
174
|
+
else:
|
|
175
|
+
user_contents.append(content)
|
|
176
|
+
|
|
177
|
+
# Prepare multimodal contents
|
|
178
|
+
contents = []
|
|
179
|
+
images = kwargs.pop("images", None)
|
|
180
|
+
|
|
181
|
+
if images:
|
|
182
|
+
try:
|
|
183
|
+
from PIL import Image
|
|
184
|
+
import base64
|
|
185
|
+
from io import BytesIO
|
|
186
|
+
except Exception:
|
|
187
|
+
raise ImportError("PIL and base64 are required for image support in GeminiClient.")
|
|
188
|
+
|
|
189
|
+
for img in images:
|
|
190
|
+
# Accept base64 or PIL
|
|
191
|
+
if isinstance(img, Image.Image):
|
|
192
|
+
contents.append(img)
|
|
193
|
+
elif isinstance(img, str):
|
|
194
|
+
try:
|
|
195
|
+
# Base64 -> PIL Image
|
|
196
|
+
img_data = base64.b64decode(img)
|
|
197
|
+
image = Image.open(BytesIO(img_data))
|
|
198
|
+
contents.append(image)
|
|
199
|
+
except Exception:
|
|
200
|
+
# Assume URL string
|
|
201
|
+
contents.append(img)
|
|
202
|
+
# Add text content last
|
|
203
|
+
if user_contents:
|
|
204
|
+
contents.append("\n".join(user_contents))
|
|
205
|
+
|
|
206
|
+
# Config with system instruction and temperature
|
|
207
|
+
gen_config = types.GenerateContentConfig(
|
|
208
|
+
temperature=temperature,
|
|
209
|
+
system_instruction=system_instruction
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
response = self.client.models.generate_content(
|
|
214
|
+
model=model_name,
|
|
215
|
+
contents=contents,
|
|
216
|
+
config=gen_config
|
|
217
|
+
)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
raise RuntimeError(f"Gemini chat_completion error: {e}")
|
|
220
|
+
|
|
221
|
+
# Wrap response in OpenAI-like structure for PromptFunction compatibility
|
|
222
|
+
class FakeMessage:
|
|
223
|
+
def __init__(self, content):
|
|
224
|
+
self.content = content
|
|
225
|
+
self.tool_calls = None
|
|
226
|
+
|
|
227
|
+
class FakeChoice:
|
|
228
|
+
def __init__(self, message):
|
|
229
|
+
self.message = message
|
|
230
|
+
|
|
231
|
+
class FakeResponse:
|
|
232
|
+
def __init__(self, text):
|
|
233
|
+
self.choices = [FakeChoice(FakeMessage(text))]
|
|
234
|
+
|
|
235
|
+
return FakeResponse(response.text or "")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class DoubaoClient(LLMClient):
|
|
239
|
+
"""
|
|
240
|
+
Doubao (Volcengine Ark) API client.
|
|
241
|
+
This is a stub implementation because the required API library is
|
|
242
|
+
not available in this environment. If you need to use Doubao
|
|
243
|
+
models, install the relevant client from Volcengine and update
|
|
244
|
+
this class accordingly.
|
|
245
|
+
"""
|
|
246
|
+
def __init__(self):
|
|
247
|
+
if OpenAI is None:
|
|
248
|
+
raise ImportError(
|
|
249
|
+
"openai package is not installed. Install it to use DoubaoClient."
|
|
250
|
+
)
|
|
251
|
+
# The base URL for Doubao's API and API key must be provided via environment
|
|
252
|
+
self.client = OpenAI(base_url="https://ark.cn-beijing.volces.com/api/v3",
|
|
253
|
+
api_key=getenv('DOUBAO_API_KEY'))
|
|
254
|
+
|
|
255
|
+
def chat_completion(self, messages: list, model: str, temperature: float = 0.7, **kwargs):
|
|
256
|
+
payload = {
|
|
257
|
+
"model": model,
|
|
258
|
+
"messages": [],
|
|
259
|
+
"temperature": temperature,
|
|
260
|
+
}
|
|
261
|
+
images = kwargs.pop("images", None)
|
|
262
|
+
|
|
263
|
+
## If the images are passed as string URLs or base64, wrap them in a list
|
|
264
|
+
if images and not isinstance(images, list):
|
|
265
|
+
images = [images]
|
|
266
|
+
|
|
267
|
+
# Pass through common knobs if present
|
|
268
|
+
passthrough = [
|
|
269
|
+
"top_p","max_tokens","stop","frequency_penalty","presence_penalty",
|
|
270
|
+
"logprobs","top_logprobs","logit_bias","service_tier","thinking",
|
|
271
|
+
"stream","stream_options",
|
|
272
|
+
]
|
|
273
|
+
for k in passthrough:
|
|
274
|
+
if k in kwargs and kwargs[k] is not None:
|
|
275
|
+
payload[k] = kwargs[k]
|
|
276
|
+
|
|
277
|
+
# Messages (attach images on user turns, same structure used for OpenAI)
|
|
278
|
+
for msg in messages:
|
|
279
|
+
role = msg.get("role", "user")
|
|
280
|
+
content = msg.get("content", "")
|
|
281
|
+
if images and role == "user":
|
|
282
|
+
multi = [{"type": "text", "text": content}]
|
|
283
|
+
for img in images:
|
|
284
|
+
if isinstance(img, str) and img.startswith("http"):
|
|
285
|
+
multi.append({"type": "image_url", "image_url": {"url": img}})
|
|
286
|
+
else:
|
|
287
|
+
multi.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}})
|
|
288
|
+
content = multi
|
|
289
|
+
payload["messages"].append({"role": role, "content": content})
|
|
290
|
+
|
|
291
|
+
# Tools (function calling)
|
|
292
|
+
tools = kwargs.get("tools")
|
|
293
|
+
if tools:
|
|
294
|
+
raise NotImplementedError("DoubaoClient does not support tools yet.")
|
|
295
|
+
try:
|
|
296
|
+
response = self.client.chat.completions.create(**payload)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
raise RuntimeError(f"Doubao chat_completion error: {e}")
|
|
299
|
+
return response
|
|
300
|
+
|
|
301
|
+
class OllamaClient(LLMClient):
|
|
302
|
+
"""
|
|
303
|
+
Ollama-compatible LLM client using the /api/generate endpoint.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(self, model="llama3:latest", base_url="http://localhost:11434", default_options=None, timeout=300):
|
|
307
|
+
self.model = model
|
|
308
|
+
self.base_url = base_url
|
|
309
|
+
# sensible defaults for extraction tasks
|
|
310
|
+
self.default_options = default_options or {
|
|
311
|
+
"num_ctx": 8192, # ↑ context window so long docs don't truncate
|
|
312
|
+
"temperature": 0.02, # low variance, more literal
|
|
313
|
+
"top_p": 0.9,
|
|
314
|
+
"top_k": 40,
|
|
315
|
+
"repeat_penalty": 1.05,
|
|
316
|
+
"mirostat": 0 # disable mirostat for predictable outputs
|
|
317
|
+
}
|
|
318
|
+
self.timeout = timeout
|
|
319
|
+
|
|
320
|
+
def chat_completion(self, messages: list, model: str = None, temperature: float = 0.7, **kwargs):
|
|
321
|
+
# Extract a proper system string for /api/generate
|
|
322
|
+
system_parts, user_assistant_lines = [], []
|
|
323
|
+
for msg in messages:
|
|
324
|
+
role = msg.get("role", "user")
|
|
325
|
+
content = msg.get("content", "")
|
|
326
|
+
if role == "system":
|
|
327
|
+
system_parts.append(content)
|
|
328
|
+
elif role == "assistant":
|
|
329
|
+
user_assistant_lines.append(f"[Assistant]: {content}")
|
|
330
|
+
else:
|
|
331
|
+
user_assistant_lines.append(f"[User]: {content}")
|
|
332
|
+
|
|
333
|
+
system = "\n".join(system_parts) if system_parts else None
|
|
334
|
+
prompt = "\n".join(user_assistant_lines)
|
|
335
|
+
|
|
336
|
+
# Merge caller-provided options with our defaults
|
|
337
|
+
# Allow both a dict under 'ollama_options' and top-level knobs (max_tokens -> num_predict)
|
|
338
|
+
caller_opts = kwargs.pop("ollama_options", {}) or {}
|
|
339
|
+
options = {**self.default_options, **caller_opts}
|
|
340
|
+
|
|
341
|
+
# keep legacy temperature kw in sync with options
|
|
342
|
+
if temperature is not None:
|
|
343
|
+
options["temperature"] = temperature
|
|
344
|
+
|
|
345
|
+
payload = {
|
|
346
|
+
"model": model or self.model,
|
|
347
|
+
"prompt": prompt,
|
|
348
|
+
"stream": False,
|
|
349
|
+
# num_predict must be top level for /api/generate
|
|
350
|
+
"num_predict": kwargs.get("max_tokens", 1024),
|
|
351
|
+
"options": options,
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
# pass system separately (clearer than prepending to prompt)
|
|
355
|
+
if system:
|
|
356
|
+
payload["system"] = system
|
|
357
|
+
|
|
358
|
+
# JSON mode / schema
|
|
359
|
+
fmt = kwargs.get("response_format")
|
|
360
|
+
if fmt:
|
|
361
|
+
fmt = self._normalize_schema(fmt)
|
|
362
|
+
payload["format"] = fmt # raw JSON schema or "json"
|
|
363
|
+
|
|
364
|
+
# optional stops and keep-alive
|
|
365
|
+
if "stop" in kwargs and kwargs["stop"]:
|
|
366
|
+
payload["stop"] = kwargs["stop"]
|
|
367
|
+
if "keep_alive" in kwargs and kwargs["keep_alive"]:
|
|
368
|
+
payload["keep_alive"] = kwargs["keep_alive"]
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
response = requests.post(f"{self.base_url}/api/generate", json=payload, timeout=self.timeout)
|
|
372
|
+
response.raise_for_status()
|
|
373
|
+
content = response.json().get("response", "")
|
|
374
|
+
return self._wrap_response(content)
|
|
375
|
+
except Exception as e:
|
|
376
|
+
raise RuntimeError(f"OllamaClient error: {e}")
|
|
377
|
+
# normalize: accept raw dict, {"schema": {...}}, JSON string, or path
|
|
378
|
+
def _normalize_schema(self, fmt):
|
|
379
|
+
import json, os
|
|
380
|
+
if fmt is None:
|
|
381
|
+
return None
|
|
382
|
+
if isinstance(fmt, str):
|
|
383
|
+
# try file path, else JSON string
|
|
384
|
+
if os.path.exists(fmt):
|
|
385
|
+
return json.load(open(fmt, "r", encoding="utf-8"))
|
|
386
|
+
return json.loads(fmt)
|
|
387
|
+
if isinstance(fmt, dict) and "schema" in fmt and isinstance(fmt["schema"], dict):
|
|
388
|
+
return fmt["schema"] # <-- unwrap OpenAI-style wrapper
|
|
389
|
+
if isinstance(fmt, dict):
|
|
390
|
+
return fmt # already a schema object
|
|
391
|
+
raise TypeError("response_format must be a JSON schema dict, a JSON string, or a file path")
|
|
392
|
+
|
|
393
|
+
def _wrap_response(self, content: str):
|
|
394
|
+
class Message:
|
|
395
|
+
def __init__(self, content): self.content = content; self.tool_calls = None
|
|
396
|
+
|
|
397
|
+
class Choice:
|
|
398
|
+
def __init__(self, message): self.message = message
|
|
399
|
+
|
|
400
|
+
class Response:
|
|
401
|
+
def __init__(self, content): self.choices = [Choice(Message(content))]
|
|
402
|
+
|
|
403
|
+
return Response(content)
|