fleurs-ds 0.0.1__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fleurs_ds-0.0.1 → fleurs_ds-0.1.0}/PKG-INFO +1 -1
- fleurs_ds-0.1.0/fleurs_ds/__init__.py +17 -0
- fleurs_ds-0.1.0/fleurs_ds/_get_dataset.py +356 -0
- fleurs_ds-0.1.0/fleurs_ds/types/__init__.py +0 -0
- fleurs_ds-0.1.0/fleurs_ds/types/features.py +9 -0
- {fleurs_ds-0.0.1 → fleurs_ds-0.1.0}/pyproject.toml +1 -1
- fleurs_ds-0.0.1/fleurs_ds/__init__.py +0 -1
- {fleurs_ds-0.0.1 → fleurs_ds-0.1.0}/LICENSE +0 -0
- {fleurs_ds-0.0.1 → fleurs_ds-0.1.0}/README.md +0 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Final
|
|
4
|
+
|
|
5
|
+
from ._get_dataset import ALL_LANGUAGES, LANGUAGE_TYPES, get_dataset
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
|
|
9
|
+
FLEURS_DATASETS_CACHE: Final[Path] = Path(
|
|
10
|
+
os.getenv("FLEURS_DATASETS_CACHE", "~/.cache/huggingface/datasets/fleurs")
|
|
11
|
+
).expanduser()
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ALL_LANGUAGES",
|
|
15
|
+
"get_dataset",
|
|
16
|
+
"LANGUAGE_TYPES",
|
|
17
|
+
]
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Literal, TypeAlias
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
HF_DATASETS_PARQUET_PATTERN = (
|
|
12
|
+
"hf://datasets/google/fleurs@refs/convert/parquet/{lang}/{split}/*.parquet"
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_dataset(
|
|
17
|
+
name: Literal["google/fleurs"] = "google/fleurs",
|
|
18
|
+
*,
|
|
19
|
+
language: "LANGUAGE_TYPES",
|
|
20
|
+
split: Literal["train", "dev", "test"],
|
|
21
|
+
) -> "Dataset":
|
|
22
|
+
from datasets import DatasetDict
|
|
23
|
+
|
|
24
|
+
from fleurs_ds import FLEURS_DATASETS_CACHE
|
|
25
|
+
from fleurs_ds.types.features import features
|
|
26
|
+
|
|
27
|
+
if language not in ALL_LANGUAGES:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Invalid language: {language}, must be one of {ALL_LANGUAGES}"
|
|
30
|
+
)
|
|
31
|
+
if split not in ["train", "dev", "test"]:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Invalid split: {split}, must be one of ['train', 'dev', 'test']"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
cache_path = FLEURS_DATASETS_CACHE.joinpath(language)
|
|
37
|
+
|
|
38
|
+
if cache_path.exists() and cache_path.is_dir() and any(cache_path.glob("*.json")):
|
|
39
|
+
return DatasetDict.load_from_disk(str(cache_path))[split]
|
|
40
|
+
|
|
41
|
+
train_dataset = _get_dataset_of_language_and_split(language, "train")
|
|
42
|
+
dev_dataset = _get_dataset_of_language_and_split(language, "dev")
|
|
43
|
+
test_dataset = _get_dataset_of_language_and_split(language, "test")
|
|
44
|
+
|
|
45
|
+
dataset_dict = DatasetDict(
|
|
46
|
+
{
|
|
47
|
+
"train": train_dataset.cast(features),
|
|
48
|
+
"dev": dev_dataset.cast(features),
|
|
49
|
+
"test": test_dataset.cast(features),
|
|
50
|
+
}
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
dataset_dict.save_to_disk(str(cache_path))
|
|
54
|
+
return DatasetDict.load_from_disk(str(cache_path))[split]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_dataset_of_language_and_split(
|
|
58
|
+
language: "LANGUAGE_TYPES",
|
|
59
|
+
split: Literal["train", "dev", "test"],
|
|
60
|
+
) -> "Dataset":
|
|
61
|
+
from datasets import Audio, load_dataset
|
|
62
|
+
|
|
63
|
+
valid_parquet_split: Literal["train", "validation", "test"] = (
|
|
64
|
+
"validation" if split == "dev" else split
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Construct data files path
|
|
68
|
+
data_files: dict[str, str] = {
|
|
69
|
+
split: HF_DATASETS_PARQUET_PATTERN.format(
|
|
70
|
+
lang=language, split=valid_parquet_split
|
|
71
|
+
)
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Load dataset
|
|
75
|
+
ds = load_dataset("parquet", data_files=data_files, split=split, streaming=False)
|
|
76
|
+
|
|
77
|
+
# Rename transcription to text
|
|
78
|
+
ds = ds.select_columns(["audio", "transcription"]).rename_column(
|
|
79
|
+
"transcription", "text"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Cast audio column to Audio
|
|
83
|
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
|
84
|
+
|
|
85
|
+
ds = ds.add_column("language", [language] * len(ds))
|
|
86
|
+
|
|
87
|
+
# Select final columns
|
|
88
|
+
ds = ds.select_columns(["audio", "text", "language"])
|
|
89
|
+
|
|
90
|
+
# Cast to mp3 128k
|
|
91
|
+
logger.info(f"Starting conversion for {language}:{split}...")
|
|
92
|
+
ds = ds.map(
|
|
93
|
+
_convert_audio_to_mp3_128k,
|
|
94
|
+
num_proc=4,
|
|
95
|
+
keep_in_memory=True,
|
|
96
|
+
desc=f"Converting audio to mp3 128k for {language}:{split}",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
original_count = len(ds)
|
|
100
|
+
ds = ds.filter(
|
|
101
|
+
lambda x: x["is_valid"],
|
|
102
|
+
desc=f"Filtering corrupted samples for {language}:{split}",
|
|
103
|
+
)
|
|
104
|
+
filtered_count = len(ds)
|
|
105
|
+
if filtered_count < original_count:
|
|
106
|
+
logger.warning(
|
|
107
|
+
f"Dropped {original_count - filtered_count} corrupted samples "
|
|
108
|
+
+ f"from {language}:{split}"
|
|
109
|
+
)
|
|
110
|
+
ds = ds.remove_columns(["is_valid"])
|
|
111
|
+
|
|
112
|
+
return ds
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _convert_audio_to_mp3_128k(sample: dict) -> dict:
|
|
116
|
+
import io
|
|
117
|
+
import tempfile
|
|
118
|
+
|
|
119
|
+
import soundfile as sf
|
|
120
|
+
from pydub import AudioSegment
|
|
121
|
+
|
|
122
|
+
error_result = {"audio": None, "is_valid": False}
|
|
123
|
+
|
|
124
|
+
with tempfile.TemporaryDirectory() as _temp_dir:
|
|
125
|
+
temp_dir = Path(_temp_dir)
|
|
126
|
+
wav_path = temp_dir / "audio.wav"
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
audio_array = sample["audio"]["array"]
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.exception(e)
|
|
132
|
+
logger.error(f"Error getting audio array for sample: {sample}")
|
|
133
|
+
return error_result
|
|
134
|
+
|
|
135
|
+
sampling_rate = sample["audio"]["sampling_rate"]
|
|
136
|
+
|
|
137
|
+
sf.write(wav_path, audio_array, sampling_rate)
|
|
138
|
+
# Audio processing logic (same as original)
|
|
139
|
+
audio_seg: AudioSegment = AudioSegment.from_file(wav_path)
|
|
140
|
+
audio_seg = audio_seg.set_channels(1).set_frame_rate(16000)
|
|
141
|
+
mp3_io = io.BytesIO()
|
|
142
|
+
audio_seg.export(mp3_io, format="mp3", bitrate="128k")
|
|
143
|
+
audio_bytes = mp3_io.getvalue()
|
|
144
|
+
|
|
145
|
+
# Return the processed audio bytes
|
|
146
|
+
return {"audio": audio_bytes, "is_valid": True}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
LANGUAGE_TYPES: TypeAlias = Literal[
|
|
150
|
+
"af_za",
|
|
151
|
+
"am_et",
|
|
152
|
+
"ar_eg",
|
|
153
|
+
"as_in",
|
|
154
|
+
"ast_es",
|
|
155
|
+
"az_az",
|
|
156
|
+
"be_by",
|
|
157
|
+
"bg_bg",
|
|
158
|
+
"bn_in",
|
|
159
|
+
"bs_ba",
|
|
160
|
+
"ca_es",
|
|
161
|
+
"ceb_ph",
|
|
162
|
+
"ckb_iq",
|
|
163
|
+
"cmn_hans_cn",
|
|
164
|
+
"cs_cz",
|
|
165
|
+
"cy_gb",
|
|
166
|
+
"da_dk",
|
|
167
|
+
"de_de",
|
|
168
|
+
"el_gr",
|
|
169
|
+
"en_us",
|
|
170
|
+
"es_419",
|
|
171
|
+
"et_ee",
|
|
172
|
+
"fa_ir",
|
|
173
|
+
"ff_sn",
|
|
174
|
+
"fi_fi",
|
|
175
|
+
"fil_ph",
|
|
176
|
+
"fr_fr",
|
|
177
|
+
"ga_ie",
|
|
178
|
+
"gl_es",
|
|
179
|
+
"gu_in",
|
|
180
|
+
"ha_ng",
|
|
181
|
+
"he_il",
|
|
182
|
+
"hi_in",
|
|
183
|
+
"hr_hr",
|
|
184
|
+
"hu_hu",
|
|
185
|
+
"hy_am",
|
|
186
|
+
"id_id",
|
|
187
|
+
"ig_ng",
|
|
188
|
+
"is_is",
|
|
189
|
+
"it_it",
|
|
190
|
+
"ja_jp",
|
|
191
|
+
"jv_id",
|
|
192
|
+
"ka_ge",
|
|
193
|
+
"kam_ke",
|
|
194
|
+
"kea_cv",
|
|
195
|
+
"kk_kz",
|
|
196
|
+
"km_kh",
|
|
197
|
+
"kn_in",
|
|
198
|
+
"ko_kr",
|
|
199
|
+
"ky_kg",
|
|
200
|
+
"lb_lu",
|
|
201
|
+
"lg_ug",
|
|
202
|
+
"ln_cd",
|
|
203
|
+
"lo_la",
|
|
204
|
+
"lt_lt",
|
|
205
|
+
"luo_ke",
|
|
206
|
+
"lv_lv",
|
|
207
|
+
"mi_nz",
|
|
208
|
+
"mk_mk",
|
|
209
|
+
"ml_in",
|
|
210
|
+
"mn_mn",
|
|
211
|
+
"mr_in",
|
|
212
|
+
"ms_my",
|
|
213
|
+
"mt_mt",
|
|
214
|
+
"my_mm",
|
|
215
|
+
"nb_no",
|
|
216
|
+
"ne_np",
|
|
217
|
+
"nl_nl",
|
|
218
|
+
"nso_za",
|
|
219
|
+
"ny_mw",
|
|
220
|
+
"oc_fr",
|
|
221
|
+
"om_et",
|
|
222
|
+
"or_in",
|
|
223
|
+
"pa_in",
|
|
224
|
+
"pl_pl",
|
|
225
|
+
"ps_af",
|
|
226
|
+
"pt_br",
|
|
227
|
+
"ro_ro",
|
|
228
|
+
"ru_ru",
|
|
229
|
+
"sd_in",
|
|
230
|
+
"sk_sk",
|
|
231
|
+
"sl_si",
|
|
232
|
+
"sn_zw",
|
|
233
|
+
"so_so",
|
|
234
|
+
"sr_rs",
|
|
235
|
+
"sv_se",
|
|
236
|
+
"sw_ke",
|
|
237
|
+
"ta_in",
|
|
238
|
+
"te_in",
|
|
239
|
+
"tg_tj",
|
|
240
|
+
"th_th",
|
|
241
|
+
"tr_tr",
|
|
242
|
+
"uk_ua",
|
|
243
|
+
"umb_ao",
|
|
244
|
+
"ur_pk",
|
|
245
|
+
"uz_uz",
|
|
246
|
+
"vi_vn",
|
|
247
|
+
"wo_sn",
|
|
248
|
+
"xh_za",
|
|
249
|
+
"yo_ng",
|
|
250
|
+
"yue_hant_hk",
|
|
251
|
+
"zu_za",
|
|
252
|
+
]
|
|
253
|
+
ALL_LANGUAGES = (
|
|
254
|
+
"af_za",
|
|
255
|
+
"am_et",
|
|
256
|
+
"ar_eg",
|
|
257
|
+
"as_in",
|
|
258
|
+
"ast_es",
|
|
259
|
+
"az_az",
|
|
260
|
+
"be_by",
|
|
261
|
+
"bg_bg",
|
|
262
|
+
"bn_in",
|
|
263
|
+
"bs_ba",
|
|
264
|
+
"ca_es",
|
|
265
|
+
"ceb_ph",
|
|
266
|
+
"ckb_iq",
|
|
267
|
+
"cmn_hans_cn",
|
|
268
|
+
"cs_cz",
|
|
269
|
+
"cy_gb",
|
|
270
|
+
"da_dk",
|
|
271
|
+
"de_de",
|
|
272
|
+
"el_gr",
|
|
273
|
+
"en_us",
|
|
274
|
+
"es_419",
|
|
275
|
+
"et_ee",
|
|
276
|
+
"fa_ir",
|
|
277
|
+
"ff_sn",
|
|
278
|
+
"fi_fi",
|
|
279
|
+
"fil_ph",
|
|
280
|
+
"fr_fr",
|
|
281
|
+
"ga_ie",
|
|
282
|
+
"gl_es",
|
|
283
|
+
"gu_in",
|
|
284
|
+
"ha_ng",
|
|
285
|
+
"he_il",
|
|
286
|
+
"hi_in",
|
|
287
|
+
"hr_hr",
|
|
288
|
+
"hu_hu",
|
|
289
|
+
"hy_am",
|
|
290
|
+
"id_id",
|
|
291
|
+
"ig_ng",
|
|
292
|
+
"is_is",
|
|
293
|
+
"it_it",
|
|
294
|
+
"ja_jp",
|
|
295
|
+
"jv_id",
|
|
296
|
+
"ka_ge",
|
|
297
|
+
"kam_ke",
|
|
298
|
+
"kea_cv",
|
|
299
|
+
"kk_kz",
|
|
300
|
+
"km_kh",
|
|
301
|
+
"kn_in",
|
|
302
|
+
"ko_kr",
|
|
303
|
+
"ky_kg",
|
|
304
|
+
"lb_lu",
|
|
305
|
+
"lg_ug",
|
|
306
|
+
"ln_cd",
|
|
307
|
+
"lo_la",
|
|
308
|
+
"lt_lt",
|
|
309
|
+
"luo_ke",
|
|
310
|
+
"lv_lv",
|
|
311
|
+
"mi_nz",
|
|
312
|
+
"mk_mk",
|
|
313
|
+
"ml_in",
|
|
314
|
+
"mn_mn",
|
|
315
|
+
"mr_in",
|
|
316
|
+
"ms_my",
|
|
317
|
+
"mt_mt",
|
|
318
|
+
"my_mm",
|
|
319
|
+
"nb_no",
|
|
320
|
+
"ne_np",
|
|
321
|
+
"nl_nl",
|
|
322
|
+
"nso_za",
|
|
323
|
+
"ny_mw",
|
|
324
|
+
"oc_fr",
|
|
325
|
+
"om_et",
|
|
326
|
+
"or_in",
|
|
327
|
+
"pa_in",
|
|
328
|
+
"pl_pl",
|
|
329
|
+
"ps_af",
|
|
330
|
+
"pt_br",
|
|
331
|
+
"ro_ro",
|
|
332
|
+
"ru_ru",
|
|
333
|
+
"sd_in",
|
|
334
|
+
"sk_sk",
|
|
335
|
+
"sl_si",
|
|
336
|
+
"sn_zw",
|
|
337
|
+
"so_so",
|
|
338
|
+
"sr_rs",
|
|
339
|
+
"sv_se",
|
|
340
|
+
"sw_ke",
|
|
341
|
+
"ta_in",
|
|
342
|
+
"te_in",
|
|
343
|
+
"tg_tj",
|
|
344
|
+
"th_th",
|
|
345
|
+
"tr_tr",
|
|
346
|
+
"uk_ua",
|
|
347
|
+
"umb_ao",
|
|
348
|
+
"ur_pk",
|
|
349
|
+
"uz_uz",
|
|
350
|
+
"vi_vn",
|
|
351
|
+
"wo_sn",
|
|
352
|
+
"xh_za",
|
|
353
|
+
"yo_ng",
|
|
354
|
+
"yue_hant_hk",
|
|
355
|
+
"zu_za",
|
|
356
|
+
)
|
|
File without changes
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.0.1"
|
|
File without changes
|
|
File without changes
|