xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,29 @@
6
6
  "model_revision": "4bbfb1de622b80bc1b77b6e9aced75f816be0e38",
7
7
  "model_ability": [
8
8
  "text2video"
9
- ]
9
+ ],
10
+ "default_model_config": {
11
+ "scheduler": "CogVideoXDDIMScheduler",
12
+ "torch_dtype": "float16"
13
+ },
14
+ "default_generate_config": {
15
+ "guidance_scale": 6
16
+ }
17
+ },
18
+ {
19
+ "model_name": "CogVideoX-5b",
20
+ "model_family": "CogVideoX",
21
+ "model_id": "THUDM/CogVideoX-5b",
22
+ "model_revision": "8d6ea3f817438460b25595a120f109b88d5fdfad",
23
+ "model_ability": [
24
+ "text2video"
25
+ ],
26
+ "default_model_config": {
27
+ "scheduler": "CogVideoXDPMScheduler",
28
+ "torch_dtype": "bfloat16"
29
+ },
30
+ "default_generate_config": {
31
+ "guidance_scale": 7
32
+ }
10
33
  }
11
34
  ]
@@ -7,6 +7,30 @@
7
7
  "model_revision": "master",
8
8
  "model_ability": [
9
9
  "text2video"
10
- ]
10
+ ],
11
+ "default_model_config": {
12
+ "scheduler": "CogVideoXDDIMScheduler",
13
+ "torch_dtype": "float16"
14
+ },
15
+ "default_generate_config": {
16
+ "guidance_scale": 6
17
+ }
18
+ },
19
+ {
20
+ "model_name": "CogVideoX-5b",
21
+ "model_family": "CogVideoX",
22
+ "model_hub": "modelscope",
23
+ "model_id": "ZhipuAI/CogVideoX-5b",
24
+ "model_revision": "master",
25
+ "model_ability": [
26
+ "text2video"
27
+ ],
28
+ "default_model_config": {
29
+ "scheduler": "CogVideoXDPMScheduler",
30
+ "torch_dtype": "bfloat16"
31
+ },
32
+ "default_generate_config": {
33
+ "guidance_scale": 7
34
+ }
11
35
  }
12
36
  ]
File without changes
@@ -0,0 +1,3 @@
1
+ from .grad_norm import GradNormMonitor
2
+
3
+ __all__ = ["GradNormMonitor"]
@@ -0,0 +1,113 @@
1
+ from typing import Optional, Union
2
+
3
+ import lightning.pytorch as pl
4
+ import torch
5
+ from lightning import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from torch import Tensor, nn
8
+ from torch.utils._foreach_utils import (
9
+ _group_tensors_by_device_and_dtype,
10
+ _has_foreach_support,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def grad_norm(
16
+ parameters: Union[Tensor, list[Tensor]],
17
+ norm_type: float = 2.0,
18
+ ) -> float:
19
+ """
20
+ Returns the norm of the gradients of the given parameters.
21
+
22
+ Args:
23
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
+ single Tensor that will have gradients normalized
25
+ norm_type (float): type of the used p-norm.
26
+
27
+ Returns:
28
+ Total norm of the parameter gradients (viewed as a single vector).
29
+ """ # noqa: E501
30
+
31
+ if isinstance(parameters, Tensor):
32
+ parameters = [parameters]
33
+
34
+ grads = [p.grad for p in parameters if p.grad is not None]
35
+ if len(grads) == 0:
36
+ return None
37
+
38
+ first_device = grads[0].device
39
+ grouped_grads: dict[
40
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
41
+ ] = _group_tensors_by_device_and_dtype(
42
+ [[g.detach() for g in grads]]
43
+ ) # type: ignore[assignment]
44
+
45
+ norms = []
46
+ for (device, _), ([grads], _) in grouped_grads.items():
47
+ if _has_foreach_support(grads, device=device):
48
+ norms.extend(torch._foreach_norm(grads, norm_type))
49
+ else:
50
+ norms.extend([torch.norm(g, norm_type) for g in grads])
51
+
52
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
+
54
+
55
+ class GradNormMonitor(Callback):
56
+ """
57
+ Callback that computes the gradient norm of the model parameters.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ norm_type: float = 2.0,
63
+ logging_interval: str = "step",
64
+ sub_module: Optional[Union[str, list[str]]] = None,
65
+ ) -> None:
66
+ """
67
+ Args:
68
+ norm_type (float): type of the used p-norm.
69
+ logging_interval (str): "step" or "epoch".
70
+ """
71
+ super().__init__()
72
+
73
+ self.norm_type = norm_type
74
+ self.logging_interval = logging_interval
75
+ self.sub_module = sub_module
76
+
77
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
+ """
79
+ Computes the gradient norm of the model parameters and logs it to the logger.
80
+
81
+ Args:
82
+ trainer (Trainer): The trainer object
83
+ model (LightningModule): The current lightningModule
84
+ """
85
+
86
+ lightning_model = model
87
+
88
+ if self.sub_module is None:
89
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
90
+
91
+ sub_modules = self.sub_module
92
+ if isinstance(sub_modules, str):
93
+ sub_modules = [sub_modules]
94
+
95
+ for sub_module in sub_modules:
96
+ self.log_sub_module_grad_norm(
97
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
+ )
99
+
100
+ def log_sub_module_grad_norm(
101
+ self, lightning_model: LightningModule, model: nn.Module, path: str
102
+ ) -> None:
103
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
+ if grad_norm_val is None:
105
+ return
106
+
107
+ on_step = self.logging_interval == "step"
108
+ lightning_model.log(
109
+ f"train{path}/grad_norm",
110
+ grad_norm_val,
111
+ on_step=on_step,
112
+ on_epoch=not on_step,
113
+ )
@@ -0,0 +1,2 @@
1
+ SEMANTIC_TOKEN = "<|semantic|>"
2
+ CODEBOOK_PAD_TOKEN_ID = 0
@@ -0,0 +1,53 @@
1
+ import bisect
2
+ import random
3
+ from typing import Iterable
4
+
5
+ from torch.utils.data import Dataset, IterableDataset
6
+
7
+
8
+ class ConcatRepeatDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+ repeats: list[int]
12
+
13
+ @staticmethod
14
+ def cumsum(sequence, repeats):
15
+ r, s = [], 0
16
+ for dataset, repeat in zip(sequence, repeats):
17
+ l = len(dataset) * repeat
18
+ r.append(l + s)
19
+ s += l
20
+ return r
21
+
22
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
+ super().__init__()
24
+
25
+ self.datasets = list(datasets)
26
+ self.repeats = repeats
27
+
28
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
+ assert len(self.datasets) == len(
30
+ repeats
31
+ ), "datasets and repeats should have the same length"
32
+
33
+ for d in self.datasets:
34
+ assert not isinstance(
35
+ d, IterableDataset
36
+ ), "ConcatRepeatDataset does not support IterableDataset"
37
+
38
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
+
40
+ def __len__(self):
41
+ return self.cumulative_sizes[-1]
42
+
43
+ def __getitem__(self, idx):
44
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
+
46
+ if dataset_idx == 0:
47
+ sample_idx = idx
48
+ else:
49
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
+
51
+ dataset = self.datasets[dataset_idx]
52
+
53
+ return dataset[sample_idx % len(dataset)]
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: text-data.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_SEMANTICS"]._serialized_start = 30
26
+ _globals["_SEMANTICS"]._serialized_end = 57
27
+ _globals["_SENTENCE"]._serialized_start = 59
28
+ _globals["_SENTENCE"]._serialized_end = 125
29
+ _globals["_TEXTDATA"]._serialized_start = 127
30
+ _globals["_TEXTDATA"]._serialized_end = 207
31
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
32
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
33
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,36 @@
1
+ import struct
2
+
3
+ from .text_data_pb2 import TextData
4
+
5
+
6
+ def read_pb_stream(f):
7
+ while True:
8
+ buf = f.read(4)
9
+ if len(buf) == 0:
10
+ break
11
+ size = struct.unpack("I", buf)[0]
12
+ buf = f.read(size)
13
+ text_data = TextData()
14
+ text_data.ParseFromString(buf)
15
+ yield text_data
16
+
17
+
18
+ def write_pb_stream(f, text_data):
19
+ buf = text_data.SerializeToString()
20
+ f.write(struct.pack("I", len(buf)))
21
+ f.write(buf)
22
+
23
+
24
+ def pack_pb_stream(text_data):
25
+ buf = text_data.SerializeToString()
26
+ return struct.pack("I", len(buf)) + buf
27
+
28
+
29
+ def split_pb_stream(f):
30
+ while True:
31
+ head = f.read(4)
32
+ if len(head) == 0:
33
+ break
34
+ size = struct.unpack("I", head)[0]
35
+ buf = f.read(size)
36
+ yield head + buf