groundx 2.4.5__py3-none-any.whl → 2.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of groundx might be problematic. Click here for more details.
- groundx/core/client_wrapper.py +2 -2
- groundx/extract/__init__.py +38 -0
- groundx/extract/agents/__init__.py +7 -0
- groundx/extract/agents/agent.py +202 -0
- groundx/extract/classes/__init__.py +27 -0
- groundx/extract/classes/agent.py +22 -0
- groundx/extract/classes/api.py +15 -0
- groundx/extract/classes/document.py +311 -0
- groundx/extract/classes/field.py +88 -0
- groundx/extract/classes/groundx.py +123 -0
- groundx/extract/classes/post_process.py +33 -0
- groundx/extract/classes/prompt.py +36 -0
- groundx/extract/classes/settings.py +169 -0
- groundx/extract/classes/test_document.py +126 -0
- groundx/extract/classes/test_field.py +43 -0
- groundx/extract/classes/test_groundx.py +188 -0
- groundx/extract/classes/test_prompt.py +68 -0
- groundx/extract/classes/test_settings.py +515 -0
- groundx/extract/classes/test_utility.py +81 -0
- groundx/extract/classes/utility.py +193 -0
- groundx/extract/services/.DS_Store +0 -0
- groundx/extract/services/__init__.py +14 -0
- groundx/extract/services/csv.py +76 -0
- groundx/extract/services/logger.py +127 -0
- groundx/extract/services/logging_cfg.py +55 -0
- groundx/extract/services/ratelimit.py +104 -0
- groundx/extract/services/sheets_client.py +160 -0
- groundx/extract/services/status.py +197 -0
- groundx/extract/services/upload.py +73 -0
- groundx/extract/services/upload_minio.py +122 -0
- groundx/extract/services/upload_s3.py +84 -0
- groundx/extract/services/utility.py +52 -0
- {groundx-2.4.5.dist-info → groundx-2.4.10.dist-info}/METADATA +17 -1
- {groundx-2.4.5.dist-info → groundx-2.4.10.dist-info}/RECORD +36 -5
- {groundx-2.4.5.dist-info → groundx-2.4.10.dist-info}/LICENSE +0 -0
- {groundx-2.4.5.dist-info → groundx-2.4.10.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import dateparser, typing
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExtractedField(BaseModel):
|
|
7
|
+
confidence: typing.Optional[str] = None
|
|
8
|
+
conflicts: typing.List[typing.Any] = []
|
|
9
|
+
key: str
|
|
10
|
+
|
|
11
|
+
value: typing.Union[str, float, typing.List[typing.Any]] = ""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
value: typing.Union[str, float, typing.List[typing.Any]],
|
|
16
|
+
**data: typing.Any,
|
|
17
|
+
):
|
|
18
|
+
super().__init__(**data)
|
|
19
|
+
|
|
20
|
+
self.set_value(value)
|
|
21
|
+
|
|
22
|
+
def contains(self, other: "ExtractedField") -> bool:
|
|
23
|
+
self_val = self.get_value()
|
|
24
|
+
other_val = other.get_value()
|
|
25
|
+
if not (isinstance(self_val, (str, float, int))):
|
|
26
|
+
raise Exception(f"unexpected self field value type [{type(self_val)}]")
|
|
27
|
+
|
|
28
|
+
if self.equal_to_value(other_val):
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
if other_val in self.conflicts:
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
def equal_to_field(self, other: "ExtractedField") -> bool:
|
|
37
|
+
self_val = self.get_value()
|
|
38
|
+
other_val = other.get_value()
|
|
39
|
+
if not (isinstance(self_val, (str, float, int))):
|
|
40
|
+
raise Exception(f"unexpected self field value type [{type(self_val)}]")
|
|
41
|
+
|
|
42
|
+
return self.equal_to_value(other_val)
|
|
43
|
+
|
|
44
|
+
def equal_to_value(self, other: typing.Any) -> bool:
|
|
45
|
+
if not (isinstance(other, (str, float, int))):
|
|
46
|
+
raise Exception(f"unexpected value type [{type(other)}]")
|
|
47
|
+
|
|
48
|
+
exist = self.get_value()
|
|
49
|
+
if isinstance(exist, int):
|
|
50
|
+
exist = float(exist)
|
|
51
|
+
if isinstance(other, int):
|
|
52
|
+
other = float(other)
|
|
53
|
+
if isinstance(exist, str):
|
|
54
|
+
exist = exist.lower()
|
|
55
|
+
if isinstance(other, str):
|
|
56
|
+
other = other.lower()
|
|
57
|
+
|
|
58
|
+
return type(other) == type(exist) and other == exist
|
|
59
|
+
|
|
60
|
+
def get_value(self) -> typing.Union[str, float, typing.List[typing.Any]]:
|
|
61
|
+
return self.value
|
|
62
|
+
|
|
63
|
+
def remove_conflict(self, value: typing.Any) -> None:
|
|
64
|
+
if value in self.conflicts:
|
|
65
|
+
self.conflicts.remove(value)
|
|
66
|
+
if not self.equal_to_value(value):
|
|
67
|
+
self.conflicts.append(self.get_value())
|
|
68
|
+
|
|
69
|
+
def set_value(
|
|
70
|
+
self, value: typing.Union[str, float, typing.List[typing.Any]]
|
|
71
|
+
) -> None:
|
|
72
|
+
if isinstance(value, int):
|
|
73
|
+
self.value = float(value)
|
|
74
|
+
elif isinstance(value, str) and "date" in self.key.lower():
|
|
75
|
+
try:
|
|
76
|
+
dt = dateparser.parse(value)
|
|
77
|
+
if dt is None:
|
|
78
|
+
self.value = value
|
|
79
|
+
else:
|
|
80
|
+
self.value = dt.strftime("%Y-%m-%d")
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(f"date error [{value}]: [{e}]")
|
|
83
|
+
self.value = value
|
|
84
|
+
else:
|
|
85
|
+
self.value = value
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
ExtractedField.model_rebuild()
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import json, requests, typing
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GroundXDocument(BaseModel):
|
|
8
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
9
|
+
base_url: str
|
|
10
|
+
document_id: str = Field(alias="documentID")
|
|
11
|
+
task_id: str = Field(alias="taskID")
|
|
12
|
+
|
|
13
|
+
def xray_url(self, base: typing.Optional[str] = None) -> str:
|
|
14
|
+
if not base:
|
|
15
|
+
base = self.base_url
|
|
16
|
+
if base.endswith("/"):
|
|
17
|
+
base = base[:-1]
|
|
18
|
+
return f"{base}/layout/processed/{self.task_id}/{self.document_id}-xray.json"
|
|
19
|
+
|
|
20
|
+
def xray(
|
|
21
|
+
self,
|
|
22
|
+
clear_cache: bool = False,
|
|
23
|
+
is_test: bool = False,
|
|
24
|
+
base: typing.Optional[str] = None,
|
|
25
|
+
) -> "XRayDocument":
|
|
26
|
+
return XRayDocument.download(
|
|
27
|
+
self, base=base, clear_cache=clear_cache, is_test=is_test
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GroundXResponse(BaseModel):
|
|
32
|
+
code: int
|
|
33
|
+
document_id: str = Field(alias="documentID")
|
|
34
|
+
model_id: int = Field(alias="modelID")
|
|
35
|
+
processor_id: int = Field(alias="processorID")
|
|
36
|
+
result_url: str = Field(alias="resultURL")
|
|
37
|
+
task_id: str = Field(alias="taskID")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BoundingBox(BaseModel):
|
|
41
|
+
bottomRightX: float
|
|
42
|
+
bottomRightY: float
|
|
43
|
+
topLeftX: float
|
|
44
|
+
topLeftY: float
|
|
45
|
+
corrected: typing.Optional[bool]
|
|
46
|
+
pageNumber: typing.Optional[int]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Chunk(BaseModel):
|
|
50
|
+
boundingBoxes: typing.Optional[typing.List[BoundingBox]] = []
|
|
51
|
+
chunk: typing.Optional[str] = None
|
|
52
|
+
contentType: typing.Optional[typing.List[str]] = []
|
|
53
|
+
json_: typing.Optional[typing.List[typing.Any]] = Field(None, alias="json")
|
|
54
|
+
multimodalUrl: typing.Optional[str] = None
|
|
55
|
+
narrative: typing.Optional[typing.List[str]] = None
|
|
56
|
+
pageNumbers: typing.Optional[typing.List[int]] = []
|
|
57
|
+
sectionSummary: typing.Optional[str] = None
|
|
58
|
+
suggestedText: typing.Optional[str] = None
|
|
59
|
+
text: typing.Optional[str] = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DocumentPage(BaseModel):
|
|
63
|
+
chunks: typing.List[Chunk]
|
|
64
|
+
height: float
|
|
65
|
+
pageNumber: int
|
|
66
|
+
pageUrl: str
|
|
67
|
+
width: float
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class XRayDocument(BaseModel):
|
|
71
|
+
chunks: typing.List[Chunk]
|
|
72
|
+
documentPages: typing.List[DocumentPage] = []
|
|
73
|
+
sourceUrl: str
|
|
74
|
+
fileKeywords: typing.Optional[str] = None
|
|
75
|
+
fileName: typing.Optional[str] = None
|
|
76
|
+
fileType: typing.Optional[str] = None
|
|
77
|
+
fileSummary: typing.Optional[str] = None
|
|
78
|
+
language: typing.Optional[str] = None
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def download(
|
|
82
|
+
cls,
|
|
83
|
+
gx_doc: GroundXDocument,
|
|
84
|
+
clear_cache: bool = False,
|
|
85
|
+
is_test: bool = False,
|
|
86
|
+
base: typing.Optional[str] = None,
|
|
87
|
+
) -> "XRayDocument":
|
|
88
|
+
cache_dir = Path(__file__).resolve().parent.parent / "cache"
|
|
89
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
cache_file = cache_dir / f"{gx_doc.document_id}-xray.json"
|
|
91
|
+
|
|
92
|
+
if not clear_cache and cache_file.exists():
|
|
93
|
+
try:
|
|
94
|
+
with cache_file.open("r", encoding="utf-8") as f:
|
|
95
|
+
payload = json.load(f)
|
|
96
|
+
|
|
97
|
+
except Exception as e:
|
|
98
|
+
raise RuntimeError(
|
|
99
|
+
f"Error loading cached X-ray JSON from {cache_file}: {e}"
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
url = gx_doc.xray_url(base=base)
|
|
103
|
+
try:
|
|
104
|
+
resp = requests.get(url)
|
|
105
|
+
resp.raise_for_status()
|
|
106
|
+
except requests.RequestException as e:
|
|
107
|
+
raise RuntimeError(f"Error fetching X-ray JSON from {url}: {e}")
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
payload = resp.json()
|
|
111
|
+
except ValueError as e:
|
|
112
|
+
raise RuntimeError(f"Invalid JSON returned from {url}: {e}")
|
|
113
|
+
|
|
114
|
+
if is_test is False:
|
|
115
|
+
try:
|
|
116
|
+
with cache_file.open("w", encoding="utf-8") as f:
|
|
117
|
+
json.dump(payload, f)
|
|
118
|
+
except Exception as e:
|
|
119
|
+
print(
|
|
120
|
+
f"Warning: failed to write X-ray JSON cache to {cache_file}: {e}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return cls(**payload)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def check_map(
|
|
5
|
+
fty: str,
|
|
6
|
+
sty: str,
|
|
7
|
+
val: str,
|
|
8
|
+
mp: typing.Dict[str, typing.Dict[str, str]],
|
|
9
|
+
should_warn: bool = True,
|
|
10
|
+
) -> typing.Optional[str]:
|
|
11
|
+
if sty not in mp:
|
|
12
|
+
sty = ""
|
|
13
|
+
if sty not in mp:
|
|
14
|
+
return None
|
|
15
|
+
|
|
16
|
+
vl = val.lower().strip()
|
|
17
|
+
|
|
18
|
+
nmp = mp[sty]
|
|
19
|
+
if vl not in nmp:
|
|
20
|
+
if should_warn:
|
|
21
|
+
print(f"[arcadia-v1] {fty} not found [{sty}] [{vl}]")
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
return nmp[vl]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def check_valid(sty: str, val: str, valid: typing.Dict[str, typing.List[str]]) -> bool:
|
|
28
|
+
vl = val.lower().strip()
|
|
29
|
+
|
|
30
|
+
if sty not in valid:
|
|
31
|
+
sty = ""
|
|
32
|
+
|
|
33
|
+
return sty in valid and vl in valid[sty]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from .utility import str_to_type_sequence
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Prompt(BaseModel):
|
|
9
|
+
attr_name: str
|
|
10
|
+
prompt: str
|
|
11
|
+
type: typing.Union[str, typing.List[str]]
|
|
12
|
+
|
|
13
|
+
class Config:
|
|
14
|
+
validate_by_name = True
|
|
15
|
+
|
|
16
|
+
def valid_value(self, value: typing.Any) -> bool:
|
|
17
|
+
ty = self.type
|
|
18
|
+
|
|
19
|
+
types: typing.List[typing.Type[typing.Any]] = []
|
|
20
|
+
if isinstance(ty, list):
|
|
21
|
+
for t in ty:
|
|
22
|
+
if t == "int" or t == "float":
|
|
23
|
+
types.extend([int, float])
|
|
24
|
+
elif t == "str":
|
|
25
|
+
types.append(str)
|
|
26
|
+
|
|
27
|
+
return isinstance(value, tuple(types))
|
|
28
|
+
|
|
29
|
+
exp = str_to_type_sequence(ty)
|
|
30
|
+
for et in exp:
|
|
31
|
+
if et in (int, float):
|
|
32
|
+
types.extend([int, float])
|
|
33
|
+
else:
|
|
34
|
+
types.append(et)
|
|
35
|
+
types = list(dict.fromkeys(types))
|
|
36
|
+
return isinstance(value, tuple(types))
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import json, typing, os
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
AWS_REGION: str = "AWS_REGION"
|
|
7
|
+
AWS_DEFAULT_REGION: str = "AWS_DEFAULT_REGION"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
GX_AGENT_KEY: str = "GROUNDX_AGENT_API_KEY"
|
|
11
|
+
CALLBACK_KEY: str = "GROUNDX_CALLBACK_API_KEY"
|
|
12
|
+
GCP_CREDENTIALS: str = "GCP_CREDENTIALS"
|
|
13
|
+
GX_API_KEY: str = "GROUNDX_API_KEY"
|
|
14
|
+
GX_KEY: str = "GROUNDX_ACCESS_KEY_ID"
|
|
15
|
+
GX_REGION: str = "GROUNDX_REGION"
|
|
16
|
+
GX_DEFAULT_REGION: str = "GROUNDX_DEFAULT_REGION"
|
|
17
|
+
GX_SECRET: str = "GROUNDX_SECRET_ACCESS_KEY"
|
|
18
|
+
GX_TOKEN: str = "GROUNDX_SESSION_TOKEN"
|
|
19
|
+
VALID_KEYS: str = "GROUNDX_VALID_API_KEYS"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AgentSettings(BaseModel):
|
|
23
|
+
api_base: typing.Optional[str] = None
|
|
24
|
+
api_key: typing.Optional[str] = None
|
|
25
|
+
imports: typing.List[str] = [
|
|
26
|
+
"csv",
|
|
27
|
+
"glob",
|
|
28
|
+
"io",
|
|
29
|
+
"json",
|
|
30
|
+
"markdown",
|
|
31
|
+
"numpy",
|
|
32
|
+
"os",
|
|
33
|
+
"pandas",
|
|
34
|
+
"posixpath",
|
|
35
|
+
"open",
|
|
36
|
+
"builtins.open",
|
|
37
|
+
"utils.safe_open",
|
|
38
|
+
"pydantic",
|
|
39
|
+
"typing",
|
|
40
|
+
]
|
|
41
|
+
max_steps: int = 7
|
|
42
|
+
model_id: str = "gpt-5-mini"
|
|
43
|
+
|
|
44
|
+
def get_api_key(self) -> str:
|
|
45
|
+
if self.api_key:
|
|
46
|
+
return self.api_key
|
|
47
|
+
|
|
48
|
+
key = os.environ.get(GX_AGENT_KEY)
|
|
49
|
+
if key:
|
|
50
|
+
return key
|
|
51
|
+
|
|
52
|
+
raise Exception(f"you must set a valid agent api_key")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ContainerSettings(BaseModel):
|
|
56
|
+
broker: str
|
|
57
|
+
cache_to: int = 300
|
|
58
|
+
google_sheets_drive_id: typing.Optional[str] = None
|
|
59
|
+
google_sheets_template_id: typing.Optional[str] = None
|
|
60
|
+
log_level: str = "info"
|
|
61
|
+
metrics_broker: typing.Optional[str] = None
|
|
62
|
+
refresh_to: int = 60
|
|
63
|
+
service: str
|
|
64
|
+
task_to: int = 600
|
|
65
|
+
upload: "ContainerUploadSettings"
|
|
66
|
+
workers: int
|
|
67
|
+
|
|
68
|
+
callback_api_key: typing.Optional[str] = None
|
|
69
|
+
valid_api_keys: typing.Optional[typing.List[str]] = None
|
|
70
|
+
|
|
71
|
+
def get_callback_api_key(self) -> str:
|
|
72
|
+
if self.callback_api_key:
|
|
73
|
+
return self.callback_api_key
|
|
74
|
+
|
|
75
|
+
key = os.environ.get(CALLBACK_KEY)
|
|
76
|
+
if key:
|
|
77
|
+
return key
|
|
78
|
+
|
|
79
|
+
raise Exception(f"you must set a callback_api_key")
|
|
80
|
+
|
|
81
|
+
def get_valid_api_keys(self) -> typing.List[str]:
|
|
82
|
+
if self.valid_api_keys:
|
|
83
|
+
return self.valid_api_keys
|
|
84
|
+
|
|
85
|
+
keys: typing.Optional[str] = os.environ.get(VALID_KEYS)
|
|
86
|
+
if not keys:
|
|
87
|
+
raise Exception(f"you must set an array of valid_api_keys")
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
data: typing.List[str] = json.loads(keys)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise Exception(f"you must set an array of valid_api_keys: {e}")
|
|
93
|
+
|
|
94
|
+
return data
|
|
95
|
+
|
|
96
|
+
def loglevel(self) -> str:
|
|
97
|
+
return self.log_level.upper()
|
|
98
|
+
|
|
99
|
+
def status_broker(self) -> str:
|
|
100
|
+
if self.metrics_broker:
|
|
101
|
+
return self.metrics_broker
|
|
102
|
+
|
|
103
|
+
return self.broker
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ContainerUploadSettings(BaseModel):
|
|
107
|
+
base_domain: str
|
|
108
|
+
base_path: str = "layout/processed/"
|
|
109
|
+
bucket: str
|
|
110
|
+
ssl: bool = False
|
|
111
|
+
type: str
|
|
112
|
+
url: str
|
|
113
|
+
|
|
114
|
+
key: typing.Optional[str] = None
|
|
115
|
+
region: typing.Optional[str] = None
|
|
116
|
+
secret: typing.Optional[str] = None
|
|
117
|
+
token: typing.Optional[str] = None
|
|
118
|
+
|
|
119
|
+
def get_key(self) -> typing.Optional[str]:
|
|
120
|
+
if self.key:
|
|
121
|
+
return self.key
|
|
122
|
+
|
|
123
|
+
return os.environ.get(GX_KEY)
|
|
124
|
+
|
|
125
|
+
def get_region(self) -> typing.Optional[str]:
|
|
126
|
+
if self.region:
|
|
127
|
+
return self.region
|
|
128
|
+
|
|
129
|
+
key = os.environ.get(GX_REGION)
|
|
130
|
+
if key:
|
|
131
|
+
return key
|
|
132
|
+
|
|
133
|
+
key = os.environ.get(AWS_REGION)
|
|
134
|
+
if key:
|
|
135
|
+
return key
|
|
136
|
+
|
|
137
|
+
key = os.environ.get(GX_DEFAULT_REGION)
|
|
138
|
+
if key:
|
|
139
|
+
return key
|
|
140
|
+
|
|
141
|
+
return os.environ.get(AWS_DEFAULT_REGION)
|
|
142
|
+
|
|
143
|
+
def get_secret(self) -> typing.Optional[str]:
|
|
144
|
+
if self.secret:
|
|
145
|
+
return self.secret
|
|
146
|
+
|
|
147
|
+
return os.environ.get(GX_SECRET)
|
|
148
|
+
|
|
149
|
+
def get_token(self) -> typing.Optional[str]:
|
|
150
|
+
if self.token:
|
|
151
|
+
return self.token
|
|
152
|
+
|
|
153
|
+
return os.environ.get(GX_TOKEN)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class GroundXSettings(BaseModel):
|
|
157
|
+
api_key: typing.Optional[str] = None
|
|
158
|
+
base_url: typing.Optional[str] = None
|
|
159
|
+
upload_url: str = "https://upload.eyelevel.ai"
|
|
160
|
+
|
|
161
|
+
def get_api_key(self) -> str:
|
|
162
|
+
if self.api_key:
|
|
163
|
+
return self.api_key
|
|
164
|
+
|
|
165
|
+
key = os.environ.get(GX_API_KEY)
|
|
166
|
+
if key:
|
|
167
|
+
return key
|
|
168
|
+
|
|
169
|
+
raise Exception(f"you must set a valid GroundX api_key")
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import pytest, typing, unittest
|
|
2
|
+
|
|
3
|
+
pytest.importorskip("PIL")
|
|
4
|
+
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from unittest.mock import patch
|
|
8
|
+
|
|
9
|
+
from .document import Document, DocumentRequest
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def DR(**data: typing.Any) -> DocumentRequest:
|
|
13
|
+
return DocumentRequest.model_validate(data)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_request() -> DocumentRequest:
|
|
17
|
+
return DR(documentID="D", fileName="F", modelID=1, processorID=1, taskID="T")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DummyChunk:
|
|
21
|
+
def __init__(self, json_str: str):
|
|
22
|
+
self.sectionSummary = None
|
|
23
|
+
self.suggestedText = json_str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DummyDocumentPage:
|
|
27
|
+
def __init__(self, page_url: str):
|
|
28
|
+
self.pageUrl = page_url
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DummyXRay:
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
source_url: str,
|
|
35
|
+
chunks: typing.Optional[typing.List[DummyChunk]] = [],
|
|
36
|
+
document_pages: typing.Optional[typing.List[str]] = [],
|
|
37
|
+
):
|
|
38
|
+
self.chunks = chunks
|
|
39
|
+
self.documentPages: typing.List[DummyDocumentPage] = []
|
|
40
|
+
if document_pages is not None:
|
|
41
|
+
for p in document_pages:
|
|
42
|
+
self.documentPages.append(DummyDocumentPage(p))
|
|
43
|
+
self.sourceUrl = source_url
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TestDocument(unittest.TestCase):
|
|
47
|
+
def setUp(self) -> None:
|
|
48
|
+
patcher = patch(
|
|
49
|
+
"groundx.extract.classes.document.GroundXDocument.xray", autospec=True
|
|
50
|
+
)
|
|
51
|
+
self.mock_xray = patcher.start()
|
|
52
|
+
self.addCleanup(patcher.stop)
|
|
53
|
+
self.mock_xray.return_value = DummyXRay("http://test.co", [])
|
|
54
|
+
|
|
55
|
+
def test_init_name(self) -> None:
|
|
56
|
+
st1: Document = Document.from_request(
|
|
57
|
+
base_url="",
|
|
58
|
+
req=test_request(),
|
|
59
|
+
)
|
|
60
|
+
self.assertEqual(st1.file_name, "F")
|
|
61
|
+
st2: Document = Document.from_request(
|
|
62
|
+
base_url="",
|
|
63
|
+
req=DR(
|
|
64
|
+
documentID="D", fileName="F.pdf", modelID=1, processorID=1, taskID="T"
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
self.assertEqual(st2.file_name, "F.pdf")
|
|
68
|
+
st3: Document = Document.from_request(
|
|
69
|
+
base_url="",
|
|
70
|
+
req=DR(documentID="D", fileName="F.", modelID=1, processorID=1, taskID="T"),
|
|
71
|
+
)
|
|
72
|
+
self.assertEqual(st3.file_name, "F.")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TestDocumentRequest(unittest.TestCase):
|
|
76
|
+
def test_load_images_cached(self) -> None:
|
|
77
|
+
urls: typing.List[str] = [
|
|
78
|
+
"http://example.com/page1.png",
|
|
79
|
+
"http://example.com/page2.png",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
red_img = Image.new("RGB", (10, 10), color="red")
|
|
83
|
+
buf = BytesIO()
|
|
84
|
+
red_img.save(buf, format="PNG")
|
|
85
|
+
|
|
86
|
+
st = test_request()
|
|
87
|
+
st.page_images = [red_img, red_img]
|
|
88
|
+
st.page_image_dict = {
|
|
89
|
+
urls[0]: 0,
|
|
90
|
+
urls[1]: 1,
|
|
91
|
+
}
|
|
92
|
+
st.load_images(urls)
|
|
93
|
+
self.assertEqual(len(st.page_images), 2)
|
|
94
|
+
self.assertEqual(len(st.page_image_dict), 2)
|
|
95
|
+
|
|
96
|
+
def test_load_images_download(self) -> None:
|
|
97
|
+
urls = ["http://example.com/page1.png", "http://example.com/page2.png"]
|
|
98
|
+
|
|
99
|
+
red_img = Image.new("RGB", (10, 10), color="red")
|
|
100
|
+
buf = BytesIO()
|
|
101
|
+
red_img.save(buf, format="PNG")
|
|
102
|
+
img_bytes = buf.getvalue()
|
|
103
|
+
|
|
104
|
+
class DummyResp:
|
|
105
|
+
content = img_bytes
|
|
106
|
+
|
|
107
|
+
def raise_for_status(self) -> None:
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
with patch("requests.get", return_value=DummyResp()):
|
|
111
|
+
st = test_request()
|
|
112
|
+
st.load_images(urls)
|
|
113
|
+
|
|
114
|
+
self.assertEqual(len(st.page_images), 2)
|
|
115
|
+
self.assertEqual(len(st.page_image_dict), 2)
|
|
116
|
+
for img in st.page_images:
|
|
117
|
+
self.assertIsInstance(img, Image.Image)
|
|
118
|
+
self.assertEqual(img.size, (10, 10))
|
|
119
|
+
|
|
120
|
+
def test_load_images_error(self) -> None:
|
|
121
|
+
urls = ["http://example.com/page1.png", "http://example.com/page2.png"]
|
|
122
|
+
|
|
123
|
+
st = test_request()
|
|
124
|
+
st.load_images(urls, should_sleep=False)
|
|
125
|
+
self.assertEqual(len(st.page_images), 0)
|
|
126
|
+
self.assertEqual(len(st.page_image_dict), 0)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import pytest, typing, unittest
|
|
2
|
+
|
|
3
|
+
pytest.importorskip("dateparser")
|
|
4
|
+
|
|
5
|
+
from .field import ExtractedField
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def TestField(
|
|
9
|
+
name: str,
|
|
10
|
+
value: typing.Union[str, float, typing.List[typing.Any]],
|
|
11
|
+
conflicts: typing.List[typing.Any] = [],
|
|
12
|
+
) -> ExtractedField:
|
|
13
|
+
return ExtractedField(
|
|
14
|
+
key=name.replace("_", " "),
|
|
15
|
+
value=value,
|
|
16
|
+
conflicts=conflicts,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestExtractedField(unittest.TestCase):
|
|
21
|
+
def test_equalToValue_string(self):
|
|
22
|
+
ef = TestField("test", "hello")
|
|
23
|
+
self.assertTrue(ef.equal_to_value("hello"))
|
|
24
|
+
self.assertFalse(ef.equal_to_value("world"))
|
|
25
|
+
|
|
26
|
+
def test_equalToValue_int_float_equivalence(self):
|
|
27
|
+
ef = TestField("test", int(10))
|
|
28
|
+
self.assertTrue(ef.equal_to_value(10.0))
|
|
29
|
+
self.assertTrue(ef.equal_to_value(10))
|
|
30
|
+
|
|
31
|
+
def test_equalToValue_mismatch(self):
|
|
32
|
+
ef = TestField("test", 3.14)
|
|
33
|
+
self.assertFalse(ef.equal_to_value(2.71))
|
|
34
|
+
|
|
35
|
+
def test_set_value_dates(self):
|
|
36
|
+
ef1 = TestField("test date", "3/29/25")
|
|
37
|
+
self.assertEqual(ef1.get_value(), "2025-03-29")
|
|
38
|
+
ef2 = TestField("test date", "2025-03-29")
|
|
39
|
+
self.assertEqual(ef2.get_value(), "2025-03-29")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
if __name__ == "__main__":
|
|
43
|
+
unittest.main()
|