xinference 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (70) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +123 -3
  3. xinference/client/restful/restful_client.py +131 -2
  4. xinference/core/model.py +93 -24
  5. xinference/core/supervisor.py +132 -15
  6. xinference/core/worker.py +165 -8
  7. xinference/deploy/cmdline.py +5 -0
  8. xinference/model/audio/chattts.py +46 -14
  9. xinference/model/audio/core.py +23 -15
  10. xinference/model/core.py +12 -3
  11. xinference/model/embedding/core.py +25 -16
  12. xinference/model/flexible/__init__.py +40 -0
  13. xinference/model/flexible/core.py +228 -0
  14. xinference/model/flexible/launchers/__init__.py +15 -0
  15. xinference/model/flexible/launchers/transformers_launcher.py +63 -0
  16. xinference/model/flexible/utils.py +33 -0
  17. xinference/model/image/core.py +21 -14
  18. xinference/model/image/custom.py +1 -1
  19. xinference/model/image/model_spec.json +14 -0
  20. xinference/model/image/stable_diffusion/core.py +43 -6
  21. xinference/model/llm/__init__.py +0 -2
  22. xinference/model/llm/core.py +3 -2
  23. xinference/model/llm/ggml/llamacpp.py +1 -10
  24. xinference/model/llm/llm_family.json +292 -36
  25. xinference/model/llm/llm_family.py +97 -52
  26. xinference/model/llm/llm_family_modelscope.json +220 -27
  27. xinference/model/llm/pytorch/core.py +0 -80
  28. xinference/model/llm/sglang/core.py +7 -2
  29. xinference/model/llm/utils.py +4 -2
  30. xinference/model/llm/vllm/core.py +3 -0
  31. xinference/model/rerank/core.py +24 -25
  32. xinference/types.py +0 -1
  33. xinference/web/ui/build/asset-manifest.json +3 -3
  34. xinference/web/ui/build/index.html +1 -1
  35. xinference/web/ui/build/static/js/{main.0fb6f3ab.js → main.95c1d652.js} +3 -3
  36. xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
  43. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/METADATA +9 -11
  44. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/RECORD +49 -58
  45. xinference/model/llm/ggml/chatglm.py +0 -457
  46. xinference/thirdparty/ChatTTS/__init__.py +0 -1
  47. xinference/thirdparty/ChatTTS/core.py +0 -200
  48. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  49. xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
  50. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  51. xinference/thirdparty/ChatTTS/infer/api.py +0 -125
  52. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  53. xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
  54. xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
  55. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  56. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
  57. xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
  58. xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
  59. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +0 -1
  66. /xinference/web/ui/build/static/js/{main.0fb6f3ab.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
  67. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/LICENSE +0 -0
  68. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/WHEEL +0 -0
  69. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/entry_points.txt +0 -0
  70. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/top_level.txt +0 -0
@@ -1,40 +0,0 @@
1
-
2
- from openai import OpenAI
3
-
4
- prompt_dict = {
5
- 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6
- {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7
- {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8
- 'deepseek': [
9
- {"role": "system", "content": "You are a helpful assistant"},
10
- {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11
- {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12
- 'deepseek_TN': [
13
- {"role": "system", "content": "You are a helpful assistant"},
14
- {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15
- {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16
- {"role": "user", "content": "We paid $123 for this desk."},
17
- {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18
- {"role": "user", "content": "详询请拨打010-724654"},
19
- {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20
- {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21
- {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22
- ],
23
- }
24
-
25
- class llm_api:
26
- def __init__(self, api_key, base_url, model):
27
- self.client = OpenAI(
28
- api_key = api_key,
29
- base_url = base_url,
30
- )
31
- self.model = model
32
- def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33
-
34
- completion = self.client.chat.completions.create(
35
- model = self.model,
36
- messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37
- temperature = temperature,
38
- **kwargs
39
- )
40
- return completion.choices[0].message.content
File without changes
@@ -1,125 +0,0 @@
1
-
2
- import torch
3
- import torch.nn.functional as F
4
- from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5
- from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6
-
7
- def infer_code(
8
- models,
9
- text,
10
- spk_emb = None,
11
- top_P = 0.7,
12
- top_K = 20,
13
- temperature = 0.3,
14
- repetition_penalty = 1.05,
15
- max_new_token = 2048,
16
- **kwargs
17
- ):
18
-
19
- device = next(models['gpt'].parameters()).device
20
-
21
- if not isinstance(text, list):
22
- text = [text]
23
-
24
- if not isinstance(temperature, list):
25
- temperature = [temperature] * models['gpt'].num_vq
26
-
27
- if spk_emb is not None:
28
- text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
29
- else:
30
- text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
31
-
32
- text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
33
- input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
34
- text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
35
-
36
- inputs = {
37
- 'input_ids': input_ids,
38
- 'text_mask': text_mask,
39
- 'attention_mask': text_token['attention_mask'],
40
- }
41
-
42
- emb = models['gpt'].get_emb(**inputs)
43
- if spk_emb is not None:
44
- emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
45
- F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
46
-
47
- num_code = models['gpt'].emb_code[0].num_embeddings - 1
48
-
49
- LogitsWarpers = []
50
- if top_P is not None:
51
- LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
52
- if top_K is not None:
53
- LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
54
-
55
- LogitsProcessors = []
56
- if repetition_penalty is not None and repetition_penalty != 1:
57
- LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
58
- repetition_penalty, num_code, 16))
59
-
60
- result = models['gpt'].generate(
61
- emb, inputs['input_ids'],
62
- temperature = torch.tensor(temperature, device=device),
63
- attention_mask = inputs['attention_mask'],
64
- LogitsWarpers = LogitsWarpers,
65
- LogitsProcessors = LogitsProcessors,
66
- eos_token = num_code,
67
- max_new_token = max_new_token,
68
- infer_text = False,
69
- **kwargs
70
- )
71
-
72
- return result
73
-
74
-
75
- def refine_text(
76
- models,
77
- text,
78
- top_P = 0.7,
79
- top_K = 20,
80
- temperature = 0.7,
81
- repetition_penalty = 1.0,
82
- max_new_token = 384,
83
- prompt = '',
84
- **kwargs
85
- ):
86
-
87
- device = next(models['gpt'].parameters()).device
88
-
89
- if not isinstance(text, list):
90
- text = [text]
91
-
92
- assert len(text), 'text should not be empty'
93
-
94
- text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
95
- text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
96
- text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
97
-
98
- inputs = {
99
- 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
100
- 'text_mask': text_mask,
101
- 'attention_mask': text_token['attention_mask'],
102
- }
103
-
104
- LogitsWarpers = []
105
- if top_P is not None:
106
- LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
107
- if top_K is not None:
108
- LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
109
-
110
- LogitsProcessors = []
111
- if repetition_penalty is not None and repetition_penalty != 1:
112
- LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
113
-
114
- result = models['gpt'].generate(
115
- models['gpt'].get_emb(**inputs), inputs['input_ids'],
116
- temperature = torch.tensor([temperature,], device=device),
117
- attention_mask = inputs['attention_mask'],
118
- LogitsWarpers = LogitsWarpers,
119
- LogitsProcessors = LogitsProcessors,
120
- eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
121
- max_new_token = max_new_token,
122
- infer_text = True,
123
- **kwargs
124
- )
125
- return result
File without changes
@@ -1,155 +0,0 @@
1
- import math
2
- from einops import rearrange
3
- from vector_quantize_pytorch import GroupedResidualFSQ
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- class ConvNeXtBlock(nn.Module):
10
- def __init__(
11
- self,
12
- dim: int,
13
- intermediate_dim: int,
14
- kernel, dilation,
15
- layer_scale_init_value: float = 1e-6,
16
- ):
17
- # ConvNeXt Block copied from Vocos.
18
- super().__init__()
19
- self.dwconv = nn.Conv1d(dim, dim,
20
- kernel_size=kernel, padding=dilation*(kernel//2),
21
- dilation=dilation, groups=dim
22
- ) # depthwise conv
23
-
24
- self.norm = nn.LayerNorm(dim, eps=1e-6)
25
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
- self.act = nn.GELU()
27
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
- self.gamma = (
29
- nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
- if layer_scale_init_value > 0
31
- else None
32
- )
33
-
34
- def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
35
- residual = x
36
- x = self.dwconv(x)
37
- x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
- x = self.norm(x)
39
- x = self.pwconv1(x)
40
- x = self.act(x)
41
- x = self.pwconv2(x)
42
- if self.gamma is not None:
43
- x = self.gamma * x
44
- x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
-
46
- x = residual + x
47
- return x
48
-
49
-
50
-
51
- class GFSQ(nn.Module):
52
-
53
- def __init__(self,
54
- dim, levels, G, R, eps=1e-5, transpose = True
55
- ):
56
- super(GFSQ, self).__init__()
57
- self.quantizer = GroupedResidualFSQ(
58
- dim=dim,
59
- levels=levels,
60
- num_quantizers=R,
61
- groups=G,
62
- )
63
- self.n_ind = math.prod(levels)
64
- self.eps = eps
65
- self.transpose = transpose
66
- self.G = G
67
- self.R = R
68
-
69
- def _embed(self, x):
70
- if self.transpose:
71
- x = x.transpose(1,2)
72
- x = rearrange(
73
- x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
- )
75
- feat = self.quantizer.get_output_from_indices(x)
76
- return feat.transpose(1,2) if self.transpose else feat
77
-
78
- def forward(self, x,):
79
- if self.transpose:
80
- x = x.transpose(1,2)
81
- feat, ind = self.quantizer(x)
82
- ind = rearrange(
83
- ind, "g b t r ->b t (g r)",
84
- )
85
- embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
86
- e_mean = torch.mean(embed_onehot, dim=[0,1])
87
- e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
89
-
90
- return (
91
- torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
92
- feat.transpose(1,2) if self.transpose else feat,
93
- perplexity,
94
- None,
95
- ind.transpose(1,2) if self.transpose else ind,
96
- )
97
-
98
- class DVAEDecoder(nn.Module):
99
- def __init__(self, idim, odim,
100
- n_layer = 12, bn_dim = 64, hidden = 256,
101
- kernel = 7, dilation = 2, up = False
102
- ):
103
- super().__init__()
104
- self.up = up
105
- self.conv_in = nn.Sequential(
106
- nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
107
- nn.Conv1d(bn_dim, hidden, 3, 1, 1)
108
- )
109
- self.decoder_block = nn.ModuleList([
110
- ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
111
- for _ in range(n_layer)])
112
- self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
113
-
114
- def forward(self, input, conditioning=None):
115
- # B, T, C
116
- x = input.transpose(1, 2)
117
- x = self.conv_in(x)
118
- for f in self.decoder_block:
119
- x = f(x, conditioning)
120
-
121
- x = self.conv_out(x)
122
- return x.transpose(1, 2)
123
-
124
-
125
- class DVAE(nn.Module):
126
- def __init__(
127
- self, decoder_config, vq_config, dim=512
128
- ):
129
- super().__init__()
130
- self.register_buffer('coef', torch.randn(1, 100, 1))
131
-
132
- self.decoder = DVAEDecoder(**decoder_config)
133
- self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
134
- if vq_config is not None:
135
- self.vq_layer = GFSQ(**vq_config)
136
- else:
137
- self.vq_layer = None
138
-
139
- def forward(self, inp):
140
-
141
- if self.vq_layer is not None:
142
- vq_feats = self.vq_layer._embed(inp)
143
- else:
144
- vq_feats = inp.detach().clone()
145
-
146
- temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
147
- temp = torch.stack(temp, -1)
148
- vq_feats = temp.reshape(*temp.shape[:2], -1)
149
-
150
- vq_feats = vq_feats.transpose(1, 2)
151
- dec_out = self.decoder(input=vq_feats)
152
- dec_out = self.out_conv(dec_out.transpose(1, 2))
153
- mel = dec_out * self.coef
154
-
155
- return mel
@@ -1,265 +0,0 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
-
4
- import logging
5
- from tqdm import tqdm
6
- from einops import rearrange
7
- from transformers.cache_utils import Cache
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torch.nn.utils.parametrize as P
13
- from torch.nn.utils.parametrizations import weight_norm
14
- from transformers import LlamaModel, LlamaConfig
15
-
16
-
17
- class LlamaMLP(nn.Module):
18
- def __init__(self, hidden_size, intermediate_size):
19
- super().__init__()
20
- self.hidden_size = hidden_size
21
- self.intermediate_size = intermediate_size
22
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
23
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
24
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
25
- self.act_fn = F.silu
26
-
27
- def forward(self, x):
28
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
29
- return down_proj
30
-
31
-
32
- class GPT_warpper(nn.Module):
33
- def __init__(
34
- self,
35
- gpt_config,
36
- num_audio_tokens,
37
- num_text_tokens,
38
- num_vq=4,
39
- **kwargs,
40
- ):
41
- super().__init__()
42
-
43
- self.logger = logging.getLogger(__name__)
44
- self.gpt = self.build_model(gpt_config)
45
- self.model_dim = self.gpt.config.hidden_size
46
-
47
- self.num_vq = num_vq
48
- self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
49
- self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
50
- self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
51
- self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
52
-
53
- def build_model(self, config):
54
-
55
- configuration = LlamaConfig(**config)
56
- model = LlamaModel(configuration)
57
- del model.embed_tokens
58
-
59
- return model
60
-
61
- def get_emb(self, input_ids, text_mask, **kwargs):
62
-
63
- emb_text = self.emb_text(input_ids[text_mask][:, 0])
64
-
65
- emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
66
- emb_code = torch.stack(emb_code, 2).sum(2)
67
-
68
- emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
69
- emb[text_mask] = emb_text
70
- emb[~text_mask] = emb_code.to(emb.dtype)
71
-
72
- return emb
73
-
74
- def prepare_inputs_for_generation(
75
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
76
- ):
77
- # With static cache, the `past_key_values` is None
78
- # TODO joao: standardize interface for the different Cache classes and remove of this if
79
- has_static_cache = False
80
- if past_key_values is None:
81
- past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None)
82
- has_static_cache = past_key_values is not None
83
-
84
- past_length = 0
85
- if past_key_values is not None:
86
- if isinstance(past_key_values, Cache):
87
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
88
- max_cache_length = (
89
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
90
- if past_key_values.get_max_length() is not None
91
- else None
92
- )
93
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
94
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
95
- else:
96
- cache_length = past_length = past_key_values[0][0].shape[2]
97
- max_cache_length = None
98
-
99
- # Keep only the unprocessed tokens:
100
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
101
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
102
- # input)
103
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
104
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
105
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
106
- # input_ids based on the past_length.
107
- elif past_length < input_ids.shape[1]:
108
- input_ids = input_ids[:, past_length:]
109
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
110
-
111
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
112
- if (
113
- max_cache_length is not None
114
- and attention_mask is not None
115
- and cache_length + input_ids.shape[1] > max_cache_length
116
- ):
117
- attention_mask = attention_mask[:, -max_cache_length:]
118
-
119
- position_ids = kwargs.get("position_ids", None)
120
- if attention_mask is not None and position_ids is None:
121
- # create position_ids on the fly for batch generation
122
- position_ids = attention_mask.long().cumsum(-1) - 1
123
- position_ids.masked_fill_(attention_mask == 0, 1)
124
- if past_key_values:
125
- position_ids = position_ids[:, -input_ids.shape[1] :]
126
-
127
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
128
- if inputs_embeds is not None and past_key_values is None:
129
- model_inputs = {"inputs_embeds": inputs_embeds}
130
- else:
131
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
132
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
133
- # TODO: use `next_tokens` directly instead.
134
- model_inputs = {"input_ids": input_ids.contiguous()}
135
-
136
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
137
- if cache_position is None:
138
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
139
- else:
140
- cache_position = cache_position[-input_length:]
141
-
142
- if has_static_cache:
143
- past_key_values = None
144
-
145
- model_inputs.update(
146
- {
147
- "position_ids": position_ids,
148
- "cache_position": cache_position,
149
- "past_key_values": past_key_values,
150
- "use_cache": kwargs.get("use_cache"),
151
- "attention_mask": attention_mask,
152
- }
153
- )
154
- return model_inputs
155
-
156
- def generate(
157
- self,
158
- emb,
159
- inputs_ids,
160
- temperature,
161
- eos_token,
162
- attention_mask = None,
163
- max_new_token = 2048,
164
- min_new_token = 0,
165
- LogitsWarpers = [],
166
- LogitsProcessors = [],
167
- infer_text=False,
168
- return_attn=False,
169
- return_hidden=False,
170
- ):
171
-
172
- with torch.no_grad():
173
-
174
- attentions = []
175
- hiddens = []
176
-
177
- start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
178
- finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
179
-
180
- temperature = temperature[None].expand(inputs_ids.shape[0], -1)
181
- temperature = rearrange(temperature, "b n -> (b n) 1")
182
-
183
- attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
184
- if attention_mask is not None:
185
- attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
186
-
187
- for i in tqdm(range(max_new_token)):
188
-
189
- model_input = self.prepare_inputs_for_generation(inputs_ids,
190
- outputs.past_key_values if i!=0 else None,
191
- attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True)
192
-
193
- if i == 0:
194
- model_input['inputs_embeds'] = emb
195
- else:
196
- if infer_text:
197
- model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
198
- else:
199
- code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
200
- model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
201
-
202
- model_input['input_ids'] = None
203
- outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
204
- attentions.append(outputs.attentions)
205
- hidden_states = outputs[0] # 🐻
206
- if return_hidden:
207
- hiddens.append(hidden_states[:, -1])
208
-
209
- with P.cached():
210
- if infer_text:
211
- logits = self.head_text(hidden_states)
212
- else:
213
- logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
214
-
215
- logits = logits[:, -1].float()
216
-
217
- if not infer_text:
218
- logits = rearrange(logits, "b c n -> (b n) c")
219
- logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
220
- else:
221
- logits_token = inputs_ids[:, start_idx:, 0]
222
-
223
- logits = logits / temperature
224
-
225
- for logitsProcessors in LogitsProcessors:
226
- logits = logitsProcessors(logits_token, logits)
227
-
228
- for logitsWarpers in LogitsWarpers:
229
- logits = logitsWarpers(logits_token, logits)
230
-
231
- if i < min_new_token:
232
- logits[:, eos_token] = -torch.inf
233
-
234
- scores = F.softmax(logits, dim=-1)
235
-
236
- idx_next = torch.multinomial(scores, num_samples=1)
237
-
238
- if not infer_text:
239
- idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
240
- finish = finish | (idx_next == eos_token).any(1)
241
- inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
242
- else:
243
- finish = finish | (idx_next == eos_token).any(1)
244
- inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1)
245
-
246
- end_idx = end_idx + (~finish).int()
247
-
248
- if finish.all():
249
- break
250
-
251
- inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
252
- inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
253
-
254
- if return_hidden:
255
- hiddens = torch.stack(hiddens, 1)
256
- hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
257
-
258
- if not finish.all():
259
- self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
260
-
261
- return {
262
- 'ids': inputs_ids,
263
- 'attentions': attentions,
264
- 'hiddens':hiddens,
265
- }
File without changes
@@ -1,23 +0,0 @@
1
-
2
- import torch
3
- import logging
4
-
5
- def select_device(min_memory = 2048):
6
- logger = logging.getLogger(__name__)
7
- if torch.cuda.is_available():
8
- available_gpus = []
9
- for i in range(torch.cuda.device_count()):
10
- props = torch.cuda.get_device_properties(i)
11
- free_memory = props.total_memory - torch.cuda.memory_reserved(i)
12
- available_gpus.append((i, free_memory))
13
- selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1])
14
- device = torch.device(f'cuda:{selected_gpu}')
15
- free_memory_mb = max_free_memory / (1024 * 1024)
16
- if free_memory_mb < min_memory:
17
- logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
18
- device = torch.device('cpu')
19
- else:
20
- logger.log(logging.WARNING, f'No GPU found, use CPU instead')
21
- device = torch.device('cpu')
22
-
23
- return device