xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,29 @@
|
|
|
6
6
|
"model_revision": "4bbfb1de622b80bc1b77b6e9aced75f816be0e38",
|
|
7
7
|
"model_ability": [
|
|
8
8
|
"text2video"
|
|
9
|
-
]
|
|
9
|
+
],
|
|
10
|
+
"default_model_config": {
|
|
11
|
+
"scheduler": "CogVideoXDDIMScheduler",
|
|
12
|
+
"torch_dtype": "float16"
|
|
13
|
+
},
|
|
14
|
+
"default_generate_config": {
|
|
15
|
+
"guidance_scale": 6
|
|
16
|
+
}
|
|
17
|
+
},
|
|
18
|
+
{
|
|
19
|
+
"model_name": "CogVideoX-5b",
|
|
20
|
+
"model_family": "CogVideoX",
|
|
21
|
+
"model_id": "THUDM/CogVideoX-5b",
|
|
22
|
+
"model_revision": "8d6ea3f817438460b25595a120f109b88d5fdfad",
|
|
23
|
+
"model_ability": [
|
|
24
|
+
"text2video"
|
|
25
|
+
],
|
|
26
|
+
"default_model_config": {
|
|
27
|
+
"scheduler": "CogVideoXDPMScheduler",
|
|
28
|
+
"torch_dtype": "bfloat16"
|
|
29
|
+
},
|
|
30
|
+
"default_generate_config": {
|
|
31
|
+
"guidance_scale": 7
|
|
32
|
+
}
|
|
10
33
|
}
|
|
11
34
|
]
|
|
@@ -7,6 +7,30 @@
|
|
|
7
7
|
"model_revision": "master",
|
|
8
8
|
"model_ability": [
|
|
9
9
|
"text2video"
|
|
10
|
-
]
|
|
10
|
+
],
|
|
11
|
+
"default_model_config": {
|
|
12
|
+
"scheduler": "CogVideoXDDIMScheduler",
|
|
13
|
+
"torch_dtype": "float16"
|
|
14
|
+
},
|
|
15
|
+
"default_generate_config": {
|
|
16
|
+
"guidance_scale": 6
|
|
17
|
+
}
|
|
18
|
+
},
|
|
19
|
+
{
|
|
20
|
+
"model_name": "CogVideoX-5b",
|
|
21
|
+
"model_family": "CogVideoX",
|
|
22
|
+
"model_hub": "modelscope",
|
|
23
|
+
"model_id": "ZhipuAI/CogVideoX-5b",
|
|
24
|
+
"model_revision": "master",
|
|
25
|
+
"model_ability": [
|
|
26
|
+
"text2video"
|
|
27
|
+
],
|
|
28
|
+
"default_model_config": {
|
|
29
|
+
"scheduler": "CogVideoXDPMScheduler",
|
|
30
|
+
"torch_dtype": "bfloat16"
|
|
31
|
+
},
|
|
32
|
+
"default_generate_config": {
|
|
33
|
+
"guidance_scale": 7
|
|
34
|
+
}
|
|
11
35
|
}
|
|
12
36
|
]
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import lightning.pytorch as pl
|
|
4
|
+
import torch
|
|
5
|
+
from lightning import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from torch.utils._foreach_utils import (
|
|
9
|
+
_group_tensors_by_device_and_dtype,
|
|
10
|
+
_has_foreach_support,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@torch.no_grad()
|
|
15
|
+
def grad_norm(
|
|
16
|
+
parameters: Union[Tensor, list[Tensor]],
|
|
17
|
+
norm_type: float = 2.0,
|
|
18
|
+
) -> float:
|
|
19
|
+
"""
|
|
20
|
+
Returns the norm of the gradients of the given parameters.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
|
24
|
+
single Tensor that will have gradients normalized
|
|
25
|
+
norm_type (float): type of the used p-norm.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Total norm of the parameter gradients (viewed as a single vector).
|
|
29
|
+
""" # noqa: E501
|
|
30
|
+
|
|
31
|
+
if isinstance(parameters, Tensor):
|
|
32
|
+
parameters = [parameters]
|
|
33
|
+
|
|
34
|
+
grads = [p.grad for p in parameters if p.grad is not None]
|
|
35
|
+
if len(grads) == 0:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
first_device = grads[0].device
|
|
39
|
+
grouped_grads: dict[
|
|
40
|
+
tuple[torch.device, torch.dtype], list[list[Tensor]]
|
|
41
|
+
] = _group_tensors_by_device_and_dtype(
|
|
42
|
+
[[g.detach() for g in grads]]
|
|
43
|
+
) # type: ignore[assignment]
|
|
44
|
+
|
|
45
|
+
norms = []
|
|
46
|
+
for (device, _), ([grads], _) in grouped_grads.items():
|
|
47
|
+
if _has_foreach_support(grads, device=device):
|
|
48
|
+
norms.extend(torch._foreach_norm(grads, norm_type))
|
|
49
|
+
else:
|
|
50
|
+
norms.extend([torch.norm(g, norm_type) for g in grads])
|
|
51
|
+
|
|
52
|
+
return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GradNormMonitor(Callback):
|
|
56
|
+
"""
|
|
57
|
+
Callback that computes the gradient norm of the model parameters.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
norm_type: float = 2.0,
|
|
63
|
+
logging_interval: str = "step",
|
|
64
|
+
sub_module: Optional[Union[str, list[str]]] = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Args:
|
|
68
|
+
norm_type (float): type of the used p-norm.
|
|
69
|
+
logging_interval (str): "step" or "epoch".
|
|
70
|
+
"""
|
|
71
|
+
super().__init__()
|
|
72
|
+
|
|
73
|
+
self.norm_type = norm_type
|
|
74
|
+
self.logging_interval = logging_interval
|
|
75
|
+
self.sub_module = sub_module
|
|
76
|
+
|
|
77
|
+
def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Computes the gradient norm of the model parameters and logs it to the logger.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
trainer (Trainer): The trainer object
|
|
83
|
+
model (LightningModule): The current lightningModule
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
lightning_model = model
|
|
87
|
+
|
|
88
|
+
if self.sub_module is None:
|
|
89
|
+
return self.log_sub_module_grad_norm(lightning_model, model, "")
|
|
90
|
+
|
|
91
|
+
sub_modules = self.sub_module
|
|
92
|
+
if isinstance(sub_modules, str):
|
|
93
|
+
sub_modules = [sub_modules]
|
|
94
|
+
|
|
95
|
+
for sub_module in sub_modules:
|
|
96
|
+
self.log_sub_module_grad_norm(
|
|
97
|
+
lightning_model, getattr(model, sub_module), f"/{sub_module}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def log_sub_module_grad_norm(
|
|
101
|
+
self, lightning_model: LightningModule, model: nn.Module, path: str
|
|
102
|
+
) -> None:
|
|
103
|
+
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
|
|
104
|
+
if grad_norm_val is None:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
on_step = self.logging_interval == "step"
|
|
108
|
+
lightning_model.log(
|
|
109
|
+
f"train{path}/grad_norm",
|
|
110
|
+
grad_norm_val,
|
|
111
|
+
on_step=on_step,
|
|
112
|
+
on_epoch=not on_step,
|
|
113
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import bisect
|
|
2
|
+
import random
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
from torch.utils.data import Dataset, IterableDataset
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConcatRepeatDataset(Dataset):
|
|
9
|
+
datasets: list[Dataset]
|
|
10
|
+
cumulative_sizes: list[int]
|
|
11
|
+
repeats: list[int]
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def cumsum(sequence, repeats):
|
|
15
|
+
r, s = [], 0
|
|
16
|
+
for dataset, repeat in zip(sequence, repeats):
|
|
17
|
+
l = len(dataset) * repeat
|
|
18
|
+
r.append(l + s)
|
|
19
|
+
s += l
|
|
20
|
+
return r
|
|
21
|
+
|
|
22
|
+
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
|
|
23
|
+
super().__init__()
|
|
24
|
+
|
|
25
|
+
self.datasets = list(datasets)
|
|
26
|
+
self.repeats = repeats
|
|
27
|
+
|
|
28
|
+
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
|
29
|
+
assert len(self.datasets) == len(
|
|
30
|
+
repeats
|
|
31
|
+
), "datasets and repeats should have the same length"
|
|
32
|
+
|
|
33
|
+
for d in self.datasets:
|
|
34
|
+
assert not isinstance(
|
|
35
|
+
d, IterableDataset
|
|
36
|
+
), "ConcatRepeatDataset does not support IterableDataset"
|
|
37
|
+
|
|
38
|
+
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
|
39
|
+
|
|
40
|
+
def __len__(self):
|
|
41
|
+
return self.cumulative_sizes[-1]
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, idx):
|
|
44
|
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
45
|
+
|
|
46
|
+
if dataset_idx == 0:
|
|
47
|
+
sample_idx = idx
|
|
48
|
+
else:
|
|
49
|
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
50
|
+
|
|
51
|
+
dataset = self.datasets[dataset_idx]
|
|
52
|
+
|
|
53
|
+
return dataset[sample_idx % len(dataset)]
|
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# source: text-data.proto
|
|
4
|
+
# Protobuf Python Version: 4.25.1
|
|
5
|
+
"""Generated protocol buffer code."""
|
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
9
|
+
from google.protobuf.internal import builder as _builder
|
|
10
|
+
|
|
11
|
+
# @@protoc_insertion_point(imports)
|
|
12
|
+
|
|
13
|
+
_sym_db = _symbol_database.Default()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
|
17
|
+
b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
_globals = globals()
|
|
21
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
22
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
|
|
23
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
24
|
+
DESCRIPTOR._options = None
|
|
25
|
+
_globals["_SEMANTICS"]._serialized_start = 30
|
|
26
|
+
_globals["_SEMANTICS"]._serialized_end = 57
|
|
27
|
+
_globals["_SENTENCE"]._serialized_start = 59
|
|
28
|
+
_globals["_SENTENCE"]._serialized_end = 125
|
|
29
|
+
_globals["_TEXTDATA"]._serialized_start = 127
|
|
30
|
+
_globals["_TEXTDATA"]._serialized_end = 207
|
|
31
|
+
_globals["_SAMPLEDDATA"]._serialized_start = 209
|
|
32
|
+
_globals["_SAMPLEDDATA"]._serialized_end = 290
|
|
33
|
+
# @@protoc_insertion_point(module_scope)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import struct
|
|
2
|
+
|
|
3
|
+
from .text_data_pb2 import TextData
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def read_pb_stream(f):
|
|
7
|
+
while True:
|
|
8
|
+
buf = f.read(4)
|
|
9
|
+
if len(buf) == 0:
|
|
10
|
+
break
|
|
11
|
+
size = struct.unpack("I", buf)[0]
|
|
12
|
+
buf = f.read(size)
|
|
13
|
+
text_data = TextData()
|
|
14
|
+
text_data.ParseFromString(buf)
|
|
15
|
+
yield text_data
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def write_pb_stream(f, text_data):
|
|
19
|
+
buf = text_data.SerializeToString()
|
|
20
|
+
f.write(struct.pack("I", len(buf)))
|
|
21
|
+
f.write(buf)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def pack_pb_stream(text_data):
|
|
25
|
+
buf = text_data.SerializeToString()
|
|
26
|
+
return struct.pack("I", len(buf)) + buf
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def split_pb_stream(f):
|
|
30
|
+
while True:
|
|
31
|
+
head = f.read(4)
|
|
32
|
+
if len(head) == 0:
|
|
33
|
+
break
|
|
34
|
+
size = struct.unpack("I", head)[0]
|
|
35
|
+
buf = f.read(size)
|
|
36
|
+
yield head + buf
|