xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +4 -4
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +195 -34
- xinference/core/scheduler.py +10 -7
- xinference/core/utils.py +9 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +59 -4
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +76 -0
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/stable_diffusion/core.py +8 -34
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +178 -1
- xinference/model/llm/llm_family_modelscope.json +119 -0
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/core.py +37 -111
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/utils.py +4 -284
- xinference/model/llm/utils.py +2 -2
- xinference/model/llm/vllm/core.py +16 -1
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
- xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image, ImageFile
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
from .datasets import Dataset
|
|
7
|
+
from .flux import FluxPipeline
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Trainer:
|
|
11
|
+
|
|
12
|
+
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
|
13
|
+
self.flux = flux
|
|
14
|
+
self.dataset = dataset
|
|
15
|
+
self.args = args
|
|
16
|
+
self.latents = []
|
|
17
|
+
self.t5_features = []
|
|
18
|
+
self.clip_features = []
|
|
19
|
+
|
|
20
|
+
def _random_crop_resize(self, img):
|
|
21
|
+
resolution = self.args.resolution
|
|
22
|
+
width, height = img.size
|
|
23
|
+
|
|
24
|
+
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
|
25
|
+
|
|
26
|
+
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
|
27
|
+
crop_size = (
|
|
28
|
+
max((0.8 + 0.2 * a) * width, resolution[0]),
|
|
29
|
+
max((0.8 + 0.2 * b) * height, resolution[1]),
|
|
30
|
+
)
|
|
31
|
+
pan = (width - crop_size[0], height - crop_size[1])
|
|
32
|
+
img = img.crop(
|
|
33
|
+
(
|
|
34
|
+
pan[0] * c,
|
|
35
|
+
pan[1] * d,
|
|
36
|
+
crop_size[0] + pan[0] * c,
|
|
37
|
+
crop_size[1] + pan[1] * d,
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Fit the largest rectangle with the ratio of resolution in the image
|
|
42
|
+
# rectangle.
|
|
43
|
+
width, height = crop_size
|
|
44
|
+
ratio = resolution[0] / resolution[1]
|
|
45
|
+
r1 = (height * ratio, height)
|
|
46
|
+
r2 = (width, width / ratio)
|
|
47
|
+
r = r1 if r1[0] <= width else r2
|
|
48
|
+
img = img.crop(
|
|
49
|
+
(
|
|
50
|
+
(width - r[0]) / 2,
|
|
51
|
+
(height - r[1]) / 2,
|
|
52
|
+
(width + r[0]) / 2,
|
|
53
|
+
(height + r[1]) / 2,
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Finally resize the image to resolution
|
|
58
|
+
img = img.resize(resolution, Image.LANCZOS)
|
|
59
|
+
|
|
60
|
+
return mx.array(np.array(img))
|
|
61
|
+
|
|
62
|
+
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
|
63
|
+
for i in range(num_augmentations):
|
|
64
|
+
img = self._random_crop_resize(input_img)
|
|
65
|
+
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
|
66
|
+
x_0 = self.flux.ae.encode(img[None])
|
|
67
|
+
x_0 = x_0.astype(self.flux.dtype)
|
|
68
|
+
mx.eval(x_0)
|
|
69
|
+
self.latents.append(x_0)
|
|
70
|
+
|
|
71
|
+
def _encode_prompt(self, prompt):
|
|
72
|
+
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
|
73
|
+
t5_feat = self.flux.t5(t5_tok)
|
|
74
|
+
clip_feat = self.flux.clip(clip_tok).pooled_output
|
|
75
|
+
mx.eval(t5_feat, clip_feat)
|
|
76
|
+
self.t5_features.append(t5_feat)
|
|
77
|
+
self.clip_features.append(clip_feat)
|
|
78
|
+
|
|
79
|
+
def encode_dataset(self):
|
|
80
|
+
"""Encode the images & prompt in the latent space to prepare for training."""
|
|
81
|
+
self.flux.ae.eval()
|
|
82
|
+
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
|
83
|
+
self._encode_image(image, self.args.num_augmentations)
|
|
84
|
+
self._encode_prompt(prompt)
|
|
85
|
+
|
|
86
|
+
def iterate(self, batch_size):
|
|
87
|
+
xs = mx.concatenate(self.latents)
|
|
88
|
+
t5 = mx.concatenate(self.t5_features)
|
|
89
|
+
clip = mx.concatenate(self.clip_features)
|
|
90
|
+
mx.eval(xs, t5, clip)
|
|
91
|
+
n_aug = self.args.num_augmentations
|
|
92
|
+
while True:
|
|
93
|
+
x_indices = mx.random.permutation(len(self.latents))
|
|
94
|
+
c_indices = x_indices // n_aug
|
|
95
|
+
for i in range(0, len(self.latents), batch_size):
|
|
96
|
+
x_i = x_indices[i : i + batch_size]
|
|
97
|
+
c_i = c_indices[i : i + batch_size]
|
|
98
|
+
yield xs[x_i], t5[c_i], clip[c_i]
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
|
|
10
|
+
from .autoencoder import AutoEncoder, AutoEncoderParams
|
|
11
|
+
from .clip import CLIPTextModel, CLIPTextModelConfig
|
|
12
|
+
from .model import Flux, FluxParams
|
|
13
|
+
from .t5 import T5Config, T5Encoder
|
|
14
|
+
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ModelSpec:
|
|
19
|
+
params: FluxParams
|
|
20
|
+
ae_params: AutoEncoderParams
|
|
21
|
+
ckpt_path: Optional[str]
|
|
22
|
+
ae_path: Optional[str]
|
|
23
|
+
repo_id: Optional[str]
|
|
24
|
+
repo_flow: Optional[str]
|
|
25
|
+
repo_ae: Optional[str]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
configs = {
|
|
29
|
+
"flux-dev": ModelSpec(
|
|
30
|
+
repo_id="black-forest-labs/FLUX.1-dev",
|
|
31
|
+
repo_flow="flux1-dev.safetensors",
|
|
32
|
+
repo_ae="ae.safetensors",
|
|
33
|
+
ckpt_path=os.getenv("FLUX_DEV"),
|
|
34
|
+
params=FluxParams(
|
|
35
|
+
in_channels=64,
|
|
36
|
+
vec_in_dim=768,
|
|
37
|
+
context_in_dim=4096,
|
|
38
|
+
hidden_size=3072,
|
|
39
|
+
mlp_ratio=4.0,
|
|
40
|
+
num_heads=24,
|
|
41
|
+
depth=19,
|
|
42
|
+
depth_single_blocks=38,
|
|
43
|
+
axes_dim=[16, 56, 56],
|
|
44
|
+
theta=10_000,
|
|
45
|
+
qkv_bias=True,
|
|
46
|
+
guidance_embed=True,
|
|
47
|
+
),
|
|
48
|
+
ae_path=os.getenv("AE"),
|
|
49
|
+
ae_params=AutoEncoderParams(
|
|
50
|
+
resolution=256,
|
|
51
|
+
in_channels=3,
|
|
52
|
+
ch=128,
|
|
53
|
+
out_ch=3,
|
|
54
|
+
ch_mult=[1, 2, 4, 4],
|
|
55
|
+
num_res_blocks=2,
|
|
56
|
+
z_channels=16,
|
|
57
|
+
scale_factor=0.3611,
|
|
58
|
+
shift_factor=0.1159,
|
|
59
|
+
),
|
|
60
|
+
),
|
|
61
|
+
"flux-schnell": ModelSpec(
|
|
62
|
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
|
63
|
+
repo_flow="flux1-schnell.safetensors",
|
|
64
|
+
repo_ae="ae.safetensors",
|
|
65
|
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
|
66
|
+
params=FluxParams(
|
|
67
|
+
in_channels=64,
|
|
68
|
+
vec_in_dim=768,
|
|
69
|
+
context_in_dim=4096,
|
|
70
|
+
hidden_size=3072,
|
|
71
|
+
mlp_ratio=4.0,
|
|
72
|
+
num_heads=24,
|
|
73
|
+
depth=19,
|
|
74
|
+
depth_single_blocks=38,
|
|
75
|
+
axes_dim=[16, 56, 56],
|
|
76
|
+
theta=10_000,
|
|
77
|
+
qkv_bias=True,
|
|
78
|
+
guidance_embed=False,
|
|
79
|
+
),
|
|
80
|
+
ae_path=os.getenv("AE"),
|
|
81
|
+
ae_params=AutoEncoderParams(
|
|
82
|
+
resolution=256,
|
|
83
|
+
in_channels=3,
|
|
84
|
+
ch=128,
|
|
85
|
+
out_ch=3,
|
|
86
|
+
ch_mult=[1, 2, 4, 4],
|
|
87
|
+
num_res_blocks=2,
|
|
88
|
+
z_channels=16,
|
|
89
|
+
scale_factor=0.3611,
|
|
90
|
+
shift_factor=0.1159,
|
|
91
|
+
),
|
|
92
|
+
),
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def load_flow_model(name: str, ckpt_path: str):
|
|
97
|
+
# Make the model
|
|
98
|
+
model = Flux(configs[name].params)
|
|
99
|
+
|
|
100
|
+
# Load the checkpoint if needed
|
|
101
|
+
if os.path.isdir(ckpt_path):
|
|
102
|
+
ckpt_path = os.path.join(ckpt_path, configs[name].repo_flow)
|
|
103
|
+
weights = mx.load(ckpt_path)
|
|
104
|
+
weights = model.sanitize(weights)
|
|
105
|
+
model.load_weights(list(weights.items()))
|
|
106
|
+
|
|
107
|
+
return model
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def load_ae(name: str, ckpt_path: str):
|
|
111
|
+
# Make the autoencoder
|
|
112
|
+
ae = AutoEncoder(configs[name].ae_params)
|
|
113
|
+
|
|
114
|
+
# Load the checkpoint if needed
|
|
115
|
+
ckpt_path = os.path.join(ckpt_path, "ae.safetensors")
|
|
116
|
+
weights = mx.load(ckpt_path)
|
|
117
|
+
weights = ae.sanitize(weights)
|
|
118
|
+
ae.load_weights(list(weights.items()))
|
|
119
|
+
|
|
120
|
+
return ae
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def load_clip(name: str, ckpt_path: str):
|
|
124
|
+
config_path = os.path.join(ckpt_path, "text_encoder/config.json")
|
|
125
|
+
with open(config_path) as f:
|
|
126
|
+
config = CLIPTextModelConfig.from_dict(json.load(f))
|
|
127
|
+
|
|
128
|
+
# Make the clip text encoder
|
|
129
|
+
clip = CLIPTextModel(config)
|
|
130
|
+
|
|
131
|
+
ckpt_path = os.path.join(ckpt_path, "text_encoder/model.safetensors")
|
|
132
|
+
weights = mx.load(ckpt_path)
|
|
133
|
+
weights = clip.sanitize(weights)
|
|
134
|
+
clip.load_weights(list(weights.items()))
|
|
135
|
+
|
|
136
|
+
return clip
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def load_t5(name: str, ckpt_path: str):
|
|
140
|
+
config_path = os.path.join(ckpt_path, "text_encoder_2/config.json")
|
|
141
|
+
with open(config_path) as f:
|
|
142
|
+
config = T5Config.from_dict(json.load(f))
|
|
143
|
+
|
|
144
|
+
# Make the T5 model
|
|
145
|
+
t5 = T5Encoder(config)
|
|
146
|
+
|
|
147
|
+
model_index = os.path.join(ckpt_path, "text_encoder_2/model.safetensors.index.json")
|
|
148
|
+
weight_files = set()
|
|
149
|
+
with open(model_index) as f:
|
|
150
|
+
for _, w in json.load(f)["weight_map"].items():
|
|
151
|
+
weight_files.add(w)
|
|
152
|
+
weights = {}
|
|
153
|
+
for w in weight_files:
|
|
154
|
+
w = f"text_encoder_2/{w}"
|
|
155
|
+
w = os.path.join(ckpt_path, w)
|
|
156
|
+
weights.update(mx.load(w))
|
|
157
|
+
weights = t5.sanitize(weights)
|
|
158
|
+
t5.load_weights(list(weights.items()))
|
|
159
|
+
|
|
160
|
+
return t5
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def load_clip_tokenizer(name: str, ckpt_path: str):
|
|
164
|
+
vocab_file = os.path.join(ckpt_path, "tokenizer/vocab.json")
|
|
165
|
+
with open(vocab_file, encoding="utf-8") as f:
|
|
166
|
+
vocab = json.load(f)
|
|
167
|
+
|
|
168
|
+
merges_file = os.path.join(ckpt_path, "tokenizer/merges.txt")
|
|
169
|
+
with open(merges_file, encoding="utf-8") as f:
|
|
170
|
+
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
171
|
+
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
|
172
|
+
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
|
173
|
+
|
|
174
|
+
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def load_t5_tokenizer(name: str, ckpt_path: str, pad: bool = True):
|
|
178
|
+
model_file = os.path.join(ckpt_path, "tokenizer_2/spiece.model")
|
|
179
|
+
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
xinference/utils.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
{
|
|
2
2
|
"files": {
|
|
3
3
|
"main.css": "./static/css/main.5061c4c3.css",
|
|
4
|
-
"main.js": "./static/js/main.
|
|
4
|
+
"main.js": "./static/js/main.b76aeeb7.js",
|
|
5
5
|
"static/media/icon.webp": "./static/media/icon.4603d52c63041e5dfbfd.webp",
|
|
6
6
|
"index.html": "./index.html",
|
|
7
7
|
"main.5061c4c3.css.map": "./static/css/main.5061c4c3.css.map",
|
|
8
|
-
"main.
|
|
8
|
+
"main.b76aeeb7.js.map": "./static/js/main.b76aeeb7.js.map"
|
|
9
9
|
},
|
|
10
10
|
"entrypoints": [
|
|
11
11
|
"static/css/main.5061c4c3.css",
|
|
12
|
-
"static/js/main.
|
|
12
|
+
"static/js/main.b76aeeb7.js"
|
|
13
13
|
]
|
|
14
14
|
}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
<!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.
|
|
1
|
+
<!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.b76aeeb7.js"></script><link href="./static/css/main.5061c4c3.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|