xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
xinference/model/utils.py
CHANGED
|
@@ -14,13 +14,11 @@
|
|
|
14
14
|
import json
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
|
-
import shutil
|
|
18
17
|
from json import JSONDecodeError
|
|
19
18
|
from pathlib import Path
|
|
20
19
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
21
20
|
|
|
22
21
|
import huggingface_hub
|
|
23
|
-
from fsspec import AbstractFileSystem
|
|
24
22
|
|
|
25
23
|
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
|
|
26
24
|
from ..device_utils import get_available_device, is_device_available
|
|
@@ -220,12 +218,7 @@ def is_valid_model_uri(model_uri: Optional[str]) -> bool:
|
|
|
220
218
|
return True
|
|
221
219
|
|
|
222
220
|
|
|
223
|
-
def cache_from_uri(
|
|
224
|
-
model_spec: CacheableModelSpec,
|
|
225
|
-
self_hosted_storage: bool = False,
|
|
226
|
-
) -> str:
|
|
227
|
-
from fsspec import AbstractFileSystem, filesystem
|
|
228
|
-
|
|
221
|
+
def cache_from_uri(model_spec: CacheableModelSpec) -> str:
|
|
229
222
|
cache_dir = os.path.realpath(
|
|
230
223
|
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
|
|
231
224
|
)
|
|
@@ -247,48 +240,6 @@ def cache_from_uri(
|
|
|
247
240
|
os.makedirs(XINFERENCE_CACHE_DIR, exist_ok=True)
|
|
248
241
|
os.symlink(src_root, cache_dir, target_is_directory=True)
|
|
249
242
|
return cache_dir
|
|
250
|
-
elif src_scheme in ["s3"]:
|
|
251
|
-
# use anonymous connection for self-hosted storage.
|
|
252
|
-
src_fs: AbstractFileSystem = filesystem(src_scheme, anon=self_hosted_storage)
|
|
253
|
-
local_fs: AbstractFileSystem = filesystem("file")
|
|
254
|
-
|
|
255
|
-
files_to_download = []
|
|
256
|
-
os.makedirs(cache_dir, exist_ok=True)
|
|
257
|
-
|
|
258
|
-
for path, _, files in src_fs.walk(model_spec.model_uri):
|
|
259
|
-
for file in files:
|
|
260
|
-
src_path = f"{path}/{file}"
|
|
261
|
-
local_path = src_path.replace(src_root, cache_dir)
|
|
262
|
-
files_to_download.append((src_path, local_path))
|
|
263
|
-
|
|
264
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
265
|
-
|
|
266
|
-
failed = False
|
|
267
|
-
with ThreadPoolExecutor(max_workers=min(len(files_to_download), 4)) as executor:
|
|
268
|
-
futures = [
|
|
269
|
-
(
|
|
270
|
-
src_path,
|
|
271
|
-
executor.submit(
|
|
272
|
-
copy_from_src_to_dst, src_fs, src_path, local_fs, local_path
|
|
273
|
-
),
|
|
274
|
-
)
|
|
275
|
-
for src_path, local_path in files_to_download
|
|
276
|
-
]
|
|
277
|
-
for src_path, future in futures:
|
|
278
|
-
if failed:
|
|
279
|
-
future.cancel()
|
|
280
|
-
else:
|
|
281
|
-
try:
|
|
282
|
-
future.result()
|
|
283
|
-
except:
|
|
284
|
-
logger.error(f"Download {src_path} failed", exc_info=True)
|
|
285
|
-
failed = True
|
|
286
|
-
|
|
287
|
-
if failed:
|
|
288
|
-
logger.warning(f"Removing cache directory: {cache_dir}")
|
|
289
|
-
shutil.rmtree(cache_dir, ignore_errors=True)
|
|
290
|
-
raise RuntimeError(f"Failed to download model '{model_spec.model_name}' ")
|
|
291
|
-
return cache_dir
|
|
292
243
|
else:
|
|
293
244
|
raise ValueError(f"Unsupported URL scheme: {src_scheme}")
|
|
294
245
|
|
|
@@ -346,51 +297,6 @@ def cache(model_spec: CacheableModelSpec, model_description_type: type):
|
|
|
346
297
|
return cache_dir
|
|
347
298
|
|
|
348
299
|
|
|
349
|
-
def copy_from_src_to_dst(
|
|
350
|
-
_src_fs: "AbstractFileSystem",
|
|
351
|
-
_src_path: str,
|
|
352
|
-
dst_fs: "AbstractFileSystem",
|
|
353
|
-
dst_path: str,
|
|
354
|
-
max_attempt: int = 3,
|
|
355
|
-
):
|
|
356
|
-
from tqdm import tqdm
|
|
357
|
-
|
|
358
|
-
for attempt in range(max_attempt):
|
|
359
|
-
logger.info(f"Copy from {_src_path} to {dst_path}, attempt: {attempt}")
|
|
360
|
-
try:
|
|
361
|
-
with _src_fs.open(_src_path, "rb") as src_file:
|
|
362
|
-
file_size = _src_fs.info(_src_path)["size"]
|
|
363
|
-
|
|
364
|
-
dst_fs.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
|
365
|
-
with dst_fs.open(dst_path, "wb") as dst_file:
|
|
366
|
-
chunk_size = 1024 * 1024 # 1 MB
|
|
367
|
-
|
|
368
|
-
with tqdm(
|
|
369
|
-
total=file_size,
|
|
370
|
-
unit="B",
|
|
371
|
-
unit_scale=True,
|
|
372
|
-
unit_divisor=1024,
|
|
373
|
-
desc=_src_path,
|
|
374
|
-
) as pbar:
|
|
375
|
-
while True:
|
|
376
|
-
chunk = src_file.read(chunk_size)
|
|
377
|
-
if not chunk:
|
|
378
|
-
break
|
|
379
|
-
dst_file.write(chunk)
|
|
380
|
-
pbar.update(len(chunk))
|
|
381
|
-
logger.info(
|
|
382
|
-
f"Copy from {_src_path} to {dst_path} finished, attempt: {attempt}"
|
|
383
|
-
)
|
|
384
|
-
break
|
|
385
|
-
except:
|
|
386
|
-
logger.error(
|
|
387
|
-
f"Failed to copy from {_src_path} to {dst_path} on attempt {attempt + 1}",
|
|
388
|
-
exc_info=True,
|
|
389
|
-
)
|
|
390
|
-
if attempt + 1 == max_attempt:
|
|
391
|
-
raise
|
|
392
|
-
|
|
393
|
-
|
|
394
300
|
def patch_trust_remote_code():
|
|
395
301
|
"""sentence-transformers calls transformers without the trust_remote_code=True, some embedding
|
|
396
302
|
models will fail to load, e.g. jina-embeddings-v2-base-en
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import lightning.pytorch as pl
|
|
4
|
+
import torch
|
|
5
|
+
from lightning import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from torch.utils._foreach_utils import (
|
|
9
|
+
_group_tensors_by_device_and_dtype,
|
|
10
|
+
_has_foreach_support,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@torch.no_grad()
|
|
15
|
+
def grad_norm(
|
|
16
|
+
parameters: Union[Tensor, list[Tensor]],
|
|
17
|
+
norm_type: float = 2.0,
|
|
18
|
+
) -> float:
|
|
19
|
+
"""
|
|
20
|
+
Returns the norm of the gradients of the given parameters.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
|
24
|
+
single Tensor that will have gradients normalized
|
|
25
|
+
norm_type (float): type of the used p-norm.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Total norm of the parameter gradients (viewed as a single vector).
|
|
29
|
+
""" # noqa: E501
|
|
30
|
+
|
|
31
|
+
if isinstance(parameters, Tensor):
|
|
32
|
+
parameters = [parameters]
|
|
33
|
+
|
|
34
|
+
grads = [p.grad for p in parameters if p.grad is not None]
|
|
35
|
+
if len(grads) == 0:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
first_device = grads[0].device
|
|
39
|
+
grouped_grads: dict[
|
|
40
|
+
tuple[torch.device, torch.dtype], list[list[Tensor]]
|
|
41
|
+
] = _group_tensors_by_device_and_dtype(
|
|
42
|
+
[[g.detach() for g in grads]]
|
|
43
|
+
) # type: ignore[assignment]
|
|
44
|
+
|
|
45
|
+
norms = []
|
|
46
|
+
for (device, _), ([grads], _) in grouped_grads.items():
|
|
47
|
+
if _has_foreach_support(grads, device=device):
|
|
48
|
+
norms.extend(torch._foreach_norm(grads, norm_type))
|
|
49
|
+
else:
|
|
50
|
+
norms.extend([torch.norm(g, norm_type) for g in grads])
|
|
51
|
+
|
|
52
|
+
return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GradNormMonitor(Callback):
|
|
56
|
+
"""
|
|
57
|
+
Callback that computes the gradient norm of the model parameters.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
norm_type: float = 2.0,
|
|
63
|
+
logging_interval: str = "step",
|
|
64
|
+
sub_module: Optional[Union[str, list[str]]] = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Args:
|
|
68
|
+
norm_type (float): type of the used p-norm.
|
|
69
|
+
logging_interval (str): "step" or "epoch".
|
|
70
|
+
"""
|
|
71
|
+
super().__init__()
|
|
72
|
+
|
|
73
|
+
self.norm_type = norm_type
|
|
74
|
+
self.logging_interval = logging_interval
|
|
75
|
+
self.sub_module = sub_module
|
|
76
|
+
|
|
77
|
+
def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Computes the gradient norm of the model parameters and logs it to the logger.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
trainer (Trainer): The trainer object
|
|
83
|
+
model (LightningModule): The current lightningModule
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
lightning_model = model
|
|
87
|
+
|
|
88
|
+
if self.sub_module is None:
|
|
89
|
+
return self.log_sub_module_grad_norm(lightning_model, model, "")
|
|
90
|
+
|
|
91
|
+
sub_modules = self.sub_module
|
|
92
|
+
if isinstance(sub_modules, str):
|
|
93
|
+
sub_modules = [sub_modules]
|
|
94
|
+
|
|
95
|
+
for sub_module in sub_modules:
|
|
96
|
+
self.log_sub_module_grad_norm(
|
|
97
|
+
lightning_model, getattr(model, sub_module), f"/{sub_module}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def log_sub_module_grad_norm(
|
|
101
|
+
self, lightning_model: LightningModule, model: nn.Module, path: str
|
|
102
|
+
) -> None:
|
|
103
|
+
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
|
|
104
|
+
if grad_norm_val is None:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
on_step = self.logging_interval == "step"
|
|
108
|
+
lightning_model.log(
|
|
109
|
+
f"train{path}/grad_norm",
|
|
110
|
+
grad_norm_val,
|
|
111
|
+
on_step=on_step,
|
|
112
|
+
on_epoch=not on_step,
|
|
113
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import bisect
|
|
2
|
+
import random
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
from torch.utils.data import Dataset, IterableDataset
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConcatRepeatDataset(Dataset):
|
|
9
|
+
datasets: list[Dataset]
|
|
10
|
+
cumulative_sizes: list[int]
|
|
11
|
+
repeats: list[int]
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def cumsum(sequence, repeats):
|
|
15
|
+
r, s = [], 0
|
|
16
|
+
for dataset, repeat in zip(sequence, repeats):
|
|
17
|
+
l = len(dataset) * repeat
|
|
18
|
+
r.append(l + s)
|
|
19
|
+
s += l
|
|
20
|
+
return r
|
|
21
|
+
|
|
22
|
+
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
|
|
23
|
+
super().__init__()
|
|
24
|
+
|
|
25
|
+
self.datasets = list(datasets)
|
|
26
|
+
self.repeats = repeats
|
|
27
|
+
|
|
28
|
+
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
|
29
|
+
assert len(self.datasets) == len(
|
|
30
|
+
repeats
|
|
31
|
+
), "datasets and repeats should have the same length"
|
|
32
|
+
|
|
33
|
+
for d in self.datasets:
|
|
34
|
+
assert not isinstance(
|
|
35
|
+
d, IterableDataset
|
|
36
|
+
), "ConcatRepeatDataset does not support IterableDataset"
|
|
37
|
+
|
|
38
|
+
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
|
39
|
+
|
|
40
|
+
def __len__(self):
|
|
41
|
+
return self.cumulative_sizes[-1]
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, idx):
|
|
44
|
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
45
|
+
|
|
46
|
+
if dataset_idx == 0:
|
|
47
|
+
sample_idx = idx
|
|
48
|
+
else:
|
|
49
|
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
50
|
+
|
|
51
|
+
dataset = self.datasets[dataset_idx]
|
|
52
|
+
|
|
53
|
+
return dataset[sample_idx % len(dataset)]
|
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# source: text-data.proto
|
|
4
|
+
# Protobuf Python Version: 4.25.1
|
|
5
|
+
"""Generated protocol buffer code."""
|
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
9
|
+
from google.protobuf.internal import builder as _builder
|
|
10
|
+
|
|
11
|
+
# @@protoc_insertion_point(imports)
|
|
12
|
+
|
|
13
|
+
_sym_db = _symbol_database.Default()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
|
17
|
+
b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
_globals = globals()
|
|
21
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
22
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
|
|
23
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
24
|
+
DESCRIPTOR._options = None
|
|
25
|
+
_globals["_SEMANTICS"]._serialized_start = 30
|
|
26
|
+
_globals["_SEMANTICS"]._serialized_end = 57
|
|
27
|
+
_globals["_SENTENCE"]._serialized_start = 59
|
|
28
|
+
_globals["_SENTENCE"]._serialized_end = 125
|
|
29
|
+
_globals["_TEXTDATA"]._serialized_start = 127
|
|
30
|
+
_globals["_TEXTDATA"]._serialized_end = 207
|
|
31
|
+
_globals["_SAMPLEDDATA"]._serialized_start = 209
|
|
32
|
+
_globals["_SAMPLEDDATA"]._serialized_end = 290
|
|
33
|
+
# @@protoc_insertion_point(module_scope)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import struct
|
|
2
|
+
|
|
3
|
+
from .text_data_pb2 import TextData
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def read_pb_stream(f):
|
|
7
|
+
while True:
|
|
8
|
+
buf = f.read(4)
|
|
9
|
+
if len(buf) == 0:
|
|
10
|
+
break
|
|
11
|
+
size = struct.unpack("I", buf)[0]
|
|
12
|
+
buf = f.read(size)
|
|
13
|
+
text_data = TextData()
|
|
14
|
+
text_data.ParseFromString(buf)
|
|
15
|
+
yield text_data
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def write_pb_stream(f, text_data):
|
|
19
|
+
buf = text_data.SerializeToString()
|
|
20
|
+
f.write(struct.pack("I", len(buf)))
|
|
21
|
+
f.write(buf)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def pack_pb_stream(text_data):
|
|
25
|
+
buf = text_data.SerializeToString()
|
|
26
|
+
return struct.pack("I", len(buf)) + buf
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def split_pb_stream(f):
|
|
30
|
+
while True:
|
|
31
|
+
head = f.read(4)
|
|
32
|
+
if len(head) == 0:
|
|
33
|
+
break
|
|
34
|
+
size = struct.unpack("I", head)[0]
|
|
35
|
+
buf = f.read(size)
|
|
36
|
+
yield head + buf
|