xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (62) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +48 -0
  3. xinference/client/restful/restful_client.py +19 -0
  4. xinference/constants.py +1 -0
  5. xinference/core/chat_interface.py +5 -1
  6. xinference/core/image_interface.py +5 -1
  7. xinference/core/model.py +106 -16
  8. xinference/core/scheduler.py +1 -1
  9. xinference/core/worker.py +3 -1
  10. xinference/deploy/supervisor.py +0 -4
  11. xinference/model/audio/chattts.py +25 -14
  12. xinference/model/audio/core.py +6 -2
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/core.py +3 -1
  16. xinference/model/embedding/core.py +6 -2
  17. xinference/model/embedding/model_spec.json +1 -1
  18. xinference/model/image/core.py +65 -6
  19. xinference/model/image/model_spec.json +24 -3
  20. xinference/model/image/model_spec_modelscope.json +25 -3
  21. xinference/model/image/ocr/__init__.py +13 -0
  22. xinference/model/image/ocr/got_ocr2.py +79 -0
  23. xinference/model/image/scheduler/flux.py +1 -1
  24. xinference/model/image/stable_diffusion/core.py +2 -3
  25. xinference/model/image/stable_diffusion/mlx.py +221 -0
  26. xinference/model/llm/__init__.py +33 -0
  27. xinference/model/llm/core.py +3 -1
  28. xinference/model/llm/llm_family.json +9 -0
  29. xinference/model/llm/llm_family.py +68 -2
  30. xinference/model/llm/llm_family_modelscope.json +11 -0
  31. xinference/model/llm/llm_family_openmind_hub.json +1359 -0
  32. xinference/model/rerank/core.py +9 -1
  33. xinference/model/utils.py +7 -0
  34. xinference/model/video/core.py +6 -2
  35. xinference/thirdparty/mlx/__init__.py +13 -0
  36. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  37. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  38. xinference/thirdparty/mlx/flux/clip.py +154 -0
  39. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  40. xinference/thirdparty/mlx/flux/flux.py +247 -0
  41. xinference/thirdparty/mlx/flux/layers.py +302 -0
  42. xinference/thirdparty/mlx/flux/lora.py +76 -0
  43. xinference/thirdparty/mlx/flux/model.py +134 -0
  44. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  45. xinference/thirdparty/mlx/flux/t5.py +244 -0
  46. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  47. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  48. xinference/thirdparty/mlx/flux/utils.py +179 -0
  49. xinference/web/ui/build/asset-manifest.json +3 -3
  50. xinference/web/ui/build/index.html +1 -1
  51. xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
  52. xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
  54. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
  55. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
  56. xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  58. /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
  59. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
  60. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
  61. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
  62. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+ from PIL import Image, ImageFile
4
+ from tqdm import tqdm
5
+
6
+ from .datasets import Dataset
7
+ from .flux import FluxPipeline
8
+
9
+
10
+ class Trainer:
11
+
12
+ def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
13
+ self.flux = flux
14
+ self.dataset = dataset
15
+ self.args = args
16
+ self.latents = []
17
+ self.t5_features = []
18
+ self.clip_features = []
19
+
20
+ def _random_crop_resize(self, img):
21
+ resolution = self.args.resolution
22
+ width, height = img.size
23
+
24
+ a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
25
+
26
+ # Random crop the input image between 0.8 to 1.0 of its original dimensions
27
+ crop_size = (
28
+ max((0.8 + 0.2 * a) * width, resolution[0]),
29
+ max((0.8 + 0.2 * b) * height, resolution[1]),
30
+ )
31
+ pan = (width - crop_size[0], height - crop_size[1])
32
+ img = img.crop(
33
+ (
34
+ pan[0] * c,
35
+ pan[1] * d,
36
+ crop_size[0] + pan[0] * c,
37
+ crop_size[1] + pan[1] * d,
38
+ )
39
+ )
40
+
41
+ # Fit the largest rectangle with the ratio of resolution in the image
42
+ # rectangle.
43
+ width, height = crop_size
44
+ ratio = resolution[0] / resolution[1]
45
+ r1 = (height * ratio, height)
46
+ r2 = (width, width / ratio)
47
+ r = r1 if r1[0] <= width else r2
48
+ img = img.crop(
49
+ (
50
+ (width - r[0]) / 2,
51
+ (height - r[1]) / 2,
52
+ (width + r[0]) / 2,
53
+ (height + r[1]) / 2,
54
+ )
55
+ )
56
+
57
+ # Finally resize the image to resolution
58
+ img = img.resize(resolution, Image.LANCZOS)
59
+
60
+ return mx.array(np.array(img))
61
+
62
+ def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
63
+ for i in range(num_augmentations):
64
+ img = self._random_crop_resize(input_img)
65
+ img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
66
+ x_0 = self.flux.ae.encode(img[None])
67
+ x_0 = x_0.astype(self.flux.dtype)
68
+ mx.eval(x_0)
69
+ self.latents.append(x_0)
70
+
71
+ def _encode_prompt(self, prompt):
72
+ t5_tok, clip_tok = self.flux.tokenize([prompt])
73
+ t5_feat = self.flux.t5(t5_tok)
74
+ clip_feat = self.flux.clip(clip_tok).pooled_output
75
+ mx.eval(t5_feat, clip_feat)
76
+ self.t5_features.append(t5_feat)
77
+ self.clip_features.append(clip_feat)
78
+
79
+ def encode_dataset(self):
80
+ """Encode the images & prompt in the latent space to prepare for training."""
81
+ self.flux.ae.eval()
82
+ for image, prompt in tqdm(self.dataset, desc="encode dataset"):
83
+ self._encode_image(image, self.args.num_augmentations)
84
+ self._encode_prompt(prompt)
85
+
86
+ def iterate(self, batch_size):
87
+ xs = mx.concatenate(self.latents)
88
+ t5 = mx.concatenate(self.t5_features)
89
+ clip = mx.concatenate(self.clip_features)
90
+ mx.eval(xs, t5, clip)
91
+ n_aug = self.args.num_augmentations
92
+ while True:
93
+ x_indices = mx.random.permutation(len(self.latents))
94
+ c_indices = x_indices // n_aug
95
+ for i in range(0, len(self.latents), batch_size):
96
+ x_i = x_indices[i : i + batch_size]
97
+ c_i = c_indices[i : i + batch_size]
98
+ yield xs[x_i], t5[c_i], clip[c_i]
@@ -0,0 +1,179 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import mlx.core as mx
9
+
10
+ from .autoencoder import AutoEncoder, AutoEncoderParams
11
+ from .clip import CLIPTextModel, CLIPTextModelConfig
12
+ from .model import Flux, FluxParams
13
+ from .t5 import T5Config, T5Encoder
14
+ from .tokenizers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ @dataclass
18
+ class ModelSpec:
19
+ params: FluxParams
20
+ ae_params: AutoEncoderParams
21
+ ckpt_path: Optional[str]
22
+ ae_path: Optional[str]
23
+ repo_id: Optional[str]
24
+ repo_flow: Optional[str]
25
+ repo_ae: Optional[str]
26
+
27
+
28
+ configs = {
29
+ "flux-dev": ModelSpec(
30
+ repo_id="black-forest-labs/FLUX.1-dev",
31
+ repo_flow="flux1-dev.safetensors",
32
+ repo_ae="ae.safetensors",
33
+ ckpt_path=os.getenv("FLUX_DEV"),
34
+ params=FluxParams(
35
+ in_channels=64,
36
+ vec_in_dim=768,
37
+ context_in_dim=4096,
38
+ hidden_size=3072,
39
+ mlp_ratio=4.0,
40
+ num_heads=24,
41
+ depth=19,
42
+ depth_single_blocks=38,
43
+ axes_dim=[16, 56, 56],
44
+ theta=10_000,
45
+ qkv_bias=True,
46
+ guidance_embed=True,
47
+ ),
48
+ ae_path=os.getenv("AE"),
49
+ ae_params=AutoEncoderParams(
50
+ resolution=256,
51
+ in_channels=3,
52
+ ch=128,
53
+ out_ch=3,
54
+ ch_mult=[1, 2, 4, 4],
55
+ num_res_blocks=2,
56
+ z_channels=16,
57
+ scale_factor=0.3611,
58
+ shift_factor=0.1159,
59
+ ),
60
+ ),
61
+ "flux-schnell": ModelSpec(
62
+ repo_id="black-forest-labs/FLUX.1-schnell",
63
+ repo_flow="flux1-schnell.safetensors",
64
+ repo_ae="ae.safetensors",
65
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
66
+ params=FluxParams(
67
+ in_channels=64,
68
+ vec_in_dim=768,
69
+ context_in_dim=4096,
70
+ hidden_size=3072,
71
+ mlp_ratio=4.0,
72
+ num_heads=24,
73
+ depth=19,
74
+ depth_single_blocks=38,
75
+ axes_dim=[16, 56, 56],
76
+ theta=10_000,
77
+ qkv_bias=True,
78
+ guidance_embed=False,
79
+ ),
80
+ ae_path=os.getenv("AE"),
81
+ ae_params=AutoEncoderParams(
82
+ resolution=256,
83
+ in_channels=3,
84
+ ch=128,
85
+ out_ch=3,
86
+ ch_mult=[1, 2, 4, 4],
87
+ num_res_blocks=2,
88
+ z_channels=16,
89
+ scale_factor=0.3611,
90
+ shift_factor=0.1159,
91
+ ),
92
+ ),
93
+ }
94
+
95
+
96
+ def load_flow_model(name: str, ckpt_path: str):
97
+ # Make the model
98
+ model = Flux(configs[name].params)
99
+
100
+ # Load the checkpoint if needed
101
+ if os.path.isdir(ckpt_path):
102
+ ckpt_path = os.path.join(ckpt_path, configs[name].repo_flow)
103
+ weights = mx.load(ckpt_path)
104
+ weights = model.sanitize(weights)
105
+ model.load_weights(list(weights.items()))
106
+
107
+ return model
108
+
109
+
110
+ def load_ae(name: str, ckpt_path: str):
111
+ # Make the autoencoder
112
+ ae = AutoEncoder(configs[name].ae_params)
113
+
114
+ # Load the checkpoint if needed
115
+ ckpt_path = os.path.join(ckpt_path, "ae.safetensors")
116
+ weights = mx.load(ckpt_path)
117
+ weights = ae.sanitize(weights)
118
+ ae.load_weights(list(weights.items()))
119
+
120
+ return ae
121
+
122
+
123
+ def load_clip(name: str, ckpt_path: str):
124
+ config_path = os.path.join(ckpt_path, "text_encoder/config.json")
125
+ with open(config_path) as f:
126
+ config = CLIPTextModelConfig.from_dict(json.load(f))
127
+
128
+ # Make the clip text encoder
129
+ clip = CLIPTextModel(config)
130
+
131
+ ckpt_path = os.path.join(ckpt_path, "text_encoder/model.safetensors")
132
+ weights = mx.load(ckpt_path)
133
+ weights = clip.sanitize(weights)
134
+ clip.load_weights(list(weights.items()))
135
+
136
+ return clip
137
+
138
+
139
+ def load_t5(name: str, ckpt_path: str):
140
+ config_path = os.path.join(ckpt_path, "text_encoder_2/config.json")
141
+ with open(config_path) as f:
142
+ config = T5Config.from_dict(json.load(f))
143
+
144
+ # Make the T5 model
145
+ t5 = T5Encoder(config)
146
+
147
+ model_index = os.path.join(ckpt_path, "text_encoder_2/model.safetensors.index.json")
148
+ weight_files = set()
149
+ with open(model_index) as f:
150
+ for _, w in json.load(f)["weight_map"].items():
151
+ weight_files.add(w)
152
+ weights = {}
153
+ for w in weight_files:
154
+ w = f"text_encoder_2/{w}"
155
+ w = os.path.join(ckpt_path, w)
156
+ weights.update(mx.load(w))
157
+ weights = t5.sanitize(weights)
158
+ t5.load_weights(list(weights.items()))
159
+
160
+ return t5
161
+
162
+
163
+ def load_clip_tokenizer(name: str, ckpt_path: str):
164
+ vocab_file = os.path.join(ckpt_path, "tokenizer/vocab.json")
165
+ with open(vocab_file, encoding="utf-8") as f:
166
+ vocab = json.load(f)
167
+
168
+ merges_file = os.path.join(ckpt_path, "tokenizer/merges.txt")
169
+ with open(merges_file, encoding="utf-8") as f:
170
+ bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
171
+ bpe_merges = [tuple(m.split()) for m in bpe_merges]
172
+ bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
173
+
174
+ return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
175
+
176
+
177
+ def load_t5_tokenizer(name: str, ckpt_path: str, pad: bool = True):
178
+ model_file = os.path.join(ckpt_path, "tokenizer_2/spiece.model")
179
+ return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
@@ -1,14 +1,14 @@
1
1
  {
2
2
  "files": {
3
3
  "main.css": "./static/css/main.5061c4c3.css",
4
- "main.js": "./static/js/main.f7da0140.js",
4
+ "main.js": "./static/js/main.2f269bb3.js",
5
5
  "static/media/icon.webp": "./static/media/icon.4603d52c63041e5dfbfd.webp",
6
6
  "index.html": "./index.html",
7
7
  "main.5061c4c3.css.map": "./static/css/main.5061c4c3.css.map",
8
- "main.f7da0140.js.map": "./static/js/main.f7da0140.js.map"
8
+ "main.2f269bb3.js.map": "./static/js/main.2f269bb3.js.map"
9
9
  },
10
10
  "entrypoints": [
11
11
  "static/css/main.5061c4c3.css",
12
- "static/js/main.f7da0140.js"
12
+ "static/js/main.2f269bb3.js"
13
13
  ]
14
14
  }
@@ -1 +1 @@
1
- <!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.f7da0140.js"></script><link href="./static/css/main.5061c4c3.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
1
+ <!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.2f269bb3.js"></script><link href="./static/css/main.5061c4c3.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>