groundx 2.4.5__py3-none-any.whl → 2.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of groundx might be problematic. Click here for more details.
- groundx/core/client_wrapper.py +2 -2
- groundx/extract/__init__.py +38 -0
- groundx/extract/agents/__init__.py +7 -0
- groundx/extract/agents/agent.py +202 -0
- groundx/extract/classes/__init__.py +27 -0
- groundx/extract/classes/agent.py +22 -0
- groundx/extract/classes/api.py +15 -0
- groundx/extract/classes/document.py +311 -0
- groundx/extract/classes/field.py +88 -0
- groundx/extract/classes/groundx.py +123 -0
- groundx/extract/classes/post_process.py +33 -0
- groundx/extract/classes/prompt.py +36 -0
- groundx/extract/classes/settings.py +169 -0
- groundx/extract/classes/test_document.py +126 -0
- groundx/extract/classes/test_field.py +43 -0
- groundx/extract/classes/test_groundx.py +188 -0
- groundx/extract/classes/test_prompt.py +68 -0
- groundx/extract/classes/test_settings.py +515 -0
- groundx/extract/classes/test_utility.py +81 -0
- groundx/extract/classes/utility.py +193 -0
- groundx/extract/services/.DS_Store +0 -0
- groundx/extract/services/__init__.py +14 -0
- groundx/extract/services/csv.py +76 -0
- groundx/extract/services/logger.py +127 -0
- groundx/extract/services/logging_cfg.py +55 -0
- groundx/extract/services/ratelimit.py +104 -0
- groundx/extract/services/sheets_client.py +160 -0
- groundx/extract/services/status.py +197 -0
- groundx/extract/services/upload.py +73 -0
- groundx/extract/services/upload_minio.py +122 -0
- groundx/extract/services/upload_s3.py +84 -0
- groundx/extract/services/utility.py +52 -0
- {groundx-2.4.5.dist-info → groundx-2.4.9.dist-info}/METADATA +1 -1
- {groundx-2.4.5.dist-info → groundx-2.4.9.dist-info}/RECORD +36 -5
- {groundx-2.4.5.dist-info → groundx-2.4.9.dist-info}/LICENSE +0 -0
- {groundx-2.4.5.dist-info → groundx-2.4.9.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
if typing.TYPE_CHECKING:
|
|
4
|
+
from .prompt import Prompt
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def class_fields(cls: typing.Any) -> typing.Set[str]:
|
|
8
|
+
fields: typing.Set[str] = set()
|
|
9
|
+
if hasattr(cls, "model_fields"):
|
|
10
|
+
fields = set(cls.model_fields.keys())
|
|
11
|
+
elif hasattr(cls, "__fields__"):
|
|
12
|
+
fields = set(cls.__fields__.keys()) # type: ignore[reportDeprecated]
|
|
13
|
+
else:
|
|
14
|
+
fields = set()
|
|
15
|
+
|
|
16
|
+
return fields
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def clean_json(txt: str) -> str:
|
|
20
|
+
for p in ("json```\n", "```json\n", "json\n"):
|
|
21
|
+
if txt.startswith(p):
|
|
22
|
+
txt = txt[len(p) :]
|
|
23
|
+
if txt.endswith("```"):
|
|
24
|
+
txt = txt[:-3]
|
|
25
|
+
return txt.strip()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def coerce_numeric_string(
|
|
29
|
+
value: typing.Any,
|
|
30
|
+
et: typing.Union[str, typing.List[str]],
|
|
31
|
+
) -> typing.Union[typing.Any, typing.List[typing.Any]]:
|
|
32
|
+
expected_types = str_to_type_sequence(et)
|
|
33
|
+
|
|
34
|
+
if any(t in (int, float) for t in expected_types):
|
|
35
|
+
if isinstance(value, str):
|
|
36
|
+
value = value.replace(",", "")
|
|
37
|
+
try:
|
|
38
|
+
value = float(value)
|
|
39
|
+
except ValueError:
|
|
40
|
+
return value
|
|
41
|
+
if float in expected_types:
|
|
42
|
+
return value
|
|
43
|
+
return int(value)
|
|
44
|
+
|
|
45
|
+
if str in expected_types and isinstance(value, str) and value == "0":
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def from_attr_name(
|
|
52
|
+
name: str, prompts: typing.Sequence[typing.Mapping[str, "Prompt"]]
|
|
53
|
+
) -> typing.Tuple[typing.Optional[str], typing.Optional[typing.Any]]:
|
|
54
|
+
for pmps in prompts:
|
|
55
|
+
for key, prompt in pmps.items():
|
|
56
|
+
if getattr(prompt, "attr_name", None) == name:
|
|
57
|
+
return key, prompt
|
|
58
|
+
|
|
59
|
+
return None, None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def from_key(
|
|
63
|
+
name: str,
|
|
64
|
+
prompts: typing.Sequence[typing.Mapping[str, "Prompt"]],
|
|
65
|
+
) -> typing.Tuple[typing.Optional[str], typing.Optional[typing.Any]]:
|
|
66
|
+
for pmps in prompts:
|
|
67
|
+
for k, prompt in pmps.items():
|
|
68
|
+
if k == name:
|
|
69
|
+
return k, prompt
|
|
70
|
+
|
|
71
|
+
key, pmp = from_attr_name(name, prompts)
|
|
72
|
+
if pmp:
|
|
73
|
+
return key, pmp
|
|
74
|
+
|
|
75
|
+
return None, None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def str_to_type_sequence(
|
|
79
|
+
ty: typing.Union[str, typing.List[str]],
|
|
80
|
+
) -> typing.Sequence[typing.Type[typing.Any]]:
|
|
81
|
+
if isinstance(ty, list):
|
|
82
|
+
tys: typing.List[typing.Any] = []
|
|
83
|
+
for t in ty:
|
|
84
|
+
tys.append(str_to_type(t))
|
|
85
|
+
|
|
86
|
+
return tys
|
|
87
|
+
|
|
88
|
+
return [str_to_type(ty)]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def str_to_type(
|
|
92
|
+
ty: str,
|
|
93
|
+
) -> typing.Type[typing.Any]:
|
|
94
|
+
if ty == "int":
|
|
95
|
+
return int
|
|
96
|
+
elif ty == "float":
|
|
97
|
+
return float
|
|
98
|
+
elif ty == "list":
|
|
99
|
+
return list
|
|
100
|
+
elif ty == "dict":
|
|
101
|
+
return dict
|
|
102
|
+
|
|
103
|
+
return str
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def type_to_str(
|
|
107
|
+
ty: typing.Union[typing.Type[typing.Any], typing.Sequence[typing.Type[typing.Any]]],
|
|
108
|
+
) -> typing.Union[str, typing.List[str]]:
|
|
109
|
+
if isinstance(ty, list):
|
|
110
|
+
tys: typing.List[str] = []
|
|
111
|
+
for t in ty:
|
|
112
|
+
nt = type_to_str(t)
|
|
113
|
+
if isinstance(nt, str):
|
|
114
|
+
tys.append(nt)
|
|
115
|
+
else:
|
|
116
|
+
tys.append("list")
|
|
117
|
+
|
|
118
|
+
if ty == int:
|
|
119
|
+
return "int"
|
|
120
|
+
if ty == float:
|
|
121
|
+
return "float"
|
|
122
|
+
if ty == list:
|
|
123
|
+
return "list"
|
|
124
|
+
if ty == dict:
|
|
125
|
+
return "dict"
|
|
126
|
+
|
|
127
|
+
return "str"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def validate_confidence(
|
|
131
|
+
key: str,
|
|
132
|
+
key_data: typing.Any,
|
|
133
|
+
fields: typing.Set[str],
|
|
134
|
+
value: typing.Any,
|
|
135
|
+
errors: typing.Dict[str, str],
|
|
136
|
+
) -> typing.Tuple[
|
|
137
|
+
typing.Union[typing.Any, typing.List[typing.Any]],
|
|
138
|
+
typing.Optional[str],
|
|
139
|
+
typing.Optional[str],
|
|
140
|
+
]:
|
|
141
|
+
if key_data.attr_name not in fields:
|
|
142
|
+
return None, None, f"unexpected attribute [{key_data.attr_name}]"
|
|
143
|
+
|
|
144
|
+
if value is None:
|
|
145
|
+
return None, None, None
|
|
146
|
+
|
|
147
|
+
if not isinstance(value, dict):
|
|
148
|
+
return (
|
|
149
|
+
None,
|
|
150
|
+
None,
|
|
151
|
+
f"unexpected value type [{key_data.attr_name}] [{type(value)}] [{key_data.type}]\n[{value}]",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if "value" not in value:
|
|
155
|
+
return (
|
|
156
|
+
None,
|
|
157
|
+
None,
|
|
158
|
+
f'value is missing "value" key [{key_data.attr_name}]\n[{value}]',
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if value["value"] is None:
|
|
162
|
+
return None, None, None
|
|
163
|
+
|
|
164
|
+
final_value = coerce_numeric_string(value["value"], key_data.type)
|
|
165
|
+
if not key_data.valid_value(final_value):
|
|
166
|
+
return (
|
|
167
|
+
final_value,
|
|
168
|
+
None,
|
|
169
|
+
f"unexpected type for statement [{key}] value [{type(final_value)}]\n\n{final_value}",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if "confidence" not in value:
|
|
173
|
+
return (
|
|
174
|
+
final_value,
|
|
175
|
+
None,
|
|
176
|
+
f'value is missing "confidence" key [{key_data.attr_name}]\n[{value}]',
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if not isinstance(value["confidence"], str):
|
|
180
|
+
return (
|
|
181
|
+
final_value,
|
|
182
|
+
None,
|
|
183
|
+
f"confidence is not type str [{key_data.attr_name}]\n[{value}]",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if value["confidence"] not in ["low", "medium", "high"]:
|
|
187
|
+
return (
|
|
188
|
+
final_value,
|
|
189
|
+
None,
|
|
190
|
+
f'confidence value is unsupported value [{key_data.attr_name}]\n[{value["confidence"]}]',
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return final_value, value["confidence"], None
|
|
Binary file
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .logger import Logger
|
|
2
|
+
from .sheets_client import SheetsClient
|
|
3
|
+
from .ratelimit import RateLimit
|
|
4
|
+
from .status import Status
|
|
5
|
+
from .upload import Upload
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Logger",
|
|
10
|
+
"RateLimit",
|
|
11
|
+
"SheetsClient",
|
|
12
|
+
"Status",
|
|
13
|
+
"Upload",
|
|
14
|
+
]
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import csv, typing
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def append_row(
|
|
6
|
+
csv_path: Path,
|
|
7
|
+
headers: typing.List[str],
|
|
8
|
+
row: typing.Dict[str, str],
|
|
9
|
+
) -> None:
|
|
10
|
+
with csv_path.open("a", newline="") as f:
|
|
11
|
+
writer = csv.DictWriter(f, fieldnames=headers)
|
|
12
|
+
writer.writerow(row)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def extraction_row(
|
|
16
|
+
record: typing.Mapping[str, typing.Any], keys_in_order: typing.Sequence[str]
|
|
17
|
+
) -> typing.List[typing.Any]:
|
|
18
|
+
return [record.get(k, "") for k in keys_in_order]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def find_rows(
|
|
22
|
+
query: typing.Dict[str, str],
|
|
23
|
+
csv_path: str,
|
|
24
|
+
) -> typing.List[typing.Dict[str, str]]:
|
|
25
|
+
with open(csv_path, newline="", encoding="utf-8") as f:
|
|
26
|
+
reader = csv.DictReader(f)
|
|
27
|
+
|
|
28
|
+
rows: typing.List[typing.Dict[str, str]] = []
|
|
29
|
+
for row in reader:
|
|
30
|
+
matches: typing.List[str] = []
|
|
31
|
+
for k, v in query.items():
|
|
32
|
+
if str(row.get(k)) == str(v):
|
|
33
|
+
matches.append(k)
|
|
34
|
+
|
|
35
|
+
if len(matches) == len(query):
|
|
36
|
+
rows.append(row)
|
|
37
|
+
|
|
38
|
+
return rows
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_row(
|
|
42
|
+
key: str,
|
|
43
|
+
match: typing.List[str],
|
|
44
|
+
csv_path: typing.Optional[Path] = None,
|
|
45
|
+
rows: typing.Optional[typing.List[typing.Dict[str, str]]] = None,
|
|
46
|
+
) -> typing.Optional[typing.Dict[str, str]]:
|
|
47
|
+
if csv_path is None and rows is None:
|
|
48
|
+
raise Exception("csv_path and rows are None")
|
|
49
|
+
|
|
50
|
+
if rows is None and csv_path:
|
|
51
|
+
rows = load_rows(csv_path)
|
|
52
|
+
|
|
53
|
+
if not rows:
|
|
54
|
+
raise Exception("rows are None")
|
|
55
|
+
|
|
56
|
+
return next((r for r in rows if r.get(key) in match), None)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_rows(csv_path: Path) -> typing.List[typing.Dict[str, str]]:
|
|
60
|
+
rows: typing.List[typing.Dict[str, str]] = []
|
|
61
|
+
with csv_path.open("r", newline="") as csvfile:
|
|
62
|
+
reader = csv.DictReader(csvfile)
|
|
63
|
+
for row in reader:
|
|
64
|
+
rows.append(row)
|
|
65
|
+
|
|
66
|
+
return rows
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def save_rows(
|
|
70
|
+
csv_path: Path, headers: typing.List[str], rows: typing.List[typing.Dict[str, str]]
|
|
71
|
+
) -> None:
|
|
72
|
+
with csv_path.open("w", newline="") as csvfile:
|
|
73
|
+
writer = csv.DictWriter(csvfile, fieldnames=headers)
|
|
74
|
+
writer.writeheader()
|
|
75
|
+
for r in rows:
|
|
76
|
+
writer.writerow(r)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import logging, logging.config, typing
|
|
2
|
+
|
|
3
|
+
from .logging_cfg import logging_config
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Logger:
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
name: str,
|
|
10
|
+
level: str,
|
|
11
|
+
) -> None:
|
|
12
|
+
logging.config.dictConfig(logging_config(name, level))
|
|
13
|
+
|
|
14
|
+
self.logger = logging.getLogger(name)
|
|
15
|
+
|
|
16
|
+
def debug_msg(
|
|
17
|
+
self,
|
|
18
|
+
msg: str,
|
|
19
|
+
name: typing.Optional[str] = None,
|
|
20
|
+
document_id: typing.Optional[str] = None,
|
|
21
|
+
task_id: typing.Optional[str] = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
self.print_msg("DEBUG", msg, name, document_id, task_id)
|
|
24
|
+
|
|
25
|
+
def error_msg(
|
|
26
|
+
self,
|
|
27
|
+
msg: str,
|
|
28
|
+
name: typing.Optional[str] = None,
|
|
29
|
+
document_id: typing.Optional[str] = None,
|
|
30
|
+
task_id: typing.Optional[str] = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
self.print_msg("ERROR", msg, name, document_id, task_id)
|
|
33
|
+
|
|
34
|
+
def info_msg(
|
|
35
|
+
self,
|
|
36
|
+
msg: str,
|
|
37
|
+
name: typing.Optional[str] = None,
|
|
38
|
+
document_id: typing.Optional[str] = None,
|
|
39
|
+
task_id: typing.Optional[str] = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
self.print_msg("INFO", msg, name, document_id, task_id)
|
|
42
|
+
|
|
43
|
+
def report_error(
|
|
44
|
+
self,
|
|
45
|
+
api_key: str,
|
|
46
|
+
callback_url: str,
|
|
47
|
+
req: typing.Optional[typing.Dict[str, typing.Any]],
|
|
48
|
+
msg: str,
|
|
49
|
+
) -> None:
|
|
50
|
+
import requests
|
|
51
|
+
|
|
52
|
+
self.error_msg(msg)
|
|
53
|
+
|
|
54
|
+
if req is None or callback_url == "":
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
requests.post(
|
|
58
|
+
callback_url,
|
|
59
|
+
json=req,
|
|
60
|
+
headers={"X-API-Key": api_key},
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def report_result(
|
|
64
|
+
self,
|
|
65
|
+
api_key: str,
|
|
66
|
+
callback_url: str,
|
|
67
|
+
result_url: str,
|
|
68
|
+
req: typing.Dict[str, typing.Any],
|
|
69
|
+
):
|
|
70
|
+
import requests
|
|
71
|
+
|
|
72
|
+
if callback_url == "":
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
self.info_msg("calling back to [%s]" % (callback_url))
|
|
76
|
+
|
|
77
|
+
requests.post(
|
|
78
|
+
callback_url,
|
|
79
|
+
json=req,
|
|
80
|
+
headers={"X-API-Key": api_key},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def warning_msg(
|
|
84
|
+
self,
|
|
85
|
+
msg: str,
|
|
86
|
+
name: typing.Optional[str] = None,
|
|
87
|
+
document_id: typing.Optional[str] = None,
|
|
88
|
+
task_id: typing.Optional[str] = None,
|
|
89
|
+
) -> None:
|
|
90
|
+
self.print_msg("WARNING", msg, name, document_id, task_id)
|
|
91
|
+
|
|
92
|
+
def print_msg(
|
|
93
|
+
self,
|
|
94
|
+
level: str,
|
|
95
|
+
msg: str,
|
|
96
|
+
name: typing.Optional[str] = None,
|
|
97
|
+
document_id: typing.Optional[str] = None,
|
|
98
|
+
task_id: typing.Optional[str] = None,
|
|
99
|
+
) -> None:
|
|
100
|
+
prefix = ""
|
|
101
|
+
if name:
|
|
102
|
+
if prefix != "":
|
|
103
|
+
prefix += " "
|
|
104
|
+
prefix += f"[{name}]"
|
|
105
|
+
if document_id:
|
|
106
|
+
if prefix != "":
|
|
107
|
+
prefix += " "
|
|
108
|
+
prefix += f"d [{document_id}]"
|
|
109
|
+
if task_id:
|
|
110
|
+
if prefix != "":
|
|
111
|
+
prefix += " "
|
|
112
|
+
prefix += f"t [{task_id}]"
|
|
113
|
+
|
|
114
|
+
text = ""
|
|
115
|
+
if prefix != "":
|
|
116
|
+
text += f"{prefix} "
|
|
117
|
+
text += f"\n\n\t>> {msg}\n"
|
|
118
|
+
|
|
119
|
+
lvl = level.upper()
|
|
120
|
+
if lvl == "ERROR":
|
|
121
|
+
self.logger.error(text)
|
|
122
|
+
elif lvl in ("WARN", "WARNING"):
|
|
123
|
+
self.logger.warning(text)
|
|
124
|
+
elif lvl == "INFO":
|
|
125
|
+
self.logger.info(text)
|
|
126
|
+
else:
|
|
127
|
+
self.logger.debug(text)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def logging_config(name: str, level: str) -> typing.Dict[str, typing.Any]:
|
|
5
|
+
print(level)
|
|
6
|
+
|
|
7
|
+
return {
|
|
8
|
+
"version": 1,
|
|
9
|
+
"disable_existing_loggers": False,
|
|
10
|
+
"formatters": {
|
|
11
|
+
"default": {
|
|
12
|
+
"format": "%(asctime)s - [%(process)d] - %(levelname)s - %(message)s",
|
|
13
|
+
},
|
|
14
|
+
},
|
|
15
|
+
"handlers": {
|
|
16
|
+
"default": {
|
|
17
|
+
"level": level,
|
|
18
|
+
"formatter": "default",
|
|
19
|
+
"class": "logging.StreamHandler",
|
|
20
|
+
"stream": "ext://sys.stdout",
|
|
21
|
+
},
|
|
22
|
+
},
|
|
23
|
+
"loggers": {
|
|
24
|
+
"": {
|
|
25
|
+
"handlers": ["default"],
|
|
26
|
+
"level": "WARNING",
|
|
27
|
+
},
|
|
28
|
+
name: {
|
|
29
|
+
"handlers": ["default"],
|
|
30
|
+
"level": level,
|
|
31
|
+
"propagate": False,
|
|
32
|
+
},
|
|
33
|
+
"gunicorn.error": {
|
|
34
|
+
"level": "INFO",
|
|
35
|
+
"handlers": ["default"],
|
|
36
|
+
"propagate": False,
|
|
37
|
+
},
|
|
38
|
+
"gunicorn.access": {
|
|
39
|
+
"level": "WARNING",
|
|
40
|
+
"handlers": ["default"],
|
|
41
|
+
"propagate": False,
|
|
42
|
+
},
|
|
43
|
+
"uvicorn": {"level": "INFO", "handlers": ["default"], "propagate": False},
|
|
44
|
+
"uvicorn.error": {
|
|
45
|
+
"level": "INFO",
|
|
46
|
+
"handlers": ["default"],
|
|
47
|
+
"propagate": False,
|
|
48
|
+
},
|
|
49
|
+
"uvicorn.access": {
|
|
50
|
+
"level": "WARNING",
|
|
51
|
+
"handlers": ["default"],
|
|
52
|
+
"propagate": False,
|
|
53
|
+
},
|
|
54
|
+
},
|
|
55
|
+
}
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
from fastapi import Request, HTTPException
|
|
5
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
6
|
+
from starlette.responses import JSONResponse, Response
|
|
7
|
+
from starlette.types import ASGIApp
|
|
8
|
+
|
|
9
|
+
from ..classes.api import ProcessResponse
|
|
10
|
+
from ..classes.settings import ContainerSettings
|
|
11
|
+
from .logger import Logger
|
|
12
|
+
from .status import Status
|
|
13
|
+
from .utility import get_gunicorn_threads, get_thread_id, get_worker_id
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RateLimit(BaseHTTPMiddleware):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
app: ASGIApp,
|
|
20
|
+
settings: ContainerSettings,
|
|
21
|
+
logger: Logger,
|
|
22
|
+
) -> None:
|
|
23
|
+
super().__init__(app)
|
|
24
|
+
|
|
25
|
+
self.worker_id = get_worker_id()
|
|
26
|
+
num_threads = get_gunicorn_threads()
|
|
27
|
+
if num_threads > 1:
|
|
28
|
+
num_threads = num_threads - 1
|
|
29
|
+
|
|
30
|
+
self.status = Status(
|
|
31
|
+
settings,
|
|
32
|
+
logger,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
self.settings = settings
|
|
36
|
+
self.logger = logger
|
|
37
|
+
|
|
38
|
+
self.thread_ids: typing.Dict[str, typing.Any] = {}
|
|
39
|
+
|
|
40
|
+
self.status.set_worker_available(self.worker_id)
|
|
41
|
+
|
|
42
|
+
self.logger.info_msg(
|
|
43
|
+
f"[{self.settings.service}] ratelimit init [{num_threads}]"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
async def dispatch(
|
|
47
|
+
self,
|
|
48
|
+
request: Request,
|
|
49
|
+
call_next: typing.Callable[[Request], typing.Awaitable[Response]],
|
|
50
|
+
) -> Response:
|
|
51
|
+
thread_id, self.thread_ids = get_thread_id(self.thread_ids)
|
|
52
|
+
wasSet = False
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
if request.url.path == "/health":
|
|
56
|
+
response = await call_next(request)
|
|
57
|
+
|
|
58
|
+
self.status.refresh_worker(self.worker_id)
|
|
59
|
+
|
|
60
|
+
available, total = self.status.get_worker_state(self.worker_id)
|
|
61
|
+
|
|
62
|
+
response = self.status.set_headers(
|
|
63
|
+
response, self.worker_id, available, total
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return response
|
|
67
|
+
|
|
68
|
+
api_key = request.headers.get("X-API-Key") or request.headers.get(
|
|
69
|
+
"Authorization"
|
|
70
|
+
)
|
|
71
|
+
if api_key and api_key.startswith("Bearer "):
|
|
72
|
+
api_key = api_key.split("Bearer ")[1]
|
|
73
|
+
if not api_key or api_key not in self.settings.get_valid_api_keys():
|
|
74
|
+
raise HTTPException(status_code=403, detail="Invalid API key")
|
|
75
|
+
|
|
76
|
+
request.state.api_key = api_key
|
|
77
|
+
|
|
78
|
+
wasSet = True
|
|
79
|
+
self.status.set_worker_unavailable(self.worker_id)
|
|
80
|
+
|
|
81
|
+
response = await call_next(request)
|
|
82
|
+
|
|
83
|
+
wasSet = False
|
|
84
|
+
self.status.set_worker_available(self.worker_id)
|
|
85
|
+
|
|
86
|
+
available, total = self.status.get_service_state()
|
|
87
|
+
|
|
88
|
+
response.headers.update(
|
|
89
|
+
{
|
|
90
|
+
"X-RateLimit-Limit-Requests": str(total),
|
|
91
|
+
"X-RateLimit-Remaining-Requests": str(max(0, available)),
|
|
92
|
+
"X-Worker-ID": f"{self.worker_id}:{thread_id}",
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return response
|
|
97
|
+
except HTTPException as exc:
|
|
98
|
+
if wasSet:
|
|
99
|
+
self.status.set_worker_available(self.worker_id)
|
|
100
|
+
|
|
101
|
+
return JSONResponse(
|
|
102
|
+
status_code=exc.status_code,
|
|
103
|
+
content=asdict(ProcessResponse(message=exc.detail)),
|
|
104
|
+
)
|