fleurs-ds 0.0.1__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fleurs-ds
3
- Version: 0.0.1
3
+ Version: 0.1.0
4
4
  Summary: Google Fleurs As Huggingface Dataset
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -0,0 +1,17 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Final
4
+
5
+ from ._get_dataset import ALL_LANGUAGES, LANGUAGE_TYPES, get_dataset
6
+
7
+ __version__ = "0.1.0"
8
+
9
+ FLEURS_DATASETS_CACHE: Final[Path] = Path(
10
+ os.getenv("FLEURS_DATASETS_CACHE", "~/.cache/huggingface/datasets/fleurs")
11
+ ).expanduser()
12
+
13
+ __all__ = [
14
+ "ALL_LANGUAGES",
15
+ "get_dataset",
16
+ "LANGUAGE_TYPES",
17
+ ]
@@ -0,0 +1,356 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Literal, TypeAlias
4
+
5
+ if TYPE_CHECKING:
6
+ from datasets import Dataset
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ HF_DATASETS_PARQUET_PATTERN = (
12
+ "hf://datasets/google/fleurs@refs/convert/parquet/{lang}/{split}/*.parquet"
13
+ )
14
+
15
+
16
+ def get_dataset(
17
+ name: Literal["google/fleurs"] = "google/fleurs",
18
+ *,
19
+ language: "LANGUAGE_TYPES",
20
+ split: Literal["train", "dev", "test"],
21
+ ) -> "Dataset":
22
+ from datasets import DatasetDict
23
+
24
+ from fleurs_ds import FLEURS_DATASETS_CACHE
25
+ from fleurs_ds.types.features import features
26
+
27
+ if language not in ALL_LANGUAGES:
28
+ raise ValueError(
29
+ f"Invalid language: {language}, must be one of {ALL_LANGUAGES}"
30
+ )
31
+ if split not in ["train", "dev", "test"]:
32
+ raise ValueError(
33
+ f"Invalid split: {split}, must be one of ['train', 'dev', 'test']"
34
+ )
35
+
36
+ cache_path = FLEURS_DATASETS_CACHE.joinpath(language)
37
+
38
+ if cache_path.exists() and cache_path.is_dir() and any(cache_path.glob("*.json")):
39
+ return DatasetDict.load_from_disk(str(cache_path))[split]
40
+
41
+ train_dataset = _get_dataset_of_language_and_split(language, "train")
42
+ dev_dataset = _get_dataset_of_language_and_split(language, "dev")
43
+ test_dataset = _get_dataset_of_language_and_split(language, "test")
44
+
45
+ dataset_dict = DatasetDict(
46
+ {
47
+ "train": train_dataset.cast(features),
48
+ "dev": dev_dataset.cast(features),
49
+ "test": test_dataset.cast(features),
50
+ }
51
+ )
52
+
53
+ dataset_dict.save_to_disk(str(cache_path))
54
+ return DatasetDict.load_from_disk(str(cache_path))[split]
55
+
56
+
57
+ def _get_dataset_of_language_and_split(
58
+ language: "LANGUAGE_TYPES",
59
+ split: Literal["train", "dev", "test"],
60
+ ) -> "Dataset":
61
+ from datasets import Audio, load_dataset
62
+
63
+ valid_parquet_split: Literal["train", "validation", "test"] = (
64
+ "validation" if split == "dev" else split
65
+ )
66
+
67
+ # Construct data files path
68
+ data_files: dict[str, str] = {
69
+ split: HF_DATASETS_PARQUET_PATTERN.format(
70
+ lang=language, split=valid_parquet_split
71
+ )
72
+ }
73
+
74
+ # Load dataset
75
+ ds = load_dataset("parquet", data_files=data_files, split=split, streaming=False)
76
+
77
+ # Rename transcription to text
78
+ ds = ds.select_columns(["audio", "transcription"]).rename_column(
79
+ "transcription", "text"
80
+ )
81
+
82
+ # Cast audio column to Audio
83
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
84
+
85
+ ds = ds.add_column("language", [language] * len(ds))
86
+
87
+ # Select final columns
88
+ ds = ds.select_columns(["audio", "text", "language"])
89
+
90
+ # Cast to mp3 128k
91
+ logger.info(f"Starting conversion for {language}:{split}...")
92
+ ds = ds.map(
93
+ _convert_audio_to_mp3_128k,
94
+ num_proc=4,
95
+ keep_in_memory=True,
96
+ desc=f"Converting audio to mp3 128k for {language}:{split}",
97
+ )
98
+
99
+ original_count = len(ds)
100
+ ds = ds.filter(
101
+ lambda x: x["is_valid"],
102
+ desc=f"Filtering corrupted samples for {language}:{split}",
103
+ )
104
+ filtered_count = len(ds)
105
+ if filtered_count < original_count:
106
+ logger.warning(
107
+ f"Dropped {original_count - filtered_count} corrupted samples "
108
+ + f"from {language}:{split}"
109
+ )
110
+ ds = ds.remove_columns(["is_valid"])
111
+
112
+ return ds
113
+
114
+
115
+ def _convert_audio_to_mp3_128k(sample: dict) -> dict:
116
+ import io
117
+ import tempfile
118
+
119
+ import soundfile as sf
120
+ from pydub import AudioSegment
121
+
122
+ error_result = {"audio": None, "is_valid": False}
123
+
124
+ with tempfile.TemporaryDirectory() as _temp_dir:
125
+ temp_dir = Path(_temp_dir)
126
+ wav_path = temp_dir / "audio.wav"
127
+
128
+ try:
129
+ audio_array = sample["audio"]["array"]
130
+ except Exception as e:
131
+ logger.exception(e)
132
+ logger.error(f"Error getting audio array for sample: {sample}")
133
+ return error_result
134
+
135
+ sampling_rate = sample["audio"]["sampling_rate"]
136
+
137
+ sf.write(wav_path, audio_array, sampling_rate)
138
+ # Audio processing logic (same as original)
139
+ audio_seg: AudioSegment = AudioSegment.from_file(wav_path)
140
+ audio_seg = audio_seg.set_channels(1).set_frame_rate(16000)
141
+ mp3_io = io.BytesIO()
142
+ audio_seg.export(mp3_io, format="mp3", bitrate="128k")
143
+ audio_bytes = mp3_io.getvalue()
144
+
145
+ # Return the processed audio bytes
146
+ return {"audio": audio_bytes, "is_valid": True}
147
+
148
+
149
+ LANGUAGE_TYPES: TypeAlias = Literal[
150
+ "af_za",
151
+ "am_et",
152
+ "ar_eg",
153
+ "as_in",
154
+ "ast_es",
155
+ "az_az",
156
+ "be_by",
157
+ "bg_bg",
158
+ "bn_in",
159
+ "bs_ba",
160
+ "ca_es",
161
+ "ceb_ph",
162
+ "ckb_iq",
163
+ "cmn_hans_cn",
164
+ "cs_cz",
165
+ "cy_gb",
166
+ "da_dk",
167
+ "de_de",
168
+ "el_gr",
169
+ "en_us",
170
+ "es_419",
171
+ "et_ee",
172
+ "fa_ir",
173
+ "ff_sn",
174
+ "fi_fi",
175
+ "fil_ph",
176
+ "fr_fr",
177
+ "ga_ie",
178
+ "gl_es",
179
+ "gu_in",
180
+ "ha_ng",
181
+ "he_il",
182
+ "hi_in",
183
+ "hr_hr",
184
+ "hu_hu",
185
+ "hy_am",
186
+ "id_id",
187
+ "ig_ng",
188
+ "is_is",
189
+ "it_it",
190
+ "ja_jp",
191
+ "jv_id",
192
+ "ka_ge",
193
+ "kam_ke",
194
+ "kea_cv",
195
+ "kk_kz",
196
+ "km_kh",
197
+ "kn_in",
198
+ "ko_kr",
199
+ "ky_kg",
200
+ "lb_lu",
201
+ "lg_ug",
202
+ "ln_cd",
203
+ "lo_la",
204
+ "lt_lt",
205
+ "luo_ke",
206
+ "lv_lv",
207
+ "mi_nz",
208
+ "mk_mk",
209
+ "ml_in",
210
+ "mn_mn",
211
+ "mr_in",
212
+ "ms_my",
213
+ "mt_mt",
214
+ "my_mm",
215
+ "nb_no",
216
+ "ne_np",
217
+ "nl_nl",
218
+ "nso_za",
219
+ "ny_mw",
220
+ "oc_fr",
221
+ "om_et",
222
+ "or_in",
223
+ "pa_in",
224
+ "pl_pl",
225
+ "ps_af",
226
+ "pt_br",
227
+ "ro_ro",
228
+ "ru_ru",
229
+ "sd_in",
230
+ "sk_sk",
231
+ "sl_si",
232
+ "sn_zw",
233
+ "so_so",
234
+ "sr_rs",
235
+ "sv_se",
236
+ "sw_ke",
237
+ "ta_in",
238
+ "te_in",
239
+ "tg_tj",
240
+ "th_th",
241
+ "tr_tr",
242
+ "uk_ua",
243
+ "umb_ao",
244
+ "ur_pk",
245
+ "uz_uz",
246
+ "vi_vn",
247
+ "wo_sn",
248
+ "xh_za",
249
+ "yo_ng",
250
+ "yue_hant_hk",
251
+ "zu_za",
252
+ ]
253
+ ALL_LANGUAGES = (
254
+ "af_za",
255
+ "am_et",
256
+ "ar_eg",
257
+ "as_in",
258
+ "ast_es",
259
+ "az_az",
260
+ "be_by",
261
+ "bg_bg",
262
+ "bn_in",
263
+ "bs_ba",
264
+ "ca_es",
265
+ "ceb_ph",
266
+ "ckb_iq",
267
+ "cmn_hans_cn",
268
+ "cs_cz",
269
+ "cy_gb",
270
+ "da_dk",
271
+ "de_de",
272
+ "el_gr",
273
+ "en_us",
274
+ "es_419",
275
+ "et_ee",
276
+ "fa_ir",
277
+ "ff_sn",
278
+ "fi_fi",
279
+ "fil_ph",
280
+ "fr_fr",
281
+ "ga_ie",
282
+ "gl_es",
283
+ "gu_in",
284
+ "ha_ng",
285
+ "he_il",
286
+ "hi_in",
287
+ "hr_hr",
288
+ "hu_hu",
289
+ "hy_am",
290
+ "id_id",
291
+ "ig_ng",
292
+ "is_is",
293
+ "it_it",
294
+ "ja_jp",
295
+ "jv_id",
296
+ "ka_ge",
297
+ "kam_ke",
298
+ "kea_cv",
299
+ "kk_kz",
300
+ "km_kh",
301
+ "kn_in",
302
+ "ko_kr",
303
+ "ky_kg",
304
+ "lb_lu",
305
+ "lg_ug",
306
+ "ln_cd",
307
+ "lo_la",
308
+ "lt_lt",
309
+ "luo_ke",
310
+ "lv_lv",
311
+ "mi_nz",
312
+ "mk_mk",
313
+ "ml_in",
314
+ "mn_mn",
315
+ "mr_in",
316
+ "ms_my",
317
+ "mt_mt",
318
+ "my_mm",
319
+ "nb_no",
320
+ "ne_np",
321
+ "nl_nl",
322
+ "nso_za",
323
+ "ny_mw",
324
+ "oc_fr",
325
+ "om_et",
326
+ "or_in",
327
+ "pa_in",
328
+ "pl_pl",
329
+ "ps_af",
330
+ "pt_br",
331
+ "ro_ro",
332
+ "ru_ru",
333
+ "sd_in",
334
+ "sk_sk",
335
+ "sl_si",
336
+ "sn_zw",
337
+ "so_so",
338
+ "sr_rs",
339
+ "sv_se",
340
+ "sw_ke",
341
+ "ta_in",
342
+ "te_in",
343
+ "tg_tj",
344
+ "th_th",
345
+ "tr_tr",
346
+ "uk_ua",
347
+ "umb_ao",
348
+ "ur_pk",
349
+ "uz_uz",
350
+ "vi_vn",
351
+ "wo_sn",
352
+ "xh_za",
353
+ "yo_ng",
354
+ "yue_hant_hk",
355
+ "zu_za",
356
+ )
File without changes
@@ -0,0 +1,9 @@
1
+ from datasets import Audio, Features, Value
2
+
3
+ features = Features(
4
+ {
5
+ "audio": Audio(sampling_rate=16000),
6
+ "text": Value("string"),
7
+ "language": Value("string"),
8
+ }
9
+ )
@@ -14,7 +14,7 @@ license = { text = "MIT" }
14
14
  name = "fleurs-ds"
15
15
  readme = "README.md"
16
16
  requires-python = ">=3.11,<4"
17
- version = "0.0.1"
17
+ version = "0.1.0"
18
18
 
19
19
  [project.urls]
20
20
  Homepage = "https://github.com/allen2c/fleurs-ds"
@@ -1 +0,0 @@
1
- __version__ = "0.0.1"
File without changes
File without changes