csvai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- csvai/__init__.py +8 -0
- csvai/__main__.py +7 -0
- csvai/cli.py +57 -0
- csvai/io_utils.py +125 -0
- csvai/launch.py +13 -0
- csvai/processor.py +420 -0
- csvai/settings.py +42 -0
- csvai/ui.py +216 -0
- csvai-0.1.0.dist-info/METADATA +751 -0
- csvai-0.1.0.dist-info/RECORD +14 -0
- csvai-0.1.0.dist-info/WHEEL +5 -0
- csvai-0.1.0.dist-info/entry_points.txt +3 -0
- csvai-0.1.0.dist-info/licenses/LICENSE +339 -0
- csvai-0.1.0.dist-info/top_level.txt +1 -0
csvai/__init__.py
ADDED
csvai/__main__.py
ADDED
csvai/cli.py
ADDED
@@ -0,0 +1,57 @@
|
|
1
|
+
"""Command-line interface for CSVAI."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import asyncio
|
5
|
+
import logging
|
6
|
+
import signal
|
7
|
+
|
8
|
+
from .processor import CSVAIProcessor, ProcessorConfig
|
9
|
+
from .settings import Settings
|
10
|
+
|
11
|
+
|
12
|
+
def main() -> None:
|
13
|
+
"""Run the CSVAI processor via the command line."""
|
14
|
+
settings = Settings()
|
15
|
+
|
16
|
+
parser = argparse.ArgumentParser(
|
17
|
+
description="Enrich CSV/Excel rows.",
|
18
|
+
)
|
19
|
+
parser.add_argument("input", help="Input CSV or Excel file path")
|
20
|
+
parser.add_argument("--prompt", "-p", help="Prompt text file path")
|
21
|
+
parser.add_argument("--output", "-o", help="Output CSV or Excel file path")
|
22
|
+
parser.add_argument(
|
23
|
+
"--schema",
|
24
|
+
help="JSON schema file path (strict). If omitted, uses json_object.",
|
25
|
+
)
|
26
|
+
parser.add_argument("--limit", type=int, help="Limit number of new rows to process")
|
27
|
+
parser.add_argument("--model", default=settings.default_model, help="Model to use")
|
28
|
+
args = parser.parse_args()
|
29
|
+
|
30
|
+
config = ProcessorConfig(
|
31
|
+
input=args.input,
|
32
|
+
prompt=args.prompt,
|
33
|
+
output=args.output,
|
34
|
+
schema=args.schema,
|
35
|
+
limit=args.limit,
|
36
|
+
model=args.model,
|
37
|
+
)
|
38
|
+
processor = CSVAIProcessor(config, settings=settings)
|
39
|
+
|
40
|
+
def _handle_signal(sig, frame):
|
41
|
+
logging.warning("Signal %s received: shutting down after current batch", sig)
|
42
|
+
processor.stop()
|
43
|
+
|
44
|
+
signal.signal(signal.SIGINT, _handle_signal)
|
45
|
+
signal.signal(signal.SIGTERM, _handle_signal)
|
46
|
+
|
47
|
+
try:
|
48
|
+
asyncio.run(processor.run())
|
49
|
+
except SystemExit as e:
|
50
|
+
raise e
|
51
|
+
except Exception as e:
|
52
|
+
logging.error(f"Fatal error: {e}")
|
53
|
+
raise SystemExit(1)
|
54
|
+
|
55
|
+
|
56
|
+
if __name__ == "__main__":
|
57
|
+
main()
|
csvai/io_utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
"""Utilities for reading and writing CSV/Excel files."""
|
2
|
+
|
3
|
+
import csv
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
6
|
+
|
7
|
+
import os
|
8
|
+
import pandas as pd
|
9
|
+
import logging
|
10
|
+
from openpyxl import Workbook, load_workbook
|
11
|
+
|
12
|
+
# Configuration for file handling
|
13
|
+
DEFAULT_PROMPT_FILENAME = os.getenv("DEFAULT_PROMPT_FILENAME", "prompt.txt")
|
14
|
+
ALT_PROMPT_SUFFIX = os.getenv("ALT_PROMPT_SUFFIX", ".prompt.txt")
|
15
|
+
OUTPUT_FILE_SUFFIX = os.getenv("OUTPUT_FILE_SUFFIX", "_enriched.csv")
|
16
|
+
|
17
|
+
# ---------------------------------------------------------------------------
|
18
|
+
# Generic table loaders/writers
|
19
|
+
# ---------------------------------------------------------------------------
|
20
|
+
|
21
|
+
FILE_READERS = {
|
22
|
+
".csv": lambda p: pd.read_csv(p, dtype=str),
|
23
|
+
".xlsx": lambda p: pd.read_excel(p, dtype=str),
|
24
|
+
".xls": lambda p: pd.read_excel(p, dtype=str),
|
25
|
+
}
|
26
|
+
|
27
|
+
def _load_table(path: Path) -> pd.DataFrame:
|
28
|
+
"""Load a CSV/Excel file into a DataFrame with empty strings for NaN."""
|
29
|
+
reader = FILE_READERS.get(path.suffix.lower())
|
30
|
+
if not reader:
|
31
|
+
raise ValueError(f"Unsupported file type: {path.suffix}")
|
32
|
+
df = reader(path)
|
33
|
+
return df.fillna("")
|
34
|
+
|
35
|
+
|
36
|
+
class RowWriter:
|
37
|
+
"""Efficiently append rows to CSV or Excel output files."""
|
38
|
+
|
39
|
+
def __init__(self, path: Path, header: List[str]):
|
40
|
+
self.path = path
|
41
|
+
self.header = header
|
42
|
+
self.is_excel = path.suffix.lower() in {".xlsx", ".xls"}
|
43
|
+
self._open()
|
44
|
+
|
45
|
+
def _open(self) -> None:
|
46
|
+
exists = self.path.exists() and self.path.stat().st_size > 0
|
47
|
+
if self.is_excel:
|
48
|
+
self.wb = load_workbook(self.path) if exists else Workbook()
|
49
|
+
self.ws = self.wb.active
|
50
|
+
if not exists:
|
51
|
+
self.ws.append(self.header)
|
52
|
+
else:
|
53
|
+
mode = "a" if exists else "w"
|
54
|
+
self.f = open(self.path, mode, encoding="utf-8", newline="")
|
55
|
+
self.writer = csv.DictWriter(self.f, fieldnames=self.header, extrasaction="ignore")
|
56
|
+
if not exists:
|
57
|
+
self.writer.writeheader()
|
58
|
+
|
59
|
+
def append(self, rows: Iterable[Dict[str, Any]]) -> None:
|
60
|
+
if self.is_excel:
|
61
|
+
for row in rows:
|
62
|
+
self.ws.append([row.get(k, "") for k in self.header])
|
63
|
+
else:
|
64
|
+
self.writer.writerows({k: r.get(k, "") for k in self.header} for r in rows)
|
65
|
+
|
66
|
+
def close(self) -> None:
|
67
|
+
if self.is_excel:
|
68
|
+
self.wb.save(self.path)
|
69
|
+
else:
|
70
|
+
self.f.close()
|
71
|
+
|
72
|
+
def read_rows(filename: str) -> List[Dict[str, str]]:
|
73
|
+
"""Read rows from a CSV or Excel file as list of dicts."""
|
74
|
+
df = _load_table(Path(filename))
|
75
|
+
return df.to_dict(orient="records")
|
76
|
+
|
77
|
+
def read_prompt(filename: str) -> str:
|
78
|
+
"""Read the prompt template from a text file."""
|
79
|
+
with open(filename, "r", encoding="utf-8") as f:
|
80
|
+
return f.read()
|
81
|
+
|
82
|
+
def choose_prompt_file(input_path: Path, user_prompt: Optional[str]) -> Path:
|
83
|
+
"""Return the prompt file to use, with auto-discovery if none provided."""
|
84
|
+
path: Optional[Path] = None
|
85
|
+
if user_prompt:
|
86
|
+
path = Path(user_prompt).expanduser().resolve()
|
87
|
+
if not path.exists():
|
88
|
+
raise FileNotFoundError(f"Prompt file not found: {path}")
|
89
|
+
return path
|
90
|
+
|
91
|
+
c1 = input_path.with_suffix(ALT_PROMPT_SUFFIX)
|
92
|
+
if c1.exists():
|
93
|
+
logging.info("Auto-discovered prompt file: %s", c1)
|
94
|
+
return c1
|
95
|
+
|
96
|
+
c2 = Path(DEFAULT_PROMPT_FILENAME)
|
97
|
+
if c2.exists():
|
98
|
+
logging.info("Auto-discovered prompt file: %s", c2)
|
99
|
+
return c2
|
100
|
+
|
101
|
+
raise FileNotFoundError(
|
102
|
+
f"No prompt file supplied and neither '{c1.name}' nor '{c2.name}' exist."
|
103
|
+
)
|
104
|
+
|
105
|
+
def default_output_file(input_path: Path, user_output: Optional[str]) -> Path:
|
106
|
+
"""Determine the default output file path."""
|
107
|
+
if user_output:
|
108
|
+
return Path(user_output)
|
109
|
+
suffix = OUTPUT_FILE_SUFFIX
|
110
|
+
if input_path.suffix.lower() in {".xlsx", ".xls"} and suffix.endswith(".csv"):
|
111
|
+
suffix = suffix[:-4] + ".xlsx"
|
112
|
+
return input_path.with_name(f"{input_path.stem}{suffix}")
|
113
|
+
|
114
|
+
def collect_existing_ids_and_header(output_file: Path) -> Tuple[Set[str], List[str]]:
|
115
|
+
"""Return the set of IDs already in output and the existing header (if any)."""
|
116
|
+
if not (output_file.exists() and output_file.stat().st_size > 0):
|
117
|
+
return set(), []
|
118
|
+
df = _load_table(output_file)
|
119
|
+
header = list(df.columns)
|
120
|
+
if "id" in header:
|
121
|
+
ids = set(df["id"].astype(str))
|
122
|
+
else:
|
123
|
+
ids = {str(i) for i in range(len(df))}
|
124
|
+
return ids, header
|
125
|
+
|
csvai/launch.py
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
from streamlit.web import cli as stcli
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
|
5
|
+
def main():
|
6
|
+
"""Entry point for the csvai-ui command."""
|
7
|
+
script_path = os.path.join(os.path.dirname(__file__), 'ui.py')
|
8
|
+
args = ["run", script_path, "--global.developmentMode=false"]
|
9
|
+
sys.argv = ["streamlit"] + args
|
10
|
+
sys.exit(stcli.main())
|
11
|
+
|
12
|
+
if __name__ == "__main__":
|
13
|
+
main()
|
csvai/processor.py
ADDED
@@ -0,0 +1,420 @@
|
|
1
|
+
"""Asynchronous CSV row processor using OpenAI models."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
import re
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
12
|
+
|
13
|
+
from jinja2 import Environment, StrictUndefined
|
14
|
+
from openai import AsyncOpenAI
|
15
|
+
|
16
|
+
from .io_utils import (
|
17
|
+
read_rows,
|
18
|
+
read_prompt,
|
19
|
+
choose_prompt_file,
|
20
|
+
default_output_file,
|
21
|
+
collect_existing_ids_and_header,
|
22
|
+
RowWriter,
|
23
|
+
)
|
24
|
+
from .settings import Settings
|
25
|
+
|
26
|
+
# =============================
|
27
|
+
# Logging
|
28
|
+
# =============================
|
29
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
|
30
|
+
for noisy in ("httpx", "openai", "urllib3"):
|
31
|
+
logging.getLogger(noisy).setLevel(logging.WARNING)
|
32
|
+
|
33
|
+
# =============================
|
34
|
+
# Helpers
|
35
|
+
# =============================
|
36
|
+
|
37
|
+
|
38
|
+
def sanitize_key_name(key: str) -> str:
|
39
|
+
return key.strip().replace('"', "").replace("'", "").replace(" ", "_")
|
40
|
+
|
41
|
+
|
42
|
+
def sanitize_keys(row: Dict[str, Any]) -> Dict[str, str]:
|
43
|
+
return {sanitize_key_name(k): ("" if v is None else str(v)).strip() for k, v in (row or {}).items() if k is not None}
|
44
|
+
|
45
|
+
|
46
|
+
_CURLY_VAR_RE = re.compile(r"\{\{\s*([^}]+?)\s*\}\}")
|
47
|
+
|
48
|
+
|
49
|
+
def sanitize_prompt_placeholders(prompt_template: str, raw_keys: List[str]) -> str:
|
50
|
+
"""Map {{ Raw Header }} → {{ Raw_Header }} only for simple identifiers."""
|
51
|
+
key_map = {k: sanitize_key_name(k) for k in raw_keys}
|
52
|
+
|
53
|
+
def _replace(m: "re.Match[str]") -> str:
|
54
|
+
expr = m.group(1).strip()
|
55
|
+
if any(ch in expr for ch in ('"', "'", '[', ']', '.', '|', '(', ')', ':')):
|
56
|
+
return m.group(0)
|
57
|
+
if expr in key_map and key_map[expr] != expr:
|
58
|
+
return "{{ " + key_map[expr] + " }}"
|
59
|
+
return m.group(0)
|
60
|
+
|
61
|
+
return _CURLY_VAR_RE.sub(_replace, prompt_template)
|
62
|
+
|
63
|
+
|
64
|
+
def render_prompt(prompt_template: str, row: Dict[str, str], raw_row: Dict[str, str]) -> str:
|
65
|
+
env = Environment(undefined=StrictUndefined)
|
66
|
+
context = dict(row, raw=raw_row)
|
67
|
+
prompt_text = sanitize_prompt_placeholders(prompt_template, list(raw_row.keys()))
|
68
|
+
return env.from_string(prompt_text).render(**context)
|
69
|
+
|
70
|
+
|
71
|
+
# =============================
|
72
|
+
# OpenAI client helpers
|
73
|
+
# =============================
|
74
|
+
|
75
|
+
|
76
|
+
def get_async_client(settings: Settings) -> AsyncOpenAI:
|
77
|
+
if not settings.openai_api_key:
|
78
|
+
logging.error("OPENAI_API_KEY missing. Set it in your environment or .env file.")
|
79
|
+
raise SystemExit(2)
|
80
|
+
return AsyncOpenAI(api_key=settings.openai_api_key, timeout=settings.request_timeout)
|
81
|
+
|
82
|
+
|
83
|
+
def _pick_text_from_response(resp: Any) -> str:
|
84
|
+
text = getattr(resp, "output_text", None)
|
85
|
+
if isinstance(text, str) and text.strip():
|
86
|
+
return text.strip()
|
87
|
+
output = getattr(resp, "output", None)
|
88
|
+
if isinstance(output, list) and output:
|
89
|
+
for item in output:
|
90
|
+
content = getattr(item, "content", None)
|
91
|
+
if isinstance(content, list):
|
92
|
+
for part in content:
|
93
|
+
t = getattr(part, "text", None)
|
94
|
+
if isinstance(t, str) and t.strip():
|
95
|
+
return t.strip()
|
96
|
+
if isinstance(part, dict):
|
97
|
+
val = part.get("text")
|
98
|
+
if isinstance(val, str) and val.strip():
|
99
|
+
return val.strip()
|
100
|
+
return ""
|
101
|
+
|
102
|
+
|
103
|
+
async def call_openai_responses(
|
104
|
+
prompt: str,
|
105
|
+
client: AsyncOpenAI,
|
106
|
+
model: str,
|
107
|
+
schema: Optional[Dict[str, Any]],
|
108
|
+
settings: Settings,
|
109
|
+
) -> Optional[str]:
|
110
|
+
backoff = settings.initial_backoff
|
111
|
+
for attempt in range(1, settings.max_attempts + 1):
|
112
|
+
try:
|
113
|
+
if schema:
|
114
|
+
text_cfg: Dict[str, Any] = {
|
115
|
+
"format": {
|
116
|
+
"type": "json_schema",
|
117
|
+
"name": "row_schema",
|
118
|
+
"schema": schema,
|
119
|
+
"strict": True,
|
120
|
+
}
|
121
|
+
}
|
122
|
+
else:
|
123
|
+
text_cfg = {"format": {"type": "json_object"}}
|
124
|
+
resp = await client.responses.create(
|
125
|
+
model=model,
|
126
|
+
input=prompt,
|
127
|
+
temperature=settings.temperature,
|
128
|
+
max_output_tokens=settings.max_output_tokens,
|
129
|
+
text=text_cfg,
|
130
|
+
)
|
131
|
+
return _pick_text_from_response(resp) or None
|
132
|
+
except Exception as e:
|
133
|
+
if attempt == settings.max_attempts:
|
134
|
+
logging.error(f"Responses API error (final): {e}")
|
135
|
+
return None
|
136
|
+
logging.warning(
|
137
|
+
f"Responses API error (attempt {attempt}): {e} → retrying in {backoff:.1f}s"
|
138
|
+
)
|
139
|
+
await asyncio.sleep(backoff)
|
140
|
+
backoff *= settings.backoff_factor
|
141
|
+
return None
|
142
|
+
|
143
|
+
|
144
|
+
# =============================
|
145
|
+
# Schema helpers
|
146
|
+
# =============================
|
147
|
+
|
148
|
+
|
149
|
+
def load_schema(schema_path: Optional[str], prompt_path: Optional[Path]) -> Optional[Dict[str, Any]]:
|
150
|
+
path: Optional[Path] = None
|
151
|
+
if schema_path:
|
152
|
+
path = Path(schema_path).expanduser().resolve()
|
153
|
+
if not path.exists():
|
154
|
+
raise FileNotFoundError(f"Schema file not found: {path}")
|
155
|
+
logging.info("Using schema file: %s", path)
|
156
|
+
elif prompt_path is not None:
|
157
|
+
name = prompt_path.name
|
158
|
+
if name.endswith(".prompt.txt"):
|
159
|
+
cand = prompt_path.with_name(name[:-len(".prompt.txt")] + ".schema.json")
|
160
|
+
else:
|
161
|
+
cand = prompt_path.with_name(prompt_path.stem + ".schema.json")
|
162
|
+
if cand.exists():
|
163
|
+
path = cand
|
164
|
+
logging.info("Auto-discovered schema file: %s", path)
|
165
|
+
if path is None:
|
166
|
+
return None
|
167
|
+
with open(path, "r", encoding="utf-8") as f:
|
168
|
+
schema = json.load(f)
|
169
|
+
if not isinstance(schema, dict):
|
170
|
+
raise ValueError("Schema root must be a JSON object")
|
171
|
+
try:
|
172
|
+
if schema.get("type") == "object" and isinstance(schema.get("properties"), dict):
|
173
|
+
props = list(schema["properties"].keys())
|
174
|
+
req = schema.get("required")
|
175
|
+
if not isinstance(req, list):
|
176
|
+
schema["required"] = props
|
177
|
+
else:
|
178
|
+
missing = [k for k in props if k not in req]
|
179
|
+
if missing:
|
180
|
+
schema["required"] = req + missing
|
181
|
+
except Exception:
|
182
|
+
pass
|
183
|
+
return schema
|
184
|
+
|
185
|
+
|
186
|
+
# =============================
|
187
|
+
# File helpers
|
188
|
+
# =============================
|
189
|
+
|
190
|
+
|
191
|
+
def batched(items: List[Any], size: int) -> Iterable[List[Any]]:
|
192
|
+
for i in range(0, len(items), max(1, size)):
|
193
|
+
yield items[i : i + size]
|
194
|
+
|
195
|
+
|
196
|
+
# =============================
|
197
|
+
# Row processing
|
198
|
+
# =============================
|
199
|
+
|
200
|
+
|
201
|
+
class RowResult:
|
202
|
+
def __init__(self, id: str, data: Optional[Dict[str, Any]] = None, error: Optional[str] = None):
|
203
|
+
self.id = id
|
204
|
+
self.data = data or {}
|
205
|
+
self.error = error
|
206
|
+
|
207
|
+
@property
|
208
|
+
def ok(self) -> bool:
|
209
|
+
return self.error is None
|
210
|
+
|
211
|
+
|
212
|
+
async def process_row(
|
213
|
+
row_idx: int,
|
214
|
+
raw_row: Dict[str, Any],
|
215
|
+
client: AsyncOpenAI,
|
216
|
+
prompt_template: str,
|
217
|
+
model: str,
|
218
|
+
schema: Optional[Dict[str, Any]],
|
219
|
+
settings: Settings,
|
220
|
+
) -> RowResult:
|
221
|
+
row_id = (raw_row.get("id") or str(row_idx)).strip()
|
222
|
+
sanitized = sanitize_keys(raw_row)
|
223
|
+
sanitized["id"] = row_id
|
224
|
+
try:
|
225
|
+
prompt = render_prompt(prompt_template, sanitized, raw_row)
|
226
|
+
except Exception as e:
|
227
|
+
return RowResult(id=row_id, error=f"prompt_error: {e}")
|
228
|
+
raw = await call_openai_responses(prompt, client, model, schema, settings)
|
229
|
+
if not raw:
|
230
|
+
return RowResult(id=row_id, error="api_empty")
|
231
|
+
try:
|
232
|
+
enriched = json.loads(raw)
|
233
|
+
if isinstance(enriched, list):
|
234
|
+
enriched = (enriched[0] if enriched else {})
|
235
|
+
if not isinstance(enriched, dict):
|
236
|
+
return RowResult(id=row_id, error="json_type")
|
237
|
+
except Exception as e:
|
238
|
+
return RowResult(id=row_id, error=f"json_parse: {e}")
|
239
|
+
out = dict(sanitized)
|
240
|
+
out.update(sanitize_keys(enriched))
|
241
|
+
return RowResult(id=row_id, data=out)
|
242
|
+
|
243
|
+
|
244
|
+
# =============================
|
245
|
+
# Core processor class
|
246
|
+
# =============================
|
247
|
+
|
248
|
+
|
249
|
+
@dataclass
|
250
|
+
class ProcessorConfig:
|
251
|
+
input: str
|
252
|
+
prompt: Optional[str] = None
|
253
|
+
output: Optional[str] = None
|
254
|
+
schema: Optional[str] = None
|
255
|
+
limit: Optional[int] = None
|
256
|
+
model: Optional[str] = None
|
257
|
+
|
258
|
+
|
259
|
+
class CSVAIProcessor:
|
260
|
+
def __init__(self, config: ProcessorConfig, settings: Optional[Settings] = None):
|
261
|
+
self.settings = settings or Settings()
|
262
|
+
self.config = config
|
263
|
+
if not self.config.model:
|
264
|
+
self.config.model = self.settings.default_model
|
265
|
+
self.pause_event = asyncio.Event()
|
266
|
+
self.pause_event.set()
|
267
|
+
self.shutdown_event = asyncio.Event()
|
268
|
+
|
269
|
+
def pause(self) -> None:
|
270
|
+
self.pause_event.clear()
|
271
|
+
|
272
|
+
def resume(self) -> None:
|
273
|
+
self.pause_event.set()
|
274
|
+
|
275
|
+
def stop(self) -> None:
|
276
|
+
self.shutdown_event.set()
|
277
|
+
|
278
|
+
async def run(self) -> None:
|
279
|
+
args = self.config
|
280
|
+
input_path = Path(args.input).expanduser().resolve()
|
281
|
+
if not input_path.exists():
|
282
|
+
logging.error(f"Input file not found: {input_path}")
|
283
|
+
raise SystemExit(1)
|
284
|
+
|
285
|
+
output_path = default_output_file(input_path, args.output)
|
286
|
+
prompt_path = choose_prompt_file(input_path, args.prompt)
|
287
|
+
|
288
|
+
all_rows = read_rows(str(input_path))
|
289
|
+
if not all_rows:
|
290
|
+
logging.info("No rows to process.")
|
291
|
+
return
|
292
|
+
|
293
|
+
prompt_template = read_prompt(str(prompt_path))
|
294
|
+
schema = load_schema(args.schema, prompt_path)
|
295
|
+
|
296
|
+
client = get_async_client(self.settings)
|
297
|
+
|
298
|
+
existing_ids, existing_header = collect_existing_ids_and_header(output_path)
|
299
|
+
|
300
|
+
def src_id(i: int, r: Dict[str, str]) -> str:
|
301
|
+
return (r.get("id") or str(i)).strip()
|
302
|
+
|
303
|
+
pending: List[Tuple[int, Dict[str, str]]] = [
|
304
|
+
(i, r) for i, r in enumerate(all_rows) if src_id(i, r) not in existing_ids
|
305
|
+
]
|
306
|
+
limit = args.limit if args.limit is not None and args.limit >= 0 else None
|
307
|
+
if limit is not None:
|
308
|
+
pending = pending[:limit]
|
309
|
+
|
310
|
+
total_rows = len(all_rows)
|
311
|
+
scheduled = len(pending)
|
312
|
+
|
313
|
+
if scheduled == 0:
|
314
|
+
if len(existing_ids) == total_rows:
|
315
|
+
logging.info("Everything is already processed. Nothing to do.")
|
316
|
+
else:
|
317
|
+
logging.info("No rows scheduled (limit=0 or nothing new).")
|
318
|
+
await client.close()
|
319
|
+
return
|
320
|
+
|
321
|
+
logging.info(
|
322
|
+
"Plan → already_enriched=%d | pending_new=%d | scheduled_now=%d",
|
323
|
+
len(existing_ids), len(all_rows) - len(existing_ids), scheduled,
|
324
|
+
)
|
325
|
+
logging.info(
|
326
|
+
"Using model=%s, concurrency=%d, batch_size=%d",
|
327
|
+
args.model,
|
328
|
+
self.settings.max_concurrent_requests,
|
329
|
+
self.settings.processing_batch_size,
|
330
|
+
)
|
331
|
+
|
332
|
+
base_keys = [sanitize_key_name(k) for k in all_rows[0].keys()]
|
333
|
+
if "id" not in base_keys:
|
334
|
+
base_keys = ["id"] + base_keys
|
335
|
+
tentative_header: List[str] = list(dict.fromkeys(base_keys))
|
336
|
+
|
337
|
+
sem = asyncio.Semaphore(self.settings.max_concurrent_requests)
|
338
|
+
processed, failed = 0, 0
|
339
|
+
written_ids_this_run: Set[str] = set()
|
340
|
+
header: Optional[List[str]] = existing_header if existing_header else None
|
341
|
+
writer: Optional[RowWriter] = RowWriter(output_path, header) if header else None
|
342
|
+
|
343
|
+
try:
|
344
|
+
for bi, batch in enumerate(batched(pending, self.settings.processing_batch_size), start=1):
|
345
|
+
await self.pause_event.wait()
|
346
|
+
if self.shutdown_event.is_set():
|
347
|
+
break
|
348
|
+
|
349
|
+
async def run_one(idx: int, raw: Dict[str, Any]) -> RowResult:
|
350
|
+
async with sem:
|
351
|
+
return await process_row(idx, raw, client, prompt_template, args.model, schema, self.settings)
|
352
|
+
|
353
|
+
results = await asyncio.gather(
|
354
|
+
*(run_one(idx, raw) for idx, raw in batch), return_exceptions=True
|
355
|
+
)
|
356
|
+
|
357
|
+
successes: List[Dict[str, Any]] = []
|
358
|
+
batch_keys: List[str] = []
|
359
|
+
for res in results:
|
360
|
+
if isinstance(res, Exception):
|
361
|
+
failed += 1
|
362
|
+
logging.error("Unexpected error: %s", res)
|
363
|
+
continue
|
364
|
+
if not res.ok:
|
365
|
+
failed += 1
|
366
|
+
continue
|
367
|
+
successes.append(res.data)
|
368
|
+
for k in res.data.keys():
|
369
|
+
if k not in batch_keys:
|
370
|
+
batch_keys.append(k)
|
371
|
+
|
372
|
+
if not successes:
|
373
|
+
logging.info("Batch %d: no successful rows.", bi)
|
374
|
+
if self.shutdown_event.is_set():
|
375
|
+
logging.warning("Stopping after current batch.")
|
376
|
+
break
|
377
|
+
continue
|
378
|
+
|
379
|
+
if header is None:
|
380
|
+
header = list(dict.fromkeys(tentative_header + batch_keys))
|
381
|
+
writer = RowWriter(output_path, header)
|
382
|
+
|
383
|
+
if writer is not None:
|
384
|
+
writer.append(successes)
|
385
|
+
|
386
|
+
processed += len(successes)
|
387
|
+
for s in successes:
|
388
|
+
if "id" in s and s["id"] is not None:
|
389
|
+
written_ids_this_run.add(str(s["id"]).strip())
|
390
|
+
|
391
|
+
logging.info(
|
392
|
+
"Batch %d done | wrote=%d | total_this_run=%d/%d | failed=%d",
|
393
|
+
bi,
|
394
|
+
len(successes),
|
395
|
+
processed,
|
396
|
+
scheduled,
|
397
|
+
failed,
|
398
|
+
)
|
399
|
+
|
400
|
+
if self.shutdown_event.is_set():
|
401
|
+
logging.warning("Stopping after current batch.")
|
402
|
+
break
|
403
|
+
|
404
|
+
finally:
|
405
|
+
overall_done = len(existing_ids | written_ids_this_run)
|
406
|
+
remaining = max(0, total_rows - overall_done)
|
407
|
+
|
408
|
+
logging.info("Summary: output=%s", output_path)
|
409
|
+
logging.info(
|
410
|
+
"Totals: input=%d | processed_this_run=%d | processed_overall=%d | remaining=%d | failed=%d",
|
411
|
+
total_rows,
|
412
|
+
processed,
|
413
|
+
overall_done,
|
414
|
+
remaining,
|
415
|
+
failed,
|
416
|
+
)
|
417
|
+
if writer is not None:
|
418
|
+
writer.close()
|
419
|
+
await client.close()
|
420
|
+
|
csvai/settings.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
from dotenv import find_dotenv, load_dotenv
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class Settings:
|
11
|
+
"""Application configuration loaded from environment variables."""
|
12
|
+
|
13
|
+
openai_api_key: str = ""
|
14
|
+
default_model: str = "gpt-4o-mini"
|
15
|
+
max_output_tokens: int = 800
|
16
|
+
temperature: float = 0.2
|
17
|
+
max_concurrent_requests: int = 10
|
18
|
+
processing_batch_size: int = 50
|
19
|
+
request_timeout: float = 45.0
|
20
|
+
max_attempts: int = 4
|
21
|
+
initial_backoff: float = 2.0
|
22
|
+
backoff_factor: float = 1.7
|
23
|
+
|
24
|
+
def __post_init__(self) -> None:
|
25
|
+
load_dotenv(find_dotenv())
|
26
|
+
self.openai_api_key = os.getenv("OPENAI_API_KEY", self.openai_api_key)
|
27
|
+
self.default_model = os.getenv("DEFAULT_MODEL", self.default_model)
|
28
|
+
self.max_output_tokens = int(
|
29
|
+
os.getenv("MAX_OUTPUT_TOKENS", os.getenv("MAX_TOKENS", self.max_output_tokens))
|
30
|
+
)
|
31
|
+
self.temperature = float(os.getenv("TEMPERATURE", self.temperature))
|
32
|
+
self.max_concurrent_requests = int(os.getenv("MAX_CONCURRENT_REQUESTS", self.max_concurrent_requests))
|
33
|
+
self.processing_batch_size = int(os.getenv("PROCESSING_BATCH_SIZE", self.processing_batch_size))
|
34
|
+
self.request_timeout = float(os.getenv("REQUEST_TIMEOUT", self.request_timeout))
|
35
|
+
self.max_attempts = int(os.getenv("MAX_ATTEMPTS", self.max_attempts))
|
36
|
+
self.initial_backoff = float(os.getenv("INITIAL_BACKOFF", self.initial_backoff))
|
37
|
+
self.backoff_factor = float(os.getenv("BACKOFF_FACTOR", self.backoff_factor))
|
38
|
+
if self.max_concurrent_requests <= 0:
|
39
|
+
raise ValueError("max_concurrent_requests must be positive")
|
40
|
+
if self.processing_batch_size <= 0:
|
41
|
+
raise ValueError("processing_batch_size must be positive")
|
42
|
+
|