ehs-llm-client 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from .client import LLM
2
+
3
+ __all__ = ["LLM"]
@@ -0,0 +1,412 @@
1
+ import os
2
+ import io
3
+ import json
4
+ import asyncio
5
+ from typing import Any
6
+
7
+ import openai
8
+ from google import genai
9
+ from dotenv import load_dotenv
10
+
11
+ from .config import load_llm_config
12
+ from .utils import to_timestamp
13
+ from .exceptions import LLMProviderError
14
+
15
+ load_dotenv()
16
+
17
+ LLM_TIMEOUT = os.getenv('LLM_TIMEOUT', 180)
18
+ LLM_TIMEOUT = int(LLM_TIMEOUT)
19
+ RETRIES = os.getenv('RETRIES', 3)
20
+ RETRIES = int(RETRIES)
21
+
22
+ class LLM:
23
+ def __init__(
24
+ self,
25
+ config_section_name: str,
26
+ config_file_path: str | None = None,
27
+ config: dict | None = None,
28
+ ):
29
+ self.config_section_name = config_section_name
30
+
31
+ cfg = load_llm_config(
32
+ config_section_name,
33
+ config_file_path=config_file_path,
34
+ config_dict=config,
35
+ )
36
+
37
+ self.provider = cfg.get("provider")
38
+ self.model = cfg.get("model")
39
+ self.model_batch = cfg.get("model_batch")
40
+ self.base_url = cfg.get("base_url")
41
+ self.api_version = cfg.get("api_version")
42
+
43
+ if not self.provider or not self.model:
44
+ raise ValueError("provider and model are required")
45
+
46
+ key_env = f"{self.provider}APIkey"
47
+ self.api_key = os.getenv(key_env)
48
+
49
+ if not self.api_key:
50
+ raise ValueError(f"Missing API key env var: {key_env}")
51
+
52
+ self.client = self._init_client()
53
+
54
+ def _init_client(self):
55
+ if self.provider == "openai":
56
+ return openai.AsyncOpenAI(api_key=self.api_key)
57
+
58
+ if self.provider == "azure_openai":
59
+ return openai.AsyncAzureOpenAI(
60
+ api_key=self.api_key,
61
+ api_version=self.api_version,
62
+ azure_endpoint=self.base_url,
63
+ )
64
+
65
+ if self.provider == "google":
66
+ return genai.Client(api_key=self.api_key)
67
+
68
+ raise LLMProviderError(f"Unsupported provider: {self.provider}")
69
+
70
+ def __repr__(self):
71
+ return (
72
+ f"LLM(provider={self.provider}, model={self.model}, "
73
+ f"api_key_set={bool(self.api_key)})"
74
+ )
75
+
76
+ async def call_async(self, messages: list[dict[str, str]], schema: dict | str = None) -> Any:
77
+ # 1. Handle Google Provider
78
+ if self.provider == 'google':
79
+ # If schema is the special string, we enable JSON mode without a schema
80
+ if schema == "return json":
81
+ gemini_schema = None
82
+ else:
83
+ gemini_schema = schema
84
+ # Clean up schema if it's a dictionary
85
+ if isinstance(gemini_schema, dict) and "additionalProperties" in gemini_schema:
86
+ del gemini_schema["additionalProperties"]
87
+
88
+ system_instr = None
89
+ contents = []
90
+ for m in messages:
91
+ if m["role"] == "system":
92
+ system_instr = m["content"]
93
+ else:
94
+ role = "model" if m["role"] == "assistant" else "user"
95
+ contents.append({"role": role, "parts": [{"text": m["content"]}]})
96
+
97
+ return await self.client.aio.models.generate_content(
98
+ model=self.model,
99
+ contents=contents,
100
+ config={
101
+ "system_instruction": system_instr,
102
+ "response_mime_type": "application/json",
103
+ "response_schema": gemini_schema,
104
+ "thinking_config": {"include_thoughts": False}
105
+ }
106
+ )
107
+
108
+ # 2. Handle OpenAI / Azure Provider
109
+ elif self.provider in ['openai', 'azure_openai']:
110
+ if schema == "return json":
111
+ # Enable basic JSON Mode (no strict schema adherence)
112
+ text = {"format": {"type": "json_object"}}
113
+ # OpenAI requires "JSON" to be mentioned in the prompt for json_object mode
114
+ if not any("json" in m["content"].lower() for m in messages):
115
+ messages.append({"role": "system", "content": "Respond in valid JSON format."})
116
+
117
+ elif schema:
118
+ # Enable Structured Outputs with a specific schema
119
+ text = {
120
+ "format": {
121
+ "type": "json_schema",
122
+ "name": "schema",
123
+ "schema": schema
124
+ }
125
+ }
126
+ else:
127
+ text = None
128
+
129
+ api_kwargs = {
130
+ "model": self.model,
131
+ "input": messages,
132
+ # "reasoning": {"effort": "medium"},
133
+ "text": text
134
+ }
135
+
136
+ return await self.client.responses.create(**api_kwargs)
137
+
138
+ async def get_response_async(self, messages: list[dict[str, str]], schema: dict | str = None, max_retries: int = 3, timeout: int = LLM_TIMEOUT):
139
+ retries = 1
140
+ while retries <= max_retries:
141
+ try:
142
+ response = await asyncio.wait_for(self.call_async(messages, schema), timeout=timeout)
143
+
144
+ if self.provider == "google":
145
+ # --- GOOGLE PARSING ---
146
+ finish_reason = response.candidates[0].finish_reason
147
+ if finish_reason in ["SAFETY", "RECITATION"]:
148
+ retries += 1
149
+ continue
150
+
151
+ output_text = "".join(
152
+ part.text for part in response.candidates[0].content.parts if part.text
153
+ )
154
+ in_tokens = response.usage_metadata.prompt_token_count
155
+ out_tokens = response.usage_metadata.candidates_token_count
156
+
157
+ elif self.provider == "openai" or self.provider == "azure_openai":
158
+ # --- OPENAI PARSING ---
159
+ if getattr(response, 'incomplete_details', None) == "content_filter":
160
+ retries += 1
161
+ continue
162
+
163
+ output_text = response.output_text
164
+ in_tokens = response.usage.input_tokens
165
+ out_tokens = response.usage.output_tokens
166
+
167
+ # Both providers now return valid JSON strings due to the strict schema
168
+ if isinstance(output_text, dict):
169
+ print('Returning dict directly')
170
+ return output_text, in_tokens, out_tokens
171
+
172
+ return output_text, in_tokens, out_tokens
173
+ # # Otherwise (e.g., for Gemini), try to parse the string
174
+ # try:
175
+ # return json.loads(output_text), in_tokens, out_tokens
176
+ # except (json.JSONDecodeError, TypeError):
177
+ # # If parsing fails or it's an unexpected type, return as-is
178
+ # return output_text, in_tokens, out_tokens
179
+
180
+ except Exception as e:
181
+ error_message = str(e)
182
+ # Handle exponential backoff for rate limits (429)
183
+ if "429" in error_message or "ResourceExhausted" in error_message:
184
+ retry_seconds = int(2 ** retries)
185
+ else:
186
+ retry_seconds = 1
187
+
188
+ if retries == max_retries:
189
+ raise RuntimeError(f"LLM call failed after {max_retries} attempts: {e}")
190
+
191
+ await asyncio.sleep(retry_seconds)
192
+ retries += 1
193
+
194
+ #New code, need testing
195
+ def create_batch_request(self, custom_id: str, messages: list[dict[str, str]], schema: dict):
196
+ """
197
+ Generates a single line for a Batch API JSONL file.
198
+ Matches the logic of call_async by supporting multi-turn messages and system instructions.
199
+ """
200
+ provider = self.provider.lower()
201
+
202
+ if schema == "return json":
203
+ # Enable basic JSON Mode (no strict schema adherence)
204
+ text = {"format": {"type": "json_object"}}
205
+ # OpenAI requires "JSON" to be mentioned in the prompt for json_object mode
206
+ if not any("json" in m["content"].lower() for m in messages):
207
+ messages.append({"role": "system", "content": "Respond in valid JSON format."})
208
+
209
+ elif schema:
210
+ text = {
211
+ "format": {
212
+ "type": "json_schema",
213
+ "name": "schema",
214
+ "schema": schema
215
+ }
216
+ }
217
+ else:
218
+ text = None
219
+
220
+ if provider == "openai" or provider == "azure_openai":
221
+ return {
222
+ "custom_id": custom_id,
223
+ "method": "POST",
224
+ "url": "/v1/responses",
225
+ "body": {
226
+ "model": self.model_batch,
227
+ "input": messages,
228
+ # "reasoning": {"effort": "medium"},
229
+ "text": text
230
+ }
231
+ }
232
+
233
+ elif provider == "google":
234
+ if not isinstance(schema, str):
235
+ # 1. Clean the schema for Gemini compatibility
236
+ gemini_schema = schema.copy() if schema else None
237
+ if gemini_schema and "additionalProperties" in gemini_schema:
238
+ del gemini_schema["additionalProperties"]
239
+ else:
240
+ gemini_schema = None
241
+
242
+ # 2. Process messages into Gemini format (contents + system_instruction)
243
+ system_instr = None
244
+ contents = []
245
+ for m in messages:
246
+ if m["role"] == "system":
247
+ # In Gemini Batch, system_instruction is typically a Content object
248
+ system_instr = {"parts": [{"text": m["content"]}]}
249
+ else:
250
+ role = "model" if m["role"] == "assistant" else "user"
251
+ contents.append({
252
+ "role": role,
253
+ "parts": [{"text": m["content"]}]
254
+ })
255
+
256
+ return {
257
+ "key": custom_id, # Gemini uses 'key' instead of 'custom_id'
258
+ "request": {
259
+ "model": self.model_batch,
260
+ "contents": contents,
261
+ "system_instruction": system_instr,
262
+ "generation_config": {
263
+ "response_mime_type": "application/json",
264
+ "response_schema": gemini_schema,
265
+ "thinking_config": {"include_thoughts": False}
266
+ }
267
+ }
268
+ }
269
+
270
+ else:
271
+ raise ValueError(f"Unsupported provider: {provider}")
272
+
273
+ async def run_batch_process(self,
274
+ batch_lines: list[str],
275
+ submission_id: str
276
+ ) -> str:
277
+ """
278
+ Unified function to upload and start a batch job.
279
+ Returns the batch_id (OpenAI) or batch_name (Google).
280
+ """
281
+ if not batch_lines:
282
+ return None
283
+
284
+ file_content_str = "\n".join(batch_lines)
285
+ provider = self.provider.lower()
286
+
287
+ if provider == "openai":
288
+ # 1. Upload File to OpenAI
289
+ batch_file = await self.client.files.create(
290
+ file=file_content_str.encode('utf-8'),
291
+ purpose="batch"
292
+ )
293
+ file_id = batch_file.id
294
+
295
+ # 2. Create OpenAI Batch Job
296
+ batch_job = await self.client.batches.create(
297
+ input_file_id=file_id,
298
+ endpoint="/v1/responses",
299
+ completion_window="24h"
300
+ )
301
+ return batch_job.id
302
+
303
+ elif provider == "google":
304
+ # 1. Upload File to Google
305
+ # Google SDK prefers a file path or a file-like object
306
+ file_io = io.BytesIO(file_content_str.encode('utf-8'))
307
+
308
+ # Note: Google's upload is currently synchronous in the SDK
309
+ # We use a unique name to avoid collisions
310
+ uploaded_file = self.client.files.upload(
311
+ file=file_io,
312
+ config={'display_name': f"batch_{submission_id}", 'mime_type': 'application/jsonl'}
313
+ )
314
+
315
+ # 2. Create Google Batch Job
316
+ # Note: Gemini 3.0 Flash uses 'models/gemini-3-flash-preview'
317
+ batch_job = self.client.batches.create(
318
+ model=self.model,
319
+ src=uploaded_file.name,
320
+ )
321
+
322
+ return batch_job.name
323
+
324
+ else:
325
+ raise ValueError(f"Unsupported provider: {provider}")
326
+
327
+
328
+
329
+ async def cancel_batch(self,batch_id):
330
+ try:
331
+ await self.client.batches.cancel(batch_id)
332
+ return True
333
+ except Exception as e:
334
+ print(f"Error cancelling batch {batch_id}: {e}")
335
+ return False
336
+
337
+
338
+ async def download_batch_results(self, output_file_id: str) -> str:
339
+ """
340
+ Downloads the result file content as a string.
341
+ Works for OpenAI and Google Gemini.
342
+ """
343
+ provider = self.provider.lower()
344
+
345
+ if provider == "openai":
346
+ # OpenAI returns a response object; we must read and decode it
347
+ result_content = await self.client.files.content(output_file_id)
348
+ # result_content.read() returns bytes
349
+ return result_content.read().decode('utf-8')
350
+
351
+ elif provider == "google":
352
+ # Google's download method returns the full bytes directly.
353
+ # Note: The current Google Gen AI SDK download is synchronous.
354
+ # We wrap it in to_thread to keep the function non-blocking.
355
+ result_bytes = await asyncio.to_thread(
356
+ self.client.files.download,
357
+ file=output_file_id
358
+ )
359
+ return result_bytes.decode('utf-8')
360
+
361
+ else:
362
+ raise ValueError(f"Unsupported provider: {provider}")
363
+
364
+ async def get_batch_status(self, batch_id: str):
365
+ provider = self.provider.lower()
366
+
367
+ if provider == "openai":
368
+ job = await self.client.batches.retrieve(batch_id)
369
+ # OpenAI property: job.request_counts
370
+ total = job.request_counts.total if job.request_counts else 0
371
+ completed = job.request_counts.completed if job.request_counts else 0
372
+ failed = job.request_counts.failed if job.request_counts else 0
373
+
374
+ return {
375
+ "status": job.status,
376
+ "is_completed": job.status == "completed",
377
+ "is_terminal": job.status in ["failed", "expired", "cancelled"],
378
+ "start_time": job.in_progress_at,
379
+ "output_file_id": job.output_file_id,
380
+ "stats": {
381
+ "total": total,
382
+ "completed": completed,
383
+ "failed": failed
384
+ }
385
+ }
386
+
387
+ elif provider == "google":
388
+ job = await asyncio.to_thread(self.client.batches.get, name=batch_id)
389
+
390
+ # Gemini batches do not provide per-request stats
391
+ total = 0
392
+ completed = 0
393
+ failed = 0
394
+
395
+ # Job state
396
+ state = job.state.name if hasattr(job.state, 'name') else job.state
397
+
398
+ # Output file is in job.dest.file_name
399
+ output_file_id = job.dest.file_name if hasattr(job, 'dest') and job.dest else None
400
+
401
+ return {
402
+ "status": state,
403
+ "is_completed": state == "JOB_STATE_SUCCEEDED",
404
+ "is_terminal": state in ["JOB_STATE_FAILED", "JOB_STATE_CANCELLED", "JOB_STATE_EXPIRED"],
405
+ "start_time": to_timestamp(job.create_time) if hasattr(job, "create_time") else None,
406
+ "output_file_id": output_file_id,
407
+ "stats": {
408
+ "total": total,
409
+ "completed": completed,
410
+ "failed": failed
411
+ }
412
+ }
@@ -0,0 +1,59 @@
1
+ import os
2
+ import configparser
3
+ from typing import Any
4
+ from .exceptions import LLMConfigError
5
+
6
+
7
+ def load_llm_config(
8
+ section_name: str,
9
+ config_file_path: str | None = None,
10
+ config_dict: dict | None = None,
11
+ ) -> dict[str, Any]:
12
+ """
13
+ Load LLM configuration from:
14
+ 1. Explicit config dict
15
+ 2. Explicit config file path
16
+ 3. Environment variables fallback
17
+ """
18
+
19
+ parser = configparser.ConfigParser()
20
+
21
+ if config_dict:
22
+ parser.read_dict(config_dict)
23
+
24
+ elif config_file_path:
25
+ if not os.path.exists(config_file_path):
26
+ raise LLMConfigError(f"Config file not found: {config_file_path}")
27
+ parser.read(config_file_path)
28
+
29
+ else:
30
+ # ENV fallback (minimal but safe)
31
+ provider = os.getenv("LLM_PROVIDER")
32
+ model = os.getenv("LLM_MODEL")
33
+
34
+ if not provider or not model:
35
+ raise LLMConfigError(
36
+ "No config provided. Set config_file_path, config dict, "
37
+ "or env vars LLM_PROVIDER and LLM_MODEL."
38
+ )
39
+
40
+ return {
41
+ "provider": provider,
42
+ "model": model,
43
+ "model_batch": os.getenv("LLM_MODEL_BATCH"),
44
+ "base_url": os.getenv("LLM_BASE_URL"),
45
+ "api_version": os.getenv("LLM_API_VERSION"),
46
+ }
47
+
48
+ if not parser.has_section(section_name):
49
+ raise LLMConfigError(f"Config section '{section_name}' not found")
50
+
51
+ defaults = (
52
+ dict(parser.items("default_settings"))
53
+ if parser.has_section("default_settings")
54
+ else {}
55
+ )
56
+
57
+ section = dict(parser.items(section_name))
58
+
59
+ return {**defaults, **section}
@@ -0,0 +1,5 @@
1
+ class LLMConfigError(Exception):
2
+ pass
3
+
4
+ class LLMProviderError(Exception):
5
+ pass
@@ -0,0 +1,12 @@
1
+ from datetime import datetime
2
+ from dateutil.parser import isoparse
3
+
4
+
5
+ def to_timestamp(value):
6
+ if value is None:
7
+ return None
8
+ if isinstance(value, datetime):
9
+ return value.timestamp()
10
+ if isinstance(value, str):
11
+ return isoparse(value).timestamp()
12
+ raise TypeError(f"Unsupported time type: {type(value)}")
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: ehs-llm-client
3
+ Version: 0.1.0
4
+ Summary: Unified LLM client. Currently supports Openai, Azure Openai and Google Gemini
5
+ Requires-Dist: openai
6
+ Requires-Dist: google-genai
7
+ Requires-Dist: python-dotenv
8
+ Requires-Dist: python-dateutil
@@ -0,0 +1,9 @@
1
+ ehs_llm_client/__init__.py,sha256=AzmovmTtvj21YseUb4RVYBt4pq2JT6RpfZb41dyLK6c,46
2
+ ehs_llm_client/client.py,sha256=fbyMOPTw_NayG2GCbiIQbwZHyBrU76s9Lirh5uBpv3o,16201
3
+ ehs_llm_client/config.py,sha256=1sYOvX5biBUoLYMa22CiJ_MElRlBfBgQsatfCTs_kN0,1698
4
+ ehs_llm_client/exceptions.py,sha256=UEihJBOZufiMPjkfStEuiE7rhJTUJXib6puGN65FykA,92
5
+ ehs_llm_client/utils.py,sha256=zVLF8obG1cKJ-9cyHulzoTZBT0LQtGkx4cdiG2UaKuk,352
6
+ ehs_llm_client-0.1.0.dist-info/METADATA,sha256=4Zte70VYMJw3Fe3rz5UgM_x9P8Q4Xx-SgfJxFIUMWqU,263
7
+ ehs_llm_client-0.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
8
+ ehs_llm_client-0.1.0.dist-info/top_level.txt,sha256=8xnm7U82x1dZYaanP_IxTIw7e4PKG-GEpmaR0jU6FSE,15
9
+ ehs_llm_client-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ ehs_llm_client