labelr 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- labelr/apps/datasets.py +56 -5
- labelr/apps/google_batch.py +296 -0
- labelr/apps/label_studio.py +1 -1
- labelr/export/classification.py +114 -0
- labelr/export/common.py +42 -0
- labelr/export/llm.py +91 -0
- labelr/{export.py → export/object_detection.py} +3 -138
- labelr/google_genai.py +421 -0
- labelr/main.py +6 -0
- labelr/sample/__init__.py +0 -0
- labelr/sample/classification.py +17 -0
- labelr/sample/common.py +14 -0
- labelr/sample/llm.py +75 -0
- labelr/{sample.py → sample/object_detection.py} +0 -17
- labelr/utils.py +85 -0
- {labelr-0.8.0.dist-info → labelr-0.10.0.dist-info}/METADATA +9 -1
- labelr-0.10.0.dist-info/RECORD +36 -0
- labelr-0.8.0.dist-info/RECORD +0 -27
- /labelr/{evaluate/llm.py → export/__init__.py} +0 -0
- {labelr-0.8.0.dist-info → labelr-0.10.0.dist-info}/WHEEL +0 -0
- {labelr-0.8.0.dist-info → labelr-0.10.0.dist-info}/entry_points.txt +0 -0
- {labelr-0.8.0.dist-info → labelr-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {labelr-0.8.0.dist-info → labelr-0.10.0.dist-info}/top_level.txt +0 -0
|
@@ -1,34 +1,23 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import pickle
|
|
4
|
-
import random
|
|
5
4
|
import tempfile
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
|
|
8
7
|
import datasets
|
|
9
8
|
import tqdm
|
|
10
9
|
from label_studio_sdk.client import LabelStudio
|
|
11
|
-
from openfoodfacts.images import download_image
|
|
12
|
-
from openfoodfacts.types import Flavor
|
|
13
|
-
from PIL import Image, ImageOps
|
|
10
|
+
from openfoodfacts.images import download_image
|
|
14
11
|
|
|
15
|
-
from labelr.
|
|
16
|
-
|
|
12
|
+
from labelr.export.common import _pickle_sample_generator
|
|
13
|
+
from labelr.sample.object_detection import (
|
|
17
14
|
HF_DS_OBJECT_DETECTION_FEATURES,
|
|
18
15
|
format_object_detection_sample_to_hf,
|
|
19
16
|
)
|
|
20
|
-
from labelr.types import TaskType
|
|
21
17
|
|
|
22
18
|
logger = logging.getLogger(__name__)
|
|
23
19
|
|
|
24
20
|
|
|
25
|
-
def _pickle_sample_generator(dir: Path):
|
|
26
|
-
"""Generator that yields samples from pickles in a directory."""
|
|
27
|
-
for pkl in dir.glob("*.pkl"):
|
|
28
|
-
with open(pkl, "rb") as f:
|
|
29
|
-
yield pickle.load(f)
|
|
30
|
-
|
|
31
|
-
|
|
32
21
|
def export_from_ls_to_hf_object_detection(
|
|
33
22
|
ls: LabelStudio,
|
|
34
23
|
repo_id: str,
|
|
@@ -331,127 +320,3 @@ def export_from_hf_to_ultralytics_object_detection(
|
|
|
331
320
|
f.write("names:\n")
|
|
332
321
|
for i, category_name in enumerate(category_names):
|
|
333
322
|
f.write(f" {i}: {category_name}\n")
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
def export_from_ultralytics_to_hf(
|
|
337
|
-
task_type: TaskType,
|
|
338
|
-
dataset_dir: Path,
|
|
339
|
-
repo_id: str,
|
|
340
|
-
label_names: list[str],
|
|
341
|
-
merge_labels: bool = False,
|
|
342
|
-
is_openfoodfacts_dataset: bool = False,
|
|
343
|
-
openfoodfacts_flavor: Flavor = Flavor.off,
|
|
344
|
-
) -> None:
|
|
345
|
-
if task_type != TaskType.classification:
|
|
346
|
-
raise NotImplementedError(
|
|
347
|
-
"Only classification task is currently supported for Ultralytics to HF export"
|
|
348
|
-
)
|
|
349
|
-
|
|
350
|
-
if task_type == TaskType.classification:
|
|
351
|
-
export_from_ultralytics_to_hf_classification(
|
|
352
|
-
dataset_dir=dataset_dir,
|
|
353
|
-
repo_id=repo_id,
|
|
354
|
-
label_names=label_names,
|
|
355
|
-
merge_labels=merge_labels,
|
|
356
|
-
is_openfoodfacts_dataset=is_openfoodfacts_dataset,
|
|
357
|
-
openfoodfacts_flavor=openfoodfacts_flavor,
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
def export_from_ultralytics_to_hf_classification(
|
|
362
|
-
dataset_dir: Path,
|
|
363
|
-
repo_id: str,
|
|
364
|
-
label_names: list[str],
|
|
365
|
-
merge_labels: bool = False,
|
|
366
|
-
is_openfoodfacts_dataset: bool = False,
|
|
367
|
-
openfoodfacts_flavor: Flavor = Flavor.off,
|
|
368
|
-
) -> None:
|
|
369
|
-
"""Export an Ultralytics classification dataset to a Hugging Face dataset.
|
|
370
|
-
|
|
371
|
-
The Ultralytics dataset directory should contain 'train', 'val' and/or
|
|
372
|
-
'test' subdirectories, each containing subdirectories for each label.
|
|
373
|
-
|
|
374
|
-
Args:
|
|
375
|
-
dataset_dir (Path): Path to the Ultralytics dataset directory.
|
|
376
|
-
repo_id (str): Hugging Face repository ID to push the dataset to.
|
|
377
|
-
label_names (list[str]): List of label names.
|
|
378
|
-
merge_labels (bool): Whether to merge all labels into a single label
|
|
379
|
-
named 'object'.
|
|
380
|
-
is_openfoodfacts_dataset (bool): Whether the dataset is from
|
|
381
|
-
Open Food Facts. If True, the `off_image_id` and `image_url` will
|
|
382
|
-
be generated automatically. `off_image_id` is extracted from the
|
|
383
|
-
image filename.
|
|
384
|
-
openfoodfacts_flavor (Flavor): Flavor of Open Food Facts dataset. This
|
|
385
|
-
is ignored if `is_openfoodfacts_dataset` is False.
|
|
386
|
-
"""
|
|
387
|
-
logger.info("Repo ID: %s, dataset_dir: %s", repo_id, dataset_dir)
|
|
388
|
-
|
|
389
|
-
if not any((dataset_dir / split).is_dir() for split in ["train", "val", "test"]):
|
|
390
|
-
raise ValueError(
|
|
391
|
-
f"Dataset directory {dataset_dir} does not contain 'train', 'val' or 'test' subdirectories"
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
# Save output as pickle
|
|
395
|
-
for split in ["train", "val", "test"]:
|
|
396
|
-
split_dir = dataset_dir / split
|
|
397
|
-
|
|
398
|
-
if not split_dir.is_dir():
|
|
399
|
-
logger.info("Skipping missing split directory: %s", split_dir)
|
|
400
|
-
continue
|
|
401
|
-
|
|
402
|
-
with tempfile.TemporaryDirectory() as tmp_dir_str:
|
|
403
|
-
tmp_dir = Path(tmp_dir_str)
|
|
404
|
-
for label_dir in (d for d in split_dir.iterdir() if d.is_dir()):
|
|
405
|
-
label_name = label_dir.name
|
|
406
|
-
if merge_labels:
|
|
407
|
-
label_name = "object"
|
|
408
|
-
if label_name not in label_names:
|
|
409
|
-
raise ValueError(
|
|
410
|
-
"Label name %s not in provided label names (label names: %s)"
|
|
411
|
-
% (label_name, label_names),
|
|
412
|
-
)
|
|
413
|
-
label_id = label_names.index(label_name)
|
|
414
|
-
|
|
415
|
-
for image_path in label_dir.glob("*"):
|
|
416
|
-
if is_openfoodfacts_dataset:
|
|
417
|
-
image_stem_parts = image_path.stem.split("_")
|
|
418
|
-
barcode = image_stem_parts[0]
|
|
419
|
-
off_image_id = image_stem_parts[1]
|
|
420
|
-
image_id = f"{barcode}_{off_image_id}"
|
|
421
|
-
image_url = generate_image_url(
|
|
422
|
-
barcode, off_image_id, flavor=openfoodfacts_flavor
|
|
423
|
-
)
|
|
424
|
-
else:
|
|
425
|
-
image_id = image_path.stem
|
|
426
|
-
barcode = ""
|
|
427
|
-
off_image_id = ""
|
|
428
|
-
image_url = ""
|
|
429
|
-
image = Image.open(image_path)
|
|
430
|
-
image.load()
|
|
431
|
-
|
|
432
|
-
if image.mode != "RGB":
|
|
433
|
-
image = image.convert("RGB")
|
|
434
|
-
|
|
435
|
-
# Rotate image according to exif orientation using Pillow
|
|
436
|
-
ImageOps.exif_transpose(image, in_place=True)
|
|
437
|
-
sample = {
|
|
438
|
-
"image_id": image_id,
|
|
439
|
-
"image": image,
|
|
440
|
-
"width": image.width,
|
|
441
|
-
"height": image.height,
|
|
442
|
-
"meta": {
|
|
443
|
-
"barcode": barcode,
|
|
444
|
-
"off_image_id": off_image_id,
|
|
445
|
-
"image_url": image_url,
|
|
446
|
-
},
|
|
447
|
-
"category_id": label_id,
|
|
448
|
-
"category_name": label_name,
|
|
449
|
-
}
|
|
450
|
-
with open(tmp_dir / f"{split}_{image_id}.pkl", "wb") as f:
|
|
451
|
-
pickle.dump(sample, f)
|
|
452
|
-
|
|
453
|
-
hf_ds = datasets.Dataset.from_generator(
|
|
454
|
-
functools.partial(_pickle_sample_generator, tmp_dir),
|
|
455
|
-
features=HF_DS_CLASSIFICATION_FEATURES,
|
|
456
|
-
)
|
|
457
|
-
hf_ds.push_to_hub(repo_id, split=split)
|
labelr/google_genai.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import mimetypes
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Literal
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
|
|
8
|
+
import aiofiles
|
|
9
|
+
import jsonschema
|
|
10
|
+
import orjson
|
|
11
|
+
import typer
|
|
12
|
+
from gcloud.aio.storage import Storage
|
|
13
|
+
from openfoodfacts import Flavor
|
|
14
|
+
from openfoodfacts.images import generate_image_url
|
|
15
|
+
from tqdm.asyncio import tqdm
|
|
16
|
+
|
|
17
|
+
from labelr.sample.common import SampleMeta
|
|
18
|
+
from labelr.sample.llm import LLMImageExtractionSample
|
|
19
|
+
from labelr.utils import download_image_from_gcs
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import google.genai # noqa: F401
|
|
23
|
+
except ImportError:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"The 'google-genai' package is required to use this module. "
|
|
26
|
+
"Please install labelr with the 'google' extra: "
|
|
27
|
+
"`pip install labelr[google]`"
|
|
28
|
+
)
|
|
29
|
+
import aiohttp
|
|
30
|
+
from google import genai
|
|
31
|
+
from google.cloud import storage
|
|
32
|
+
from google.genai.types import CreateBatchJobConfig, HttpOptions
|
|
33
|
+
from google.genai.types import JSONSchema as GoogleJSONSchema
|
|
34
|
+
from google.genai.types import Schema as GoogleSchema
|
|
35
|
+
from openfoodfacts.types import JSONType
|
|
36
|
+
from pydantic import BaseModel
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RawBatchSamplePart(BaseModel):
|
|
40
|
+
type: Literal["text", "image"]
|
|
41
|
+
data: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RawBatchSample(BaseModel):
|
|
45
|
+
key: str
|
|
46
|
+
parts: list[RawBatchSamplePart]
|
|
47
|
+
meta: JSONType = {}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def convert_pydantic_model_to_google_schema(schema: type[BaseModel]) -> JSONType:
|
|
51
|
+
"""Google doesn't support natively OpenAPI schemas, so we convert them to
|
|
52
|
+
Google `Schema` (a subset of OpenAPI)."""
|
|
53
|
+
return GoogleSchema.from_json_schema(
|
|
54
|
+
json_schema=GoogleJSONSchema.model_validate(schema.model_json_schema())
|
|
55
|
+
).model_dump(mode="json", exclude_none=True, exclude_unset=True)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def download_image(url: str, session: aiohttp.ClientSession) -> bytes:
|
|
59
|
+
"""Download an image from a URL and return its content as bytes.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
url (str): URL of the image to download.
|
|
63
|
+
Returns:
|
|
64
|
+
bytes: Content of the downloaded image.
|
|
65
|
+
"""
|
|
66
|
+
async with session.get(url) as response:
|
|
67
|
+
response.raise_for_status()
|
|
68
|
+
return await response.read()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def download_image_from_filesystem(url: str, base_dir: Path) -> bytes:
|
|
72
|
+
"""Download an image from the filesystem and return its content as bytes.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
url (str): URL of the image to download.
|
|
76
|
+
base_dir (Path): Base directory where images are stored.
|
|
77
|
+
Returns:
|
|
78
|
+
bytes: Content of the downloaded image.
|
|
79
|
+
"""
|
|
80
|
+
file_path = urlparse(url).path[1:] # Remove leading '/'
|
|
81
|
+
full_file_path = base_dir / file_path
|
|
82
|
+
async with aiofiles.open(full_file_path, "rb") as f:
|
|
83
|
+
return await f.read()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def upload_to_gcs(
|
|
87
|
+
image_url: str,
|
|
88
|
+
bucket_name: str,
|
|
89
|
+
blob_name: str,
|
|
90
|
+
session: aiohttp.ClientSession,
|
|
91
|
+
base_image_dir: Path | None = None,
|
|
92
|
+
) -> dict:
|
|
93
|
+
"""Upload data to Google Cloud Storage.
|
|
94
|
+
Args:
|
|
95
|
+
bucket_name (str): Name of the GCS bucket.
|
|
96
|
+
blob_name (str): Name of the blob (object) in the bucket.
|
|
97
|
+
data (bytes): Data to upload.
|
|
98
|
+
session (aiohttp.ClientSession): HTTP session to use for downloading
|
|
99
|
+
the image.
|
|
100
|
+
base_image_dir (Path | None): If provided, images will be read from
|
|
101
|
+
the filesystem under this base directory instead of downloading
|
|
102
|
+
them from their URLs.
|
|
103
|
+
Returns:
|
|
104
|
+
dict: Status of the upload operation.
|
|
105
|
+
"""
|
|
106
|
+
if base_image_dir is None:
|
|
107
|
+
image_data = await download_image(image_url, session)
|
|
108
|
+
else:
|
|
109
|
+
image_data = await download_image_from_filesystem(image_url, base_image_dir)
|
|
110
|
+
|
|
111
|
+
client = Storage(session=session)
|
|
112
|
+
|
|
113
|
+
status = await client.upload(
|
|
114
|
+
bucket_name,
|
|
115
|
+
blob_name,
|
|
116
|
+
image_data,
|
|
117
|
+
)
|
|
118
|
+
return status
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def upload_to_gcs_format_async(
|
|
122
|
+
sample: RawBatchSample,
|
|
123
|
+
google_json_schema: JSONType,
|
|
124
|
+
instructions: str | None,
|
|
125
|
+
bucket_name: str,
|
|
126
|
+
bucket_dir_name: str,
|
|
127
|
+
session: aiohttp.ClientSession,
|
|
128
|
+
base_image_dir: Path | None = None,
|
|
129
|
+
skip_upload: bool = False,
|
|
130
|
+
thinking_level: str | None = None,
|
|
131
|
+
) -> JSONType | None:
|
|
132
|
+
parts: list[JSONType] = []
|
|
133
|
+
|
|
134
|
+
if instructions:
|
|
135
|
+
parts.append({"text": instructions})
|
|
136
|
+
|
|
137
|
+
for part in sample.parts:
|
|
138
|
+
if part.type == "image":
|
|
139
|
+
mime_type, _ = mimetypes.guess_type(part.data)
|
|
140
|
+
if mime_type is None:
|
|
141
|
+
raise ValueError(f"Cannot guess mimetype of file: {part.data}")
|
|
142
|
+
|
|
143
|
+
file_uri = part.data
|
|
144
|
+
image_blob_name = f"{bucket_dir_name}/{sample.key}/{Path(file_uri).name}"
|
|
145
|
+
# Download the image from the URL
|
|
146
|
+
if not skip_upload:
|
|
147
|
+
try:
|
|
148
|
+
await upload_to_gcs(
|
|
149
|
+
image_url=file_uri,
|
|
150
|
+
bucket_name=bucket_name,
|
|
151
|
+
blob_name=image_blob_name,
|
|
152
|
+
session=session,
|
|
153
|
+
base_image_dir=base_image_dir,
|
|
154
|
+
)
|
|
155
|
+
except FileNotFoundError:
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
parts.append(
|
|
159
|
+
{
|
|
160
|
+
"file_data": {
|
|
161
|
+
"file_uri": f"gs://{bucket_name}/{image_blob_name}",
|
|
162
|
+
"mime_type": mime_type,
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
parts.append({"text": part.data})
|
|
168
|
+
|
|
169
|
+
generation_config = {
|
|
170
|
+
"responseMimeType": "application/json",
|
|
171
|
+
"response_json_schema": google_json_schema,
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
if thinking_level is not None:
|
|
175
|
+
generation_config["thinkingConfig"] = {"thinkingLevel": thinking_level}
|
|
176
|
+
|
|
177
|
+
return {
|
|
178
|
+
"key": f"key:{sample.key}",
|
|
179
|
+
"request": {
|
|
180
|
+
"contents": [
|
|
181
|
+
{
|
|
182
|
+
"parts": parts,
|
|
183
|
+
"role": "user",
|
|
184
|
+
}
|
|
185
|
+
],
|
|
186
|
+
"generationConfig": generation_config,
|
|
187
|
+
},
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
async def generate_batch_dataset(
|
|
192
|
+
data_path: Path,
|
|
193
|
+
output_path: Path,
|
|
194
|
+
google_json_schema: JSONType,
|
|
195
|
+
instructions: str | None,
|
|
196
|
+
bucket_name: str,
|
|
197
|
+
bucket_dir_name: str,
|
|
198
|
+
max_concurrent_uploads: int = 30,
|
|
199
|
+
base_image_dir: Path | None = None,
|
|
200
|
+
from_key: str | None = None,
|
|
201
|
+
skip_upload: bool = False,
|
|
202
|
+
thinking_level: str | None = None,
|
|
203
|
+
):
|
|
204
|
+
limiter = asyncio.Semaphore(max_concurrent_uploads)
|
|
205
|
+
ignore = True if from_key is None else False
|
|
206
|
+
missing_files = 0
|
|
207
|
+
async with aiohttp.ClientSession() as session:
|
|
208
|
+
async with asyncio.TaskGroup() as tg:
|
|
209
|
+
async with (
|
|
210
|
+
aiofiles.open(data_path, "r") as input_file,
|
|
211
|
+
aiofiles.open(output_path, "wb") as output_file,
|
|
212
|
+
):
|
|
213
|
+
async with limiter:
|
|
214
|
+
tasks = set()
|
|
215
|
+
async for line in tqdm(input_file, desc="samples"):
|
|
216
|
+
# print(f"line: {line}")
|
|
217
|
+
sample = RawBatchSample.model_validate_json(line)
|
|
218
|
+
# print(f"sample: {sample}")
|
|
219
|
+
record_key = sample.key
|
|
220
|
+
if from_key is not None and ignore:
|
|
221
|
+
if record_key == from_key:
|
|
222
|
+
ignore = False
|
|
223
|
+
else:
|
|
224
|
+
continue
|
|
225
|
+
task = tg.create_task(
|
|
226
|
+
upload_to_gcs_format_async(
|
|
227
|
+
sample=sample,
|
|
228
|
+
google_json_schema=google_json_schema,
|
|
229
|
+
instructions=instructions,
|
|
230
|
+
bucket_name=bucket_name,
|
|
231
|
+
bucket_dir_name=bucket_dir_name,
|
|
232
|
+
session=session,
|
|
233
|
+
base_image_dir=base_image_dir,
|
|
234
|
+
skip_upload=skip_upload,
|
|
235
|
+
thinking_level=thinking_level,
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
tasks.add(task)
|
|
239
|
+
|
|
240
|
+
if len(tasks) >= max_concurrent_uploads:
|
|
241
|
+
for task in tasks:
|
|
242
|
+
await task
|
|
243
|
+
updated_record = task.result()
|
|
244
|
+
if updated_record is not None:
|
|
245
|
+
await output_file.write(
|
|
246
|
+
orjson.dumps(updated_record) + "\n".encode()
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
missing_files += 1
|
|
250
|
+
tasks.clear()
|
|
251
|
+
|
|
252
|
+
for task in tasks:
|
|
253
|
+
await task
|
|
254
|
+
updated_record = task.result()
|
|
255
|
+
if updated_record is not None:
|
|
256
|
+
await output_file.write(
|
|
257
|
+
orjson.dumps(updated_record) + "\n".encode()
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
missing_files += 1
|
|
261
|
+
|
|
262
|
+
typer.echo(
|
|
263
|
+
f"Upload and dataset update completed. Wrote updated dataset to {output_path}. "
|
|
264
|
+
f"Missing files: {missing_files}."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def launch_batch_job(
|
|
269
|
+
run_name: str,
|
|
270
|
+
dataset_path: Path,
|
|
271
|
+
model: str,
|
|
272
|
+
location: str,
|
|
273
|
+
):
|
|
274
|
+
"""Launch a Gemini Batch Inference job.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
run_name (str): Name of the batch run.
|
|
278
|
+
dataset_path (Path): Path to the dataset file in JSONL format.
|
|
279
|
+
model (str): Model to use for the batch job. Example:
|
|
280
|
+
'gemini-2.5-flash'.
|
|
281
|
+
location (str): Location for the Vertex AI resources. Example:
|
|
282
|
+
'europe-west4'.
|
|
283
|
+
"""
|
|
284
|
+
# We upload the dataset to a GCS bucket using the Gcloud
|
|
285
|
+
|
|
286
|
+
if model == "gemini-3-pro-preview" and location != "global":
|
|
287
|
+
typer.echo(
|
|
288
|
+
"Warning: only 'global' location is supported for 'gemini-3-pro-preview' model. Overriding location to 'global'."
|
|
289
|
+
)
|
|
290
|
+
location = "global"
|
|
291
|
+
|
|
292
|
+
storage_client = storage.Client()
|
|
293
|
+
bucket_name = "robotoff-batch" # Replace with your bucket name
|
|
294
|
+
run_dir = f"gemini-batch/{run_name}"
|
|
295
|
+
input_file_blob_name = f"{run_dir}/inputs.jsonl"
|
|
296
|
+
bucket = storage_client.bucket(bucket_name)
|
|
297
|
+
blob = bucket.blob(input_file_blob_name)
|
|
298
|
+
blob.upload_from_filename(dataset_path)
|
|
299
|
+
|
|
300
|
+
client = genai.Client(
|
|
301
|
+
http_options=HttpOptions(api_version="v1"),
|
|
302
|
+
vertexai=True,
|
|
303
|
+
location=location,
|
|
304
|
+
)
|
|
305
|
+
output_uri = f"gs://{bucket_name}/{run_dir}"
|
|
306
|
+
job = client.batches.create(
|
|
307
|
+
model=model,
|
|
308
|
+
src=f"gs://{bucket_name}/{input_file_blob_name}",
|
|
309
|
+
config=CreateBatchJobConfig(dest=output_uri),
|
|
310
|
+
)
|
|
311
|
+
print(job)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def generate_sample_iter(
|
|
315
|
+
prediction_path: Path,
|
|
316
|
+
json_schema: JSONType,
|
|
317
|
+
skip: int = 0,
|
|
318
|
+
limit: int | None = None,
|
|
319
|
+
is_openfoodfacts_dataset: bool = False,
|
|
320
|
+
openfoodfacts_flavor: Flavor = Flavor.off,
|
|
321
|
+
raise_on_invalid_sample: bool = False,
|
|
322
|
+
) -> Iterator[LLMImageExtractionSample]:
|
|
323
|
+
"""Generate training samples from a Gemini Batch Inference prediction
|
|
324
|
+
JSONL file.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
prediction_path (Path): Path to the prediction JSONL file.
|
|
328
|
+
json_schema (JSONType): JSON schema to validate the predictions.
|
|
329
|
+
skip (int): Number of initial samples to skip.
|
|
330
|
+
limit (int | None): Maximum number of samples to generate.
|
|
331
|
+
is_openfoodfacts_dataset (bool): Whether the dataset is from Open Food
|
|
332
|
+
Facts.
|
|
333
|
+
openfoodfacts_flavor (Flavor): Flavor of the Open Food Facts dataset.
|
|
334
|
+
Yields:
|
|
335
|
+
Iterator[LLMImageExtractionSample]: Generated samples.
|
|
336
|
+
"""
|
|
337
|
+
skipped = 0
|
|
338
|
+
invalid = 0
|
|
339
|
+
storage_client = storage.Client()
|
|
340
|
+
with prediction_path.open("r") as f_in:
|
|
341
|
+
for i, sample_str in enumerate(f_in):
|
|
342
|
+
if i < skip:
|
|
343
|
+
skipped += 1
|
|
344
|
+
continue
|
|
345
|
+
if limit is not None and i >= skip + limit:
|
|
346
|
+
break
|
|
347
|
+
sample = orjson.loads(sample_str)
|
|
348
|
+
try:
|
|
349
|
+
yield generate_sample_from_prediction(
|
|
350
|
+
json_schema=json_schema,
|
|
351
|
+
sample=sample,
|
|
352
|
+
is_openfoodfacts_dataset=is_openfoodfacts_dataset,
|
|
353
|
+
openfoodfacts_flavor=openfoodfacts_flavor,
|
|
354
|
+
storage_client=storage_client,
|
|
355
|
+
)
|
|
356
|
+
except Exception as e:
|
|
357
|
+
if raise_on_invalid_sample:
|
|
358
|
+
raise
|
|
359
|
+
else:
|
|
360
|
+
typer.echo(
|
|
361
|
+
f"Skipping invalid sample at line {i + 1} in {prediction_path}: {e}"
|
|
362
|
+
)
|
|
363
|
+
invalid += 1
|
|
364
|
+
continue
|
|
365
|
+
if skipped > 0:
|
|
366
|
+
typer.echo(f"Skipped {skipped} samples.")
|
|
367
|
+
if invalid > 0:
|
|
368
|
+
typer.echo(f"Skipped {invalid} invalid samples.")
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def generate_sample_from_prediction(
|
|
372
|
+
json_schema: JSONType,
|
|
373
|
+
sample: JSONType,
|
|
374
|
+
is_openfoodfacts_dataset: bool = False,
|
|
375
|
+
openfoodfacts_flavor: Flavor = Flavor.off,
|
|
376
|
+
storage_client: storage.Client | None = None,
|
|
377
|
+
) -> LLMImageExtractionSample:
|
|
378
|
+
"""Generate a LLMImageExtractionSample from a prediction sample.
|
|
379
|
+
Args:
|
|
380
|
+
json_schema (JSONType): JSON schema to validate the predictions.
|
|
381
|
+
sample (JSONType): Prediction sample.
|
|
382
|
+
is_openfoodfacts_dataset (bool): Whether the dataset is from Open Food
|
|
383
|
+
Facts.
|
|
384
|
+
openfoodfacts_flavor (Flavor): Flavor of the Open Food Facts dataset.
|
|
385
|
+
storage_client (storage.Client | None): Optional Google Cloud Storage
|
|
386
|
+
client. If not provided, a new client will be created.
|
|
387
|
+
Returns:
|
|
388
|
+
LLMImageExtractionSample: Generated sample.
|
|
389
|
+
"""
|
|
390
|
+
image_id = sample["key"][len("key:") :]
|
|
391
|
+
response_str = sample["response"]["candidates"][0]["content"]["parts"][0]["text"]
|
|
392
|
+
image_uri = sample["request"]["contents"][0]["parts"][1]["file_data"]["file_uri"]
|
|
393
|
+
image = download_image_from_gcs(image_uri=image_uri, client=storage_client)
|
|
394
|
+
response = orjson.loads(response_str)
|
|
395
|
+
jsonschema.validate(response, json_schema)
|
|
396
|
+
|
|
397
|
+
if is_openfoodfacts_dataset:
|
|
398
|
+
image_stem_parts = image_id.split("_")
|
|
399
|
+
barcode = image_stem_parts[0]
|
|
400
|
+
off_image_id = image_stem_parts[1]
|
|
401
|
+
image_id = f"{barcode}_{off_image_id}"
|
|
402
|
+
image_url = generate_image_url(
|
|
403
|
+
barcode, off_image_id, flavor=openfoodfacts_flavor
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
image_id = image_id
|
|
407
|
+
barcode = ""
|
|
408
|
+
off_image_id = ""
|
|
409
|
+
image_url = ""
|
|
410
|
+
|
|
411
|
+
sample_meta = SampleMeta(
|
|
412
|
+
barcode=barcode,
|
|
413
|
+
off_image_id=off_image_id,
|
|
414
|
+
image_url=image_url,
|
|
415
|
+
)
|
|
416
|
+
return LLMImageExtractionSample(
|
|
417
|
+
image_id=image_id,
|
|
418
|
+
image=image,
|
|
419
|
+
output=orjson.dumps(response).decode("utf-8"),
|
|
420
|
+
meta=sample_meta,
|
|
421
|
+
)
|
labelr/main.py
CHANGED
|
@@ -5,6 +5,7 @@ from openfoodfacts.utils import get_logger
|
|
|
5
5
|
|
|
6
6
|
from labelr.apps import datasets as dataset_app
|
|
7
7
|
from labelr.apps import evaluate as evaluate_app
|
|
8
|
+
from labelr.apps import google_batch as google_batch_app
|
|
8
9
|
from labelr.apps import hugging_face as hf_app
|
|
9
10
|
from labelr.apps import label_studio as ls_app
|
|
10
11
|
from labelr.apps import train as train_app
|
|
@@ -84,6 +85,11 @@ app.add_typer(
|
|
|
84
85
|
name="evaluate",
|
|
85
86
|
help="Visualize and evaluate trained models.",
|
|
86
87
|
)
|
|
88
|
+
app.add_typer(
|
|
89
|
+
google_batch_app.app,
|
|
90
|
+
name="google-batch",
|
|
91
|
+
help="Generate datasets and launch batch jobs on Google Gemini.",
|
|
92
|
+
)
|
|
87
93
|
|
|
88
94
|
if __name__ == "__main__":
|
|
89
95
|
app()
|
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import datasets
|
|
2
|
+
|
|
3
|
+
HF_DS_CLASSIFICATION_FEATURES = datasets.Features(
|
|
4
|
+
{
|
|
5
|
+
"image_id": datasets.Value("string"),
|
|
6
|
+
"image": datasets.features.Image(),
|
|
7
|
+
"width": datasets.Value("int64"),
|
|
8
|
+
"height": datasets.Value("int64"),
|
|
9
|
+
"meta": {
|
|
10
|
+
"barcode": datasets.Value("string"),
|
|
11
|
+
"off_image_id": datasets.Value("string"),
|
|
12
|
+
"image_url": datasets.Value("string"),
|
|
13
|
+
},
|
|
14
|
+
"category_id": datasets.Value("int64"),
|
|
15
|
+
"category_name": datasets.Value("string"),
|
|
16
|
+
}
|
|
17
|
+
)
|
labelr/sample/common.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SampleMeta(BaseModel):
|
|
5
|
+
barcode: str | None = Field(
|
|
6
|
+
..., description="The barcode of the product, if applicable"
|
|
7
|
+
)
|
|
8
|
+
off_image_id: str | None = Field(
|
|
9
|
+
...,
|
|
10
|
+
description="The Open Food Facts image ID associated with the image, if applicable",
|
|
11
|
+
)
|
|
12
|
+
image_url: str | None = Field(
|
|
13
|
+
..., description="The URL of the image, if applicable"
|
|
14
|
+
)
|