entityxtract 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
entityxtract/config.py ADDED
@@ -0,0 +1,46 @@
1
+ """
2
+ Gets configurations from .env file or environment variables.
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from dotenv import load_dotenv
8
+ from typing import Any, Optional
9
+ from entityxtract.logging_config import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ DOTENV_PATH = Path(__file__).parent.parent.parent / ".env"
14
+ if not DOTENV_PATH.exists():
15
+ logger.info(
16
+ f"'.env' file not found at {DOTENV_PATH}. Skipping loading environment variables from .env file."
17
+ )
18
+ else:
19
+ logger.info(f"Loading environment variables from {DOTENV_PATH}")
20
+
21
+ load_dotenv(dotenv_path=DOTENV_PATH, override=True)
22
+
23
+
24
+ def get_config(key: str) -> Optional[Any]:
25
+ """Get a particular environment variable
26
+
27
+ Args:
28
+ key (str): The environment variable key to retrieve.
29
+
30
+ Returns:
31
+ Optional[Any]: The value of the environment variable or None if not found.
32
+ """
33
+
34
+ # Environment variable takes precedence
35
+ env_value = os.environ.get(key, None)
36
+
37
+ if env_value is not None:
38
+ return env_value
39
+
40
+ logger.warning(f"Environment variable '{key}' not found.")
41
+
42
+ return None
43
+
44
+
45
+ if __name__ == "__main__":
46
+ print("OPENAI_DEFAULT_MODEL =", get_config("OPENAI_DEFAULT_MODEL"))
@@ -0,0 +1,433 @@
1
+ import base64
2
+ import json
3
+ import time
4
+ import concurrent.futures
5
+ from typing import Any, Dict, Optional, Tuple
6
+
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.messages import HumanMessage, SystemMessage
9
+
10
+ from . import extractor_types
11
+ from .config import get_config
12
+ from .prompts import get_prompt, get_system_prompt
13
+ from entityxtract.logging_config import get_logger
14
+
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ def pil_img_to_base64(img) -> str:
20
+ """
21
+ Convert a PIL image (or first image in a list) to a base64-encoded JPEG string.
22
+ Args:
23
+ img: PIL Image object, list of PIL Images, or image-like object
24
+ Returns:
25
+ Base64-encoded string of the image
26
+ """
27
+ from io import BytesIO
28
+ from PIL import Image as PILImage
29
+
30
+ # If a list of images is provided, use the first one
31
+ if isinstance(img, list):
32
+ if not img:
33
+ raise ValueError("Empty image list provided")
34
+ img = img[0]
35
+
36
+ # If it's not a PIL image, attempt to open from bytes-like object
37
+ if not hasattr(img, "save"):
38
+ try:
39
+ img = PILImage.open(BytesIO(img))
40
+ except Exception as e:
41
+ logger.error(f"Unable to open image from provided data: {e}")
42
+ raise
43
+
44
+ # Ensure RGB mode for JPEG compatibility
45
+ if getattr(img, "mode", None) != "RGB":
46
+ img = img.convert("RGB")
47
+
48
+ from io import BytesIO as _BytesIO
49
+
50
+ buffered = _BytesIO()
51
+ img.save(buffered, format="JPEG", quality=85)
52
+ img_str = base64.b64encode(buffered.getvalue()).decode()
53
+ return img_str
54
+
55
+
56
+ # ---------------------------
57
+ # Internal helpers
58
+ # ---------------------------
59
+
60
+
61
+ def _build_model(config: extractor_types.ExtractionConfig) -> ChatOpenAI:
62
+ model_kwargs = {"response_format": {"type": "json_object"}}
63
+ return ChatOpenAI(
64
+ openai_api_key=get_config("OPENAI_API_KEY"),
65
+ openai_api_base=get_config("OPENAI_API_BASE"),
66
+ model_name=config.model_name,
67
+ temperature=config.temperature,
68
+ model_kwargs=model_kwargs,
69
+ )
70
+
71
+
72
+ def _build_messages(
73
+ doc: extractor_types.Document,
74
+ object_to_extract: extractor_types.ExtractableObjectTypes,
75
+ config: extractor_types.ExtractionConfig,
76
+ ):
77
+ system_prompt = get_system_prompt()
78
+ prompt = get_prompt(object_to_extract)
79
+
80
+ if extractor_types.FileInputMode.TEXT in config.file_input_modes:
81
+ prompt = prompt.replace("{{text}}", f"\n\n{doc.text}")
82
+
83
+ # Add Attachments (keep existing structure as requested)
84
+ attachments = []
85
+ if extractor_types.FileInputMode.IMAGE in config.file_input_modes:
86
+ if doc.image is not None:
87
+ try:
88
+ attachments.append(
89
+ {
90
+ "type": "image_url",
91
+ "image_url": {
92
+ "url": f"data:image/jpeg;base64,{pil_img_to_base64(doc.image)}"
93
+ },
94
+ }
95
+ )
96
+ except Exception as e:
97
+ logger.warning(f"Skipping image attachment due to error: {e}")
98
+ else:
99
+ logger.debug("No image data available; skipping image attachment")
100
+
101
+ if extractor_types.FileInputMode.FILE in config.file_input_modes:
102
+ attachments.append(
103
+ {
104
+ "type": "file",
105
+ "file": {
106
+ "filename": "document.pdf",
107
+ "file_data": f"data:application/pdf;base64,{base64.b64encode(doc.binary).decode()}",
108
+ },
109
+ }
110
+ )
111
+
112
+ messages = [
113
+ SystemMessage(content=system_prompt),
114
+ HumanMessage(content=attachments + [{"type": "text", "text": prompt}]),
115
+ ]
116
+ return messages
117
+
118
+
119
+ def _parse_token_usage(
120
+ response: Any, response_dict: Optional[Dict[str, Any]]
121
+ ) -> Tuple[Optional[int], Optional[int], Dict[str, Any], Any]:
122
+ input_tokens = None
123
+ output_tokens = None
124
+ usage_meta = (
125
+ response_dict.get("usage_metadata") if isinstance(response_dict, dict) else None
126
+ ) or getattr(response, "usage_metadata", None)
127
+ if isinstance(usage_meta, dict):
128
+ input_tokens = usage_meta.get("input_tokens") or usage_meta.get("prompt_tokens")
129
+ output_tokens = usage_meta.get("output_tokens") or usage_meta.get(
130
+ "completion_tokens"
131
+ )
132
+ resp_meta = (
133
+ (
134
+ response_dict.get("response_metadata")
135
+ if isinstance(response_dict, dict)
136
+ else None
137
+ )
138
+ or getattr(response, "response_metadata", None)
139
+ or {}
140
+ )
141
+ if (input_tokens is None or output_tokens is None) and isinstance(resp_meta, dict):
142
+ tu = resp_meta.get("token_usage", {}) or {}
143
+ input_tokens = input_tokens or tu.get("input_tokens") or tu.get("prompt_tokens")
144
+ output_tokens = (
145
+ output_tokens or tu.get("output_tokens") or tu.get("completion_tokens")
146
+ )
147
+ return input_tokens, output_tokens, resp_meta, usage_meta
148
+
149
+
150
+ def _extract_cost_from_metadata(
151
+ response_dict: Optional[Dict[str, Any]],
152
+ resp_meta: Dict[str, Any],
153
+ ) -> Optional[float]:
154
+ """Prefer inline provider-reported cost when available.
155
+
156
+ OpenRouter currently includes cost in response metadata token_usage.cost for
157
+ many requests, which is more reliable than a follow-up generation lookup.
158
+ """
159
+ for source in (resp_meta, response_dict):
160
+ if not isinstance(source, dict):
161
+ continue
162
+ tu = source.get("token_usage")
163
+ if isinstance(tu, dict) and tu.get("cost") is not None:
164
+ return tu["cost"]
165
+ rm = source.get("response_metadata")
166
+ if isinstance(rm, dict):
167
+ tu = rm.get("token_usage")
168
+ if isinstance(tu, dict) and tu.get("cost") is not None:
169
+ return tu["cost"]
170
+ return None
171
+
172
+
173
+ def _clean_response_content(response: Any) -> str:
174
+ content = getattr(response, "content", "")
175
+ if not isinstance(content, str):
176
+ content = str(content)
177
+ # Strip potential markdown code fences
178
+ return content.replace("```json", "").replace("```", "").strip()
179
+
180
+
181
+ def _fetch_generation_cost(
182
+ config: extractor_types.ExtractionConfig, resp_meta: Dict[str, Any]
183
+ ) -> Tuple[Optional[float], Optional[Dict[str, Any]]]:
184
+ cost = None
185
+ generation_stats = None
186
+ try:
187
+ generation_id = resp_meta.get("id") if isinstance(resp_meta, dict) else None
188
+ if config.calculate_costs and generation_id:
189
+ api_base = get_config("OPENAI_API_BASE")
190
+ api_key = get_config("OPENAI_API_KEY")
191
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
192
+ url = f"{api_base.rstrip('/')}/generation"
193
+ logger.debug(
194
+ f"Cost lookup: id={generation_id} base={api_base} auth={'yes' if api_key else 'no'}"
195
+ )
196
+ try:
197
+ import requests as _requests
198
+
199
+ # Retry a few times in case the generation record isn't immediately available
200
+ delays = [0.5, 1.0, 2.0, 4.0]
201
+ last_status = None
202
+ last_text = ""
203
+ for attempt, delay in enumerate(delays, start=1):
204
+ resp = _requests.get(
205
+ url, params={"id": generation_id}, headers=headers, timeout=10
206
+ )
207
+ last_status = resp.status_code
208
+ last_text = resp.text[:200]
209
+ if resp.ok:
210
+ generation_stats = resp.json()
211
+ try:
212
+ data = generation_stats.get("data", {})
213
+ cost = data.get("total_cost")
214
+ logger.debug(f"Cost lookup success: total_cost={cost}")
215
+ except Exception:
216
+ logger.warning(
217
+ "Generation stats JSON missing 'data.total_cost'."
218
+ )
219
+ break
220
+ # Retry on 404 as record may not be indexed yet
221
+ if resp.status_code == 404 and attempt < len(delays):
222
+ logger.debug(
223
+ f"Generation not found yet (404). Retry {attempt}/{len(delays) - 1} after {delay}s..."
224
+ )
225
+ time.sleep(delay)
226
+ continue
227
+ # Non-retryable or final attempt
228
+ logger.warning(
229
+ f"Generation stats request failed for id={generation_id}: {resp.status_code} {resp.text[:200]}"
230
+ )
231
+ break
232
+ else:
233
+ logger.warning(
234
+ f"Generation stats request failed for id={generation_id}: {last_status} {last_text}"
235
+ )
236
+ except ImportError as ie:
237
+ logger.warning(
238
+ f"Requests library not installed; skipping generation stats fetch: {ie}"
239
+ )
240
+ except Exception as e:
241
+ logger.warning(f"Error calling generation stats endpoint: {e}")
242
+ elif config.calculate_costs and not generation_id:
243
+ logger.debug(
244
+ "calculate_costs enabled but no generation id found in response metadata."
245
+ )
246
+ except Exception as e:
247
+ logger.warning(f"Failed to process generation stats: {e}")
248
+ return cost, generation_stats
249
+
250
+
251
+ # ---------------------------
252
+ # Public API
253
+ # ---------------------------
254
+
255
+
256
+ def extract_object(
257
+ doc: extractor_types.Document,
258
+ object_to_extract: extractor_types.ExtractableObjectTypes,
259
+ config: extractor_types.ExtractionConfig,
260
+ ) -> extractor_types.ExtractionResult:
261
+ """
262
+ Extract specified objects from the document using the provided configuration.
263
+ Args:
264
+ doc: Document object containing the data to extract from
265
+ object_to_extract: The object (e.g., table, string, etc) to extract
266
+ config: Configuration for the extraction process
267
+ Returns:
268
+ Extracted data as an ExtractionResult object
269
+ """
270
+
271
+ logger.debug(
272
+ f"Extracting {object_to_extract.name} {type(object_to_extract)}. Config: {config}"
273
+ )
274
+
275
+ messages = _build_messages(doc, object_to_extract, config)
276
+ model = _build_model(config)
277
+
278
+ # Retry loop using config.max_retries
279
+ max_retries = max(1, int(config.max_retries or 1))
280
+ last_error_msg: Optional[str] = None
281
+ last_input_tokens: Optional[int] = None
282
+ last_output_tokens: Optional[int] = None
283
+ last_response_raw_payload: Optional[Dict[str, Any]] = None
284
+ last_cost: Optional[float] = None
285
+
286
+ for attempt in range(1, max_retries + 1):
287
+ try:
288
+ response = model.invoke(messages)
289
+ # Serialize response to dict for complete metadata (incl. generation id)
290
+ try:
291
+ response_dict = response.dict()
292
+ except Exception:
293
+ response_dict = None
294
+
295
+ input_tokens, output_tokens, resp_meta, usage_meta = _parse_token_usage(
296
+ response, response_dict
297
+ )
298
+
299
+ inline_cost = _extract_cost_from_metadata(response_dict, resp_meta)
300
+ if inline_cost is not None:
301
+ logger.debug(f"Using inline response cost from metadata: {inline_cost}")
302
+ cost, generation_stats = inline_cost, None
303
+ else:
304
+ cost, generation_stats = _fetch_generation_cost(config, resp_meta)
305
+
306
+ content_str = _clean_response_content(response)
307
+
308
+ if isinstance(response_dict, dict):
309
+ logger.debug(f"Raw chat response dict: {response_dict}")
310
+ response_raw_payload = dict(response_dict)
311
+ else:
312
+ response_raw_payload = {
313
+ "content": getattr(response, "content", ""),
314
+ "response_metadata": resp_meta,
315
+ "usage_metadata": usage_meta,
316
+ }
317
+ if generation_stats is not None:
318
+ response_raw_payload["generation_stats"] = generation_stats
319
+
320
+ # Save last-attempt metadata in case of JSON parse failure
321
+ last_input_tokens = input_tokens
322
+ last_output_tokens = output_tokens
323
+ last_response_raw_payload = response_raw_payload
324
+ last_cost = cost
325
+
326
+ response_json = json.loads(content_str)
327
+ return extractor_types.ExtractionResult(
328
+ extracted_data=response_json,
329
+ response_raw=response_raw_payload,
330
+ success=True,
331
+ message="Extraction successful",
332
+ input_tokens=input_tokens,
333
+ output_tokens=output_tokens,
334
+ cost=cost,
335
+ )
336
+
337
+ except json.JSONDecodeError as e:
338
+ # Content preview for debugging
339
+ preview = ""
340
+ try:
341
+ preview = content_str[:200].replace("\n", " ")
342
+ except Exception:
343
+ pass
344
+ last_error_msg = (
345
+ f"Response was not valid JSON: {e}. Content preview: {preview}"
346
+ )
347
+ logger.error(
348
+ f"Failed to parse JSON from model response on attempt {attempt}/{max_retries}: {e}"
349
+ )
350
+
351
+ except Exception as e:
352
+ last_error_msg = f"Model invocation failed: {e}"
353
+ logger.error(
354
+ f"Model invocation failed on attempt {attempt}/{max_retries}: {e}"
355
+ )
356
+
357
+ # Backoff and retry if attempts remain
358
+ if attempt < max_retries:
359
+ sleep_s = min(2 ** (attempt - 1), 8)
360
+ logger.debug(f"Retrying extraction in {sleep_s}s...")
361
+ time.sleep(sleep_s)
362
+
363
+ # All attempts failed
364
+ return extractor_types.ExtractionResult(
365
+ extracted_data=None,
366
+ response_raw=last_response_raw_payload,
367
+ success=False,
368
+ message=last_error_msg or f"Failed after {max_retries} attempts",
369
+ input_tokens=last_input_tokens,
370
+ output_tokens=last_output_tokens,
371
+ cost=last_cost,
372
+ )
373
+
374
+
375
+ def extract_objects(
376
+ doc: extractor_types.Document,
377
+ objects_to_extract: extractor_types.ObjectsToExtract,
378
+ ) -> extractor_types.ExtractionResults:
379
+ """
380
+ Extract multiple objects from the document concurrently.
381
+ Args:
382
+ doc: Document object containing the data to extract from
383
+ objects_to_extract: ObjectsToExtract object containing the list of objects and config
384
+ Returns:
385
+ ExtractionResults object containing the results of the extractions
386
+ """
387
+
388
+ # NOTE: use objects_to_extract.config.parallel_requests to set max_workers
389
+ max_workers = max(1, int(objects_to_extract.config.parallel_requests or 1))
390
+
391
+ results: Dict[str, extractor_types.ExtractionResult] = {}
392
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
393
+ future_to_name = {
394
+ executor.submit(
395
+ extract_object, doc, obj, objects_to_extract.config
396
+ ): obj.name
397
+ for obj in objects_to_extract.objects
398
+ }
399
+
400
+ for future in concurrent.futures.as_completed(future_to_name):
401
+ obj_name = future_to_name[future]
402
+ try:
403
+ result = future.result()
404
+ results[obj_name] = result
405
+ except Exception as e:
406
+ logger.error(f"Error extracting {obj_name}: {e}")
407
+ results[obj_name] = extractor_types.ExtractionResult(
408
+ extracted_data=None,
409
+ response_raw=None,
410
+ success=False,
411
+ message=str(e),
412
+ )
413
+
414
+ overall_success = all(result.success for result in results.values())
415
+
416
+ return extractor_types.ExtractionResults(
417
+ results=results,
418
+ success=overall_success,
419
+ message=None if overall_success else "Some extractions failed",
420
+ total_input_tokens=sum(
421
+ (res.input_tokens or 0)
422
+ for res in results.values()
423
+ if res.input_tokens is not None
424
+ ),
425
+ total_output_tokens=sum(
426
+ (res.output_tokens or 0)
427
+ for res in results.values()
428
+ if res.output_tokens is not None
429
+ ),
430
+ total_cost=(
431
+ sum(costs) if (costs := [r.cost for r in results.values() if r.cost is not None]) else None
432
+ ),
433
+ )