openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""
|
|
2
|
+
mapmaker.py
|
|
3
|
+
~~~~~~~~~~~~
|
|
4
|
+
|
|
5
|
+
This module implements a self‑contained helper class for generating
|
|
6
|
+
choropleth maps from ranked GABRIEL outputs. The ``MapMaker`` can
|
|
7
|
+
produce county‑, state‑ or country‑level maps depending on the
|
|
8
|
+
``map_type`` parameter, and it optionally normalises values as
|
|
9
|
+
z‑scores. All mapping logic resides in this file, replacing the need
|
|
10
|
+
for the separate ``create_county_choropleth`` function.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import json
|
|
17
|
+
import requests
|
|
18
|
+
from typing import Iterable, Optional
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
import plotly.express as px
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MapMaker:
|
|
26
|
+
"""Utility for generating geographic choropleth maps from a data frame.
|
|
27
|
+
|
|
28
|
+
The input data frame should contain at least one column of numeric
|
|
29
|
+
scores and one or more columns identifying the geographic unit
|
|
30
|
+
(county FIPS codes, two‑letter US state abbreviations or ISO‑3
|
|
31
|
+
country codes). Individual maps are rendered using Plotly and
|
|
32
|
+
written to ``save_dir`` with names derived from the value column.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
df:
|
|
37
|
+
DataFrame containing the data to plot. Each row should
|
|
38
|
+
correspond to a geographic region.
|
|
39
|
+
fips_col:
|
|
40
|
+
Name of the column containing five‑digit county FIPS codes.
|
|
41
|
+
state_col:
|
|
42
|
+
Name of the column containing two‑letter US state abbreviations.
|
|
43
|
+
country_col:
|
|
44
|
+
Name of the column containing ISO‑3 country codes.
|
|
45
|
+
save_dir:
|
|
46
|
+
Directory to which map files will be written. If ``None``,
|
|
47
|
+
a ``maps`` subdirectory in the current working directory is used.
|
|
48
|
+
z_score:
|
|
49
|
+
Whether to convert values to z‑scores before plotting.
|
|
50
|
+
color_scale:
|
|
51
|
+
Name of the Plotly colour scale to apply. Defaults to
|
|
52
|
+
``"RdBu"`` (diverging) when z‑scores are enabled and
|
|
53
|
+
``"Viridis"`` otherwise.
|
|
54
|
+
map_type:
|
|
55
|
+
Determines the map produced: ``"county"``, ``"state"``
|
|
56
|
+
or ``"country"`` (with ``"global"`` as an alias).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
df: pd.DataFrame,
|
|
62
|
+
*,
|
|
63
|
+
fips_col: Optional[str] = None,
|
|
64
|
+
state_col: Optional[str] = None,
|
|
65
|
+
country_col: Optional[str] = None,
|
|
66
|
+
save_dir: Optional[str] = None,
|
|
67
|
+
z_score: bool = True,
|
|
68
|
+
color_scale: str = "RdBu",
|
|
69
|
+
map_type: str = "county",
|
|
70
|
+
) -> None:
|
|
71
|
+
self.df = df.copy()
|
|
72
|
+
self.fips_col = fips_col
|
|
73
|
+
self.state_col = state_col
|
|
74
|
+
self.country_col = country_col
|
|
75
|
+
|
|
76
|
+
# normalise map_type and validate
|
|
77
|
+
map_type = map_type.lower()
|
|
78
|
+
if map_type not in {"county", "state", "country", "global"}:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"map_type must be one of 'county', 'state', 'country' or 'global'"
|
|
81
|
+
)
|
|
82
|
+
self.map_type = "country" if map_type == "global" else map_type
|
|
83
|
+
|
|
84
|
+
# choose save directory
|
|
85
|
+
if save_dir is None:
|
|
86
|
+
save_dir = os.path.join(os.getcwd(), "maps")
|
|
87
|
+
save_dir = os.path.expandvars(os.path.expanduser(save_dir))
|
|
88
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
89
|
+
self.save_dir = save_dir
|
|
90
|
+
|
|
91
|
+
self.z_score = z_score
|
|
92
|
+
self.color_scale = color_scale
|
|
93
|
+
|
|
94
|
+
def _compute_zscore(self, values: np.ndarray) -> np.ndarray:
|
|
95
|
+
"""Compute z‑scores with safe handling of NaNs and constant arrays."""
|
|
96
|
+
vals = values.astype(float)
|
|
97
|
+
if len(vals) > 1 and np.nanstd(vals) > 0:
|
|
98
|
+
return (vals - np.nanmean(vals)) / np.nanstd(vals)
|
|
99
|
+
return np.zeros_like(vals)
|
|
100
|
+
|
|
101
|
+
def _create_state_choropleth(
|
|
102
|
+
self,
|
|
103
|
+
df: pd.DataFrame,
|
|
104
|
+
state_col: str,
|
|
105
|
+
value_col: str,
|
|
106
|
+
title: str,
|
|
107
|
+
save_path: str,
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Create and save a state‑level choropleth."""
|
|
110
|
+
plot_col = value_col
|
|
111
|
+
colour_scale = self.color_scale
|
|
112
|
+
df_local = df.copy()
|
|
113
|
+
if self.z_score:
|
|
114
|
+
zs = self._compute_zscore(df_local[value_col].values)
|
|
115
|
+
plot_col = f"_zscore_{value_col}"
|
|
116
|
+
df_local[plot_col] = zs
|
|
117
|
+
colour_scale = "RdBu" if self.color_scale == "RdBu" else "PuOr"
|
|
118
|
+
fig = px.choropleth(
|
|
119
|
+
df_local,
|
|
120
|
+
locations=state_col,
|
|
121
|
+
locationmode="USA-states",
|
|
122
|
+
color=plot_col,
|
|
123
|
+
color_continuous_scale=colour_scale,
|
|
124
|
+
scope="usa",
|
|
125
|
+
hover_data={state_col: True, value_col: True},
|
|
126
|
+
)
|
|
127
|
+
if self.z_score:
|
|
128
|
+
fig.update_coloraxes(cmid=0)
|
|
129
|
+
ext = os.path.splitext(save_path)[1].lower()
|
|
130
|
+
if ext in {".png", ".jpg", ".jpeg"}:
|
|
131
|
+
fig.write_image(save_path, scale=3)
|
|
132
|
+
else:
|
|
133
|
+
fig.write_html(save_path)
|
|
134
|
+
|
|
135
|
+
def _create_country_choropleth(
|
|
136
|
+
self,
|
|
137
|
+
df: pd.DataFrame,
|
|
138
|
+
country_col: str,
|
|
139
|
+
value_col: str,
|
|
140
|
+
title: str,
|
|
141
|
+
save_path: str,
|
|
142
|
+
) -> None:
|
|
143
|
+
"""Create and save a global choropleth using ISO‑3 codes."""
|
|
144
|
+
plot_col = value_col
|
|
145
|
+
colour_scale = self.color_scale
|
|
146
|
+
df_local = df.copy()
|
|
147
|
+
if self.z_score:
|
|
148
|
+
zs = self._compute_zscore(df_local[value_col].values)
|
|
149
|
+
plot_col = f"_zscore_{value_col}"
|
|
150
|
+
df_local[plot_col] = zs
|
|
151
|
+
colour_scale = "RdBu" if self.color_scale == "RdBu" else "PuOr"
|
|
152
|
+
fig = px.choropleth(
|
|
153
|
+
df_local,
|
|
154
|
+
locations=country_col,
|
|
155
|
+
locationmode="ISO-3",
|
|
156
|
+
color=plot_col,
|
|
157
|
+
color_continuous_scale=colour_scale,
|
|
158
|
+
scope="world",
|
|
159
|
+
hover_data={country_col: True, value_col: True},
|
|
160
|
+
)
|
|
161
|
+
if self.z_score:
|
|
162
|
+
fig.update_coloraxes(cmid=0)
|
|
163
|
+
ext = os.path.splitext(save_path)[1].lower()
|
|
164
|
+
if ext in {".png", ".jpg", ".jpeg"}:
|
|
165
|
+
fig.write_image(save_path, scale=3)
|
|
166
|
+
else:
|
|
167
|
+
fig.write_html(save_path)
|
|
168
|
+
|
|
169
|
+
def _create_county_choropleth(
|
|
170
|
+
self,
|
|
171
|
+
df: pd.DataFrame,
|
|
172
|
+
fips_col: str,
|
|
173
|
+
value_col: str,
|
|
174
|
+
title: str,
|
|
175
|
+
save_path: str,
|
|
176
|
+
) -> None:
|
|
177
|
+
"""Create and save a county‑level choropleth with FIPS codes.
|
|
178
|
+
|
|
179
|
+
This method inlines the logic of the old ``create_county_choropleth``
|
|
180
|
+
function to avoid external dependencies. It downloads a GeoJSON of
|
|
181
|
+
U.S. counties on first use and caches it in ``~/.cache/county_geo.json``.
|
|
182
|
+
"""
|
|
183
|
+
# pad FIPS codes to five digits
|
|
184
|
+
df_local = df.copy()
|
|
185
|
+
df_local[fips_col] = df_local[fips_col].astype(str).str.zfill(5)
|
|
186
|
+
|
|
187
|
+
# find a county name column for hover text
|
|
188
|
+
county_col = None
|
|
189
|
+
for cand in ["county", "County", "region", "Region"]:
|
|
190
|
+
if cand in df_local.columns:
|
|
191
|
+
county_col = cand
|
|
192
|
+
break
|
|
193
|
+
if county_col is None:
|
|
194
|
+
county_col = "_county_name"
|
|
195
|
+
df_local[county_col] = ""
|
|
196
|
+
|
|
197
|
+
# load or download county GeoJSON
|
|
198
|
+
geojson_url = (
|
|
199
|
+
"https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json"
|
|
200
|
+
)
|
|
201
|
+
cache_path = os.path.join(os.path.expanduser("~"), ".cache", "county_geo.json")
|
|
202
|
+
if not os.path.exists(cache_path):
|
|
203
|
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
|
204
|
+
resp = requests.get(geojson_url, timeout=30)
|
|
205
|
+
resp.raise_for_status()
|
|
206
|
+
with open(cache_path, "w", encoding="utf-8") as f:
|
|
207
|
+
f.write(resp.text)
|
|
208
|
+
with open(cache_path, encoding="utf-8") as f:
|
|
209
|
+
counties = json.load(f)
|
|
210
|
+
|
|
211
|
+
# prepare value column (with optional z‑score)
|
|
212
|
+
plot_col = value_col
|
|
213
|
+
colour_scale = self.color_scale
|
|
214
|
+
if self.z_score:
|
|
215
|
+
zs = self._compute_zscore(df_local[value_col].values)
|
|
216
|
+
plot_col = f"_zscore_{value_col}"
|
|
217
|
+
df_local[plot_col] = zs
|
|
218
|
+
colour_scale = "RdBu" if self.color_scale == "RdBu" else "PuOr"
|
|
219
|
+
|
|
220
|
+
hover_data = {county_col: True, fips_col: True, value_col: True}
|
|
221
|
+
fig = px.choropleth(
|
|
222
|
+
df_local,
|
|
223
|
+
geojson=counties,
|
|
224
|
+
locations=fips_col,
|
|
225
|
+
color=plot_col,
|
|
226
|
+
color_continuous_scale=colour_scale,
|
|
227
|
+
scope="usa",
|
|
228
|
+
hover_data=hover_data,
|
|
229
|
+
)
|
|
230
|
+
if self.z_score:
|
|
231
|
+
fig.update_coloraxes(cmid=0)
|
|
232
|
+
ext = os.path.splitext(save_path)[1].lower()
|
|
233
|
+
if ext in {".png", ".jpg", ".jpeg"}:
|
|
234
|
+
fig.write_image(save_path, scale=3)
|
|
235
|
+
else:
|
|
236
|
+
fig.write_html(save_path)
|
|
237
|
+
|
|
238
|
+
def make_maps(self, value_cols: Iterable[str]) -> None:
|
|
239
|
+
"""Generate and save maps for each specified numeric column.
|
|
240
|
+
|
|
241
|
+
The map type is determined by ``self.map_type``: ``"county"``
|
|
242
|
+
uses FIPS codes; ``"state"`` uses two‑letter abbreviations; and
|
|
243
|
+
``"country"`` uses ISO‑3 codes.
|
|
244
|
+
"""
|
|
245
|
+
for value_col in value_cols:
|
|
246
|
+
if self.map_type == "county":
|
|
247
|
+
if not self.fips_col:
|
|
248
|
+
raise ValueError("fips_col must be provided for county maps")
|
|
249
|
+
fname = f"county_map_{value_col}.html"
|
|
250
|
+
save_path = os.path.join(self.save_dir, fname)
|
|
251
|
+
self._create_county_choropleth(
|
|
252
|
+
self.df,
|
|
253
|
+
self.fips_col,
|
|
254
|
+
value_col,
|
|
255
|
+
title=f"County Map for {value_col}",
|
|
256
|
+
save_path=save_path,
|
|
257
|
+
)
|
|
258
|
+
elif self.map_type == "state":
|
|
259
|
+
if not self.state_col:
|
|
260
|
+
raise ValueError("state_col must be provided for state maps")
|
|
261
|
+
fname = f"state_map_{value_col}.html"
|
|
262
|
+
save_path = os.path.join(self.save_dir, fname)
|
|
263
|
+
self._create_state_choropleth(
|
|
264
|
+
self.df,
|
|
265
|
+
self.state_col,
|
|
266
|
+
value_col,
|
|
267
|
+
title=f"State Map for {value_col}",
|
|
268
|
+
save_path=save_path,
|
|
269
|
+
)
|
|
270
|
+
elif self.map_type == "country":
|
|
271
|
+
if not self.country_col:
|
|
272
|
+
raise ValueError("country_col must be provided for country maps")
|
|
273
|
+
fname = f"country_map_{value_col}.html"
|
|
274
|
+
save_path = os.path.join(self.save_dir, fname)
|
|
275
|
+
self._create_country_choropleth(
|
|
276
|
+
self.df,
|
|
277
|
+
self.country_col,
|
|
278
|
+
value_col,
|
|
279
|
+
title=f"Country Map for {value_col}",
|
|
280
|
+
save_path=save_path,
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
# should not happen due to validation in __init__
|
|
284
|
+
raise ValueError(f"Unsupported map type: {self.map_type}")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def create_county_choropleth(
|
|
288
|
+
df: pd.DataFrame,
|
|
289
|
+
*,
|
|
290
|
+
fips_col: str,
|
|
291
|
+
value_col: str,
|
|
292
|
+
title: str,
|
|
293
|
+
save_path: str,
|
|
294
|
+
z_score: bool = True,
|
|
295
|
+
) -> None:
|
|
296
|
+
"""Backward compatible helper to generate a county-level choropleth.
|
|
297
|
+
|
|
298
|
+
This thin wrapper instantiates :class:`MapMaker` and delegates to its
|
|
299
|
+
internal implementation. It mirrors the signature of the legacy
|
|
300
|
+
``create_county_choropleth`` function used elsewhere in the codebase.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
mm = MapMaker(df, fips_col=fips_col, z_score=z_score, save_dir=None, map_type="county")
|
|
304
|
+
mm._create_county_choropleth(df, fips_col, value_col, title, save_path)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from .image_utils import encode_image
|
|
7
|
+
from .audio_utils import encode_audio
|
|
8
|
+
from .pdf_utils import encode_pdf
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_image_inputs(val: Any) -> List[str]:
|
|
12
|
+
"""Return a list of base64-encoded images from a DataFrame cell.
|
|
13
|
+
|
|
14
|
+
``val`` may be a single file path, a list of file paths, or a list of
|
|
15
|
+
pre-encoded base64 strings. Non-existing paths are ignored.
|
|
16
|
+
"""
|
|
17
|
+
if not val:
|
|
18
|
+
return []
|
|
19
|
+
imgs = val if isinstance(val, list) else [val]
|
|
20
|
+
encoded: List[str] = []
|
|
21
|
+
for img in imgs:
|
|
22
|
+
if isinstance(img, str) and os.path.exists(img):
|
|
23
|
+
enc = encode_image(img)
|
|
24
|
+
if enc:
|
|
25
|
+
encoded.append(enc)
|
|
26
|
+
elif isinstance(img, str):
|
|
27
|
+
encoded.append(img)
|
|
28
|
+
return encoded
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def load_audio_inputs(val: Any) -> List[Dict[str, str]]:
|
|
32
|
+
"""Return a list of audio dicts from a DataFrame cell.
|
|
33
|
+
|
|
34
|
+
``val`` may be a single file path, a list of file paths, or a list of
|
|
35
|
+
already-encoded dicts. Non-existing paths are ignored.
|
|
36
|
+
"""
|
|
37
|
+
if not val:
|
|
38
|
+
return []
|
|
39
|
+
auds = val if isinstance(val, list) else [val]
|
|
40
|
+
encoded: List[Dict[str, str]] = []
|
|
41
|
+
for aud in auds:
|
|
42
|
+
if isinstance(aud, str) and os.path.exists(aud):
|
|
43
|
+
enc = encode_audio(aud)
|
|
44
|
+
if enc:
|
|
45
|
+
encoded.append(enc)
|
|
46
|
+
elif isinstance(aud, dict):
|
|
47
|
+
encoded.append(aud)
|
|
48
|
+
return encoded
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_pdf_inputs(val: Any) -> List[Dict[str, str]]:
|
|
52
|
+
"""Return a list of PDF dicts from a DataFrame cell."""
|
|
53
|
+
if not val:
|
|
54
|
+
return []
|
|
55
|
+
pdfs = val if isinstance(val, list) else [val]
|
|
56
|
+
encoded: List[Dict[str, str]] = []
|
|
57
|
+
for pdf in pdfs:
|
|
58
|
+
if isinstance(pdf, str) and os.path.exists(pdf):
|
|
59
|
+
enc = encode_pdf(pdf)
|
|
60
|
+
if enc:
|
|
61
|
+
encoded.append(enc)
|
|
62
|
+
elif isinstance(pdf, dict):
|
|
63
|
+
encoded.append(pdf)
|
|
64
|
+
elif isinstance(pdf, str):
|
|
65
|
+
lowered = pdf.lower()
|
|
66
|
+
if lowered.startswith("data:application/pdf"):
|
|
67
|
+
encoded.append({"filename": "document.pdf", "file_data": pdf})
|
|
68
|
+
elif lowered.startswith("http://") or lowered.startswith("https://"):
|
|
69
|
+
if lowered.endswith(".pdf"):
|
|
70
|
+
encoded.append({"file_url": pdf})
|
|
71
|
+
else:
|
|
72
|
+
encoded.append(
|
|
73
|
+
{
|
|
74
|
+
"filename": "document.pdf",
|
|
75
|
+
"file_data": f"data:application/pdf;base64,{pdf}",
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
return encoded
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Iterable, List, Tuple
|
|
6
|
+
|
|
7
|
+
from .file_utils import AUDIO_EXTENSIONS, IMAGE_EXTENSIONS, PDF_EXTENSIONS
|
|
8
|
+
|
|
9
|
+
IMAGE_SUFFIXES = {ext.lower() for ext in IMAGE_EXTENSIONS}
|
|
10
|
+
AUDIO_SUFFIXES = {ext.lower() for ext in AUDIO_EXTENSIONS}
|
|
11
|
+
PDF_SUFFIXES = {ext.lower() for ext in PDF_EXTENSIONS}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def warn_if_modality_mismatch(
|
|
15
|
+
values: Iterable[Any],
|
|
16
|
+
modality: str,
|
|
17
|
+
*,
|
|
18
|
+
column_name: str,
|
|
19
|
+
sample_size: int = 100,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Inspect a sample of values and warn if modality likely mismatches."""
|
|
22
|
+
|
|
23
|
+
sample = list(values)[:sample_size]
|
|
24
|
+
if not sample:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
pdf_hits = 0
|
|
28
|
+
image_hits = 0
|
|
29
|
+
audio_hits = 0
|
|
30
|
+
text_word_counts: List[int] = []
|
|
31
|
+
|
|
32
|
+
for value in sample:
|
|
33
|
+
candidate = _coerce_first_value(value)
|
|
34
|
+
if candidate is None:
|
|
35
|
+
continue
|
|
36
|
+
kind = _detect_media_kind(candidate)
|
|
37
|
+
if kind == "pdf":
|
|
38
|
+
pdf_hits += 1
|
|
39
|
+
continue
|
|
40
|
+
if kind == "image":
|
|
41
|
+
image_hits += 1
|
|
42
|
+
continue
|
|
43
|
+
if kind == "audio":
|
|
44
|
+
audio_hits += 1
|
|
45
|
+
continue
|
|
46
|
+
if isinstance(candidate, str):
|
|
47
|
+
words = re.findall(r"\\b\\w+\\b", candidate)
|
|
48
|
+
text_word_counts.append(len(words))
|
|
49
|
+
|
|
50
|
+
total = len(sample)
|
|
51
|
+
if total == 0:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
if pdf_hits == total and modality != "pdf":
|
|
55
|
+
print(
|
|
56
|
+
f"[gabriel] Detected PDFs in column '{column_name}'. "
|
|
57
|
+
"Set modality='pdf' to attach files directly, or set modality='text' "
|
|
58
|
+
"(or 'entity'/'web') to extract plain text instead."
|
|
59
|
+
)
|
|
60
|
+
if image_hits == total and modality != "image":
|
|
61
|
+
print(
|
|
62
|
+
f"[gabriel] Detected image-like inputs in column '{column_name}'. "
|
|
63
|
+
"Set modality='image' to attach image files correctly."
|
|
64
|
+
)
|
|
65
|
+
if audio_hits == total and modality != "audio":
|
|
66
|
+
print(
|
|
67
|
+
f"[gabriel] Detected audio-like inputs in column '{column_name}'. "
|
|
68
|
+
"Set modality='audio' to attach audio files correctly."
|
|
69
|
+
)
|
|
70
|
+
if modality == "pdf" and pdf_hits == 0:
|
|
71
|
+
print(
|
|
72
|
+
f"[gabriel] Column '{column_name}' doesn't look like PDF inputs. "
|
|
73
|
+
"If this is text, consider modality='text' (or 'entity'/'web')."
|
|
74
|
+
)
|
|
75
|
+
if modality == "image" and image_hits == 0:
|
|
76
|
+
print(
|
|
77
|
+
f"[gabriel] Column '{column_name}' doesn't look like image inputs. "
|
|
78
|
+
"If this is text, consider modality='text' (or 'entity'/'web')."
|
|
79
|
+
)
|
|
80
|
+
if modality == "audio" and audio_hits == 0:
|
|
81
|
+
print(
|
|
82
|
+
f"[gabriel] Column '{column_name}' doesn't look like audio inputs. "
|
|
83
|
+
"If this is text, consider modality='text' (or 'entity'/'web')."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if modality == "text" and text_word_counts:
|
|
87
|
+
avg_words = sum(text_word_counts) / max(1, len(text_word_counts))
|
|
88
|
+
if avg_words < 10:
|
|
89
|
+
print(
|
|
90
|
+
f"[gabriel] Average word count in column '{column_name}' is {avg_words:.1f} "
|
|
91
|
+
"words. Confirm you intended modality='text' (rather than 'entity' or 'web')."
|
|
92
|
+
)
|
|
93
|
+
if modality in {"entity", "web"} and text_word_counts:
|
|
94
|
+
avg_words = sum(text_word_counts) / max(1, len(text_word_counts))
|
|
95
|
+
if avg_words > 30:
|
|
96
|
+
print(
|
|
97
|
+
f"[gabriel] Average word count in column '{column_name}' is {avg_words:.1f} "
|
|
98
|
+
"words. Consider modality='text' for long-form passages."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _coerce_first_value(value: Any) -> Any:
|
|
103
|
+
if value is None:
|
|
104
|
+
return None
|
|
105
|
+
if isinstance(value, (list, tuple)):
|
|
106
|
+
return value[0] if value else None
|
|
107
|
+
return value
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _detect_media_kind(value: Any) -> str:
|
|
111
|
+
if isinstance(value, dict):
|
|
112
|
+
file_data = value.get("file_data")
|
|
113
|
+
file_url = value.get("file_url")
|
|
114
|
+
if isinstance(file_data, str) and file_data.startswith("data:application/pdf"):
|
|
115
|
+
return "pdf"
|
|
116
|
+
if isinstance(file_url, str) and file_url.lower().endswith(".pdf"):
|
|
117
|
+
return "pdf"
|
|
118
|
+
if value.get("format"):
|
|
119
|
+
return "audio"
|
|
120
|
+
if not isinstance(value, str):
|
|
121
|
+
return ""
|
|
122
|
+
lowered = value.lower().strip()
|
|
123
|
+
if lowered.startswith("data:application/pdf"):
|
|
124
|
+
return "pdf"
|
|
125
|
+
if lowered.startswith("data:image/"):
|
|
126
|
+
return "image"
|
|
127
|
+
if lowered.startswith("data:audio/"):
|
|
128
|
+
return "audio"
|
|
129
|
+
if _looks_like_url(lowered):
|
|
130
|
+
if lowered.endswith(tuple(PDF_SUFFIXES)):
|
|
131
|
+
return "pdf"
|
|
132
|
+
if lowered.endswith(tuple(IMAGE_SUFFIXES)):
|
|
133
|
+
return "image"
|
|
134
|
+
if lowered.endswith(tuple(AUDIO_SUFFIXES)):
|
|
135
|
+
return "audio"
|
|
136
|
+
if os.path.exists(value):
|
|
137
|
+
ext = os.path.splitext(value)[1].lower()
|
|
138
|
+
if ext in PDF_SUFFIXES:
|
|
139
|
+
return "pdf"
|
|
140
|
+
if ext in IMAGE_SUFFIXES:
|
|
141
|
+
return "image"
|
|
142
|
+
if ext in AUDIO_SUFFIXES:
|
|
143
|
+
return "audio"
|
|
144
|
+
return ""
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _looks_like_url(value: str) -> bool:
|
|
148
|
+
return value.startswith("http://") or value.startswith("https://")
|