xinference 1.6.0.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +79 -2
  3. xinference/client/restful/restful_client.py +65 -3
  4. xinference/conftest.py +0 -7
  5. xinference/core/media_interface.py +132 -8
  6. xinference/core/model.py +44 -6
  7. xinference/core/scheduler.py +1 -10
  8. xinference/core/supervisor.py +8 -17
  9. xinference/core/worker.py +5 -27
  10. xinference/deploy/cmdline.py +6 -2
  11. xinference/model/audio/chattts.py +24 -39
  12. xinference/model/audio/cosyvoice.py +18 -30
  13. xinference/model/audio/funasr.py +42 -0
  14. xinference/model/audio/model_spec.json +71 -1
  15. xinference/model/audio/model_spec_modelscope.json +76 -2
  16. xinference/model/audio/utils.py +75 -0
  17. xinference/model/core.py +1 -0
  18. xinference/model/embedding/__init__.py +74 -18
  19. xinference/model/embedding/core.py +98 -589
  20. xinference/model/embedding/embed_family.py +133 -0
  21. xinference/{thirdparty/omnilmm/train → model/embedding/flag}/__init__.py +1 -1
  22. xinference/model/embedding/flag/core.py +282 -0
  23. xinference/model/embedding/model_spec.json +24 -0
  24. xinference/model/embedding/model_spec_modelscope.json +24 -0
  25. xinference/model/embedding/sentence_transformers/__init__.py +13 -0
  26. xinference/model/embedding/sentence_transformers/core.py +399 -0
  27. xinference/model/embedding/vllm/core.py +95 -0
  28. xinference/model/image/model_spec.json +30 -3
  29. xinference/model/image/model_spec_modelscope.json +41 -2
  30. xinference/model/image/stable_diffusion/core.py +144 -53
  31. xinference/model/llm/__init__.py +6 -54
  32. xinference/model/llm/core.py +19 -5
  33. xinference/model/llm/llama_cpp/core.py +59 -3
  34. xinference/model/llm/llama_cpp/memory.py +457 -0
  35. xinference/model/llm/llm_family.json +247 -402
  36. xinference/model/llm/llm_family.py +88 -16
  37. xinference/model/llm/llm_family_modelscope.json +260 -421
  38. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  39. xinference/model/llm/sglang/core.py +8 -0
  40. xinference/model/llm/transformers/__init__.py +27 -6
  41. xinference/model/llm/transformers/chatglm.py +4 -2
  42. xinference/model/llm/transformers/core.py +49 -28
  43. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  44. xinference/model/llm/transformers/gemma3.py +119 -164
  45. xinference/model/llm/transformers/multimodal/__init__.py +13 -0
  46. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  47. xinference/model/llm/transformers/multimodal/core.py +205 -0
  48. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  49. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  50. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  51. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  52. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  53. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  54. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  55. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  56. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  57. xinference/model/llm/transformers/opt.py +4 -2
  58. xinference/model/llm/transformers/utils.py +6 -37
  59. xinference/model/llm/utils.py +11 -0
  60. xinference/model/llm/vllm/core.py +7 -0
  61. xinference/model/rerank/core.py +91 -3
  62. xinference/model/rerank/model_spec.json +24 -0
  63. xinference/model/rerank/model_spec_modelscope.json +24 -0
  64. xinference/model/rerank/utils.py +20 -2
  65. xinference/model/utils.py +38 -1
  66. xinference/model/video/diffusers.py +65 -3
  67. xinference/model/video/model_spec.json +31 -4
  68. xinference/model/video/model_spec_modelscope.json +32 -4
  69. xinference/web/ui/build/asset-manifest.json +6 -6
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/css/main.013f296b.css +2 -0
  72. xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
  73. xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
  74. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
  79. xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
  80. xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
  81. xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
  82. xinference/web/ui/src/locales/en.json +21 -8
  83. xinference/web/ui/src/locales/ja.json +224 -0
  84. xinference/web/ui/src/locales/ko.json +224 -0
  85. xinference/web/ui/src/locales/zh.json +21 -8
  86. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/METADATA +14 -11
  87. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/RECORD +93 -100
  88. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/WHEEL +1 -1
  89. xinference/model/llm/transformers/cogvlm2.py +0 -442
  90. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  91. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  92. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  93. xinference/model/llm/transformers/intern_vl.py +0 -526
  94. xinference/model/llm/transformers/internlm2.py +0 -94
  95. xinference/model/llm/transformers/minicpmv25.py +0 -193
  96. xinference/model/llm/transformers/omnilmm.py +0 -132
  97. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  98. xinference/model/llm/transformers/qwen_vl.py +0 -360
  99. xinference/thirdparty/omnilmm/LICENSE +0 -201
  100. xinference/thirdparty/omnilmm/chat.py +0 -218
  101. xinference/thirdparty/omnilmm/constants.py +0 -4
  102. xinference/thirdparty/omnilmm/conversation.py +0 -332
  103. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  104. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  105. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  106. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  107. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  108. xinference/thirdparty/omnilmm/utils.py +0 -134
  109. xinference/web/ui/build/static/css/main.337afe76.css +0 -2
  110. xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
  111. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  112. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  113. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
  114. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  115. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  116. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  117. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  118. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
  120. /xinference/{thirdparty/omnilmm → model/embedding/vllm}/__init__.py +0 -0
  121. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
  122. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
  123. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
  124. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,578 +0,0 @@
1
- import base64
2
- import os
3
- import pickle
4
- from io import BytesIO
5
-
6
- import numpy as np
7
- import torch
8
- import torch.distributed as dist
9
- from PIL import Image
10
- from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
11
- from timm.data.transforms import RandomResizedCropAndInterpolation
12
- from torchvision import transforms
13
- from transformers import AutoConfig, StoppingCriteria
14
-
15
- try:
16
- from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
17
- except ImportError:
18
- OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
19
- OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
20
-
21
-
22
- class KeywordsStoppingCriteria(StoppingCriteria):
23
- def __init__(self, keywords, tokenizer, input_ids):
24
- self.keywords = keywords
25
- self.tokenizer = tokenizer
26
- self.start_len = None
27
- self.input_ids = input_ids
28
-
29
- def __call__(
30
- self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
31
- ) -> bool:
32
- if self.start_len is None:
33
- self.start_len = self.input_ids.shape[1]
34
- else:
35
- outputs = self.tokenizer.batch_decode(
36
- output_ids[:, self.start_len :], skip_special_tokens=True
37
- )[0]
38
- for keyword in self.keywords:
39
- if keyword in outputs:
40
- return True
41
- return False
42
-
43
-
44
- def auto_upgrade(config):
45
- cfg = AutoConfig.from_pretrained(config)
46
- if "llava" in config and cfg.model_type != "llava":
47
- print(
48
- "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
49
- )
50
- print(
51
- "You must upgrade the checkpoint to the new code base (this can be done automatically)."
52
- )
53
- confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
54
- if confirm.lower() in ["y", "yes"]:
55
- print("Upgrading checkpoint...")
56
- assert len(cfg.architectures) == 1
57
- setattr(cfg.__class__, "model_type", "llava")
58
- cfg.architectures[0] = "LlavaLlamaForCausalLM"
59
- cfg.save_pretrained(config)
60
- print("Checkpoint upgraded.")
61
- else:
62
- print("Checkpoint upgrade aborted.")
63
- exit(1)
64
-
65
-
66
- # aug functions
67
-
68
-
69
- def identity_func(img):
70
- return img
71
-
72
-
73
- def autocontrast_func(img, cutoff=0):
74
- """
75
- same output as PIL.ImageOps.autocontrast
76
- """
77
- import cv2
78
-
79
- n_bins = 256
80
-
81
- def tune_channel(ch):
82
- n = ch.size
83
- cut = cutoff * n // 100
84
- if cut == 0:
85
- high, low = ch.max(), ch.min()
86
- else:
87
- hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
88
- low = np.argwhere(np.cumsum(hist) > cut)
89
- low = 0 if low.shape[0] == 0 else low[0]
90
- high = np.argwhere(np.cumsum(hist[::-1]) > cut)
91
- high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
92
- if high <= low:
93
- table = np.arange(n_bins)
94
- else:
95
- scale = (n_bins - 1) / (high - low)
96
- table = np.arange(n_bins) * scale - low * scale
97
- table[table < 0] = 0
98
- table[table > n_bins - 1] = n_bins - 1
99
- table = table.clip(0, 255).astype(np.uint8)
100
- return table[ch]
101
-
102
- channels = [tune_channel(ch) for ch in cv2.split(img)]
103
- out = cv2.merge(channels)
104
- return out
105
-
106
-
107
- def equalize_func(img):
108
- """
109
- same output as PIL.ImageOps.equalize
110
- PIL's implementation is different from cv2.equalize
111
- """
112
- import cv2
113
-
114
- n_bins = 256
115
-
116
- def tune_channel(ch):
117
- hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
118
- non_zero_hist = hist[hist != 0].reshape(-1)
119
- step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
120
- if step == 0:
121
- return ch
122
- n = np.empty_like(hist)
123
- n[0] = step // 2
124
- n[1:] = hist[:-1]
125
- table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
126
- return table[ch]
127
-
128
- channels = [tune_channel(ch) for ch in cv2.split(img)]
129
- out = cv2.merge(channels)
130
- return out
131
-
132
-
133
- def rotate_func(img, degree, fill=(0, 0, 0)):
134
- """
135
- like PIL, rotate by degree, not radians
136
- """
137
- import cv2
138
-
139
- H, W = img.shape[0], img.shape[1]
140
- center = W / 2, H / 2
141
- M = cv2.getRotationMatrix2D(center, degree, 1)
142
- out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
143
- return out
144
-
145
-
146
- def solarize_func(img, thresh=128):
147
- """
148
- same output as PIL.ImageOps.posterize
149
- """
150
- table = np.array([el if el < thresh else 255 - el for el in range(256)])
151
- table = table.clip(0, 255).astype(np.uint8)
152
- out = table[img]
153
- return out
154
-
155
-
156
- def color_func(img, factor):
157
- """
158
- same output as PIL.ImageEnhance.Color
159
- """
160
- # implementation according to PIL definition, quite slow
161
- # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
162
- # out = blend(degenerate, img, factor)
163
- # M = (
164
- # np.eye(3) * factor
165
- # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
166
- # )[np.newaxis, np.newaxis, :]
167
- M = np.float32(
168
- [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
169
- ) * factor + np.float32([[0.114], [0.587], [0.299]])
170
- out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
171
- return out
172
-
173
-
174
- def contrast_func(img, factor):
175
- """
176
- same output as PIL.ImageEnhance.Contrast
177
- """
178
- mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
179
- table = (
180
- np.array([(el - mean) * factor + mean for el in range(256)])
181
- .clip(0, 255)
182
- .astype(np.uint8)
183
- )
184
- out = table[img]
185
- return out
186
-
187
-
188
- def brightness_func(img, factor):
189
- """
190
- same output as PIL.ImageEnhance.Contrast
191
- """
192
- table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
193
- out = table[img]
194
- return out
195
-
196
-
197
- def sharpness_func(img, factor):
198
- """
199
- The differences the this result and PIL are all on the 4 boundaries, the center
200
- areas are same
201
- """
202
- import cv2
203
-
204
- kernel = np.ones((3, 3), dtype=np.float32)
205
- kernel[1][1] = 5
206
- kernel /= 13
207
- degenerate = cv2.filter2D(img, -1, kernel)
208
- if factor == 0.0:
209
- out = degenerate
210
- elif factor == 1.0:
211
- out = img
212
- else:
213
- out = img.astype(np.float32)
214
- degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
215
- out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
216
- out = out.astype(np.uint8)
217
- return out
218
-
219
-
220
- def shear_x_func(img, factor, fill=(0, 0, 0)):
221
- import cv2
222
-
223
- H, W = img.shape[0], img.shape[1]
224
- M = np.float32([[1, factor, 0], [0, 1, 0]])
225
- out = cv2.warpAffine(
226
- img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
227
- ).astype(np.uint8)
228
- return out
229
-
230
-
231
- def translate_x_func(img, offset, fill=(0, 0, 0)):
232
- """
233
- same output as PIL.Image.transform
234
- """
235
- import cv2
236
-
237
- H, W = img.shape[0], img.shape[1]
238
- M = np.float32([[1, 0, -offset], [0, 1, 0]])
239
- out = cv2.warpAffine(
240
- img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
241
- ).astype(np.uint8)
242
- return out
243
-
244
-
245
- def translate_y_func(img, offset, fill=(0, 0, 0)):
246
- """
247
- same output as PIL.Image.transform
248
- """
249
- import cv2
250
-
251
- H, W = img.shape[0], img.shape[1]
252
- M = np.float32([[1, 0, 0], [0, 1, -offset]])
253
- out = cv2.warpAffine(
254
- img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
255
- ).astype(np.uint8)
256
- return out
257
-
258
-
259
- def posterize_func(img, bits):
260
- """
261
- same output as PIL.ImageOps.posterize
262
- """
263
- out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
264
- return out
265
-
266
-
267
- def shear_y_func(img, factor, fill=(0, 0, 0)):
268
- import cv2
269
-
270
- H, W = img.shape[0], img.shape[1]
271
- M = np.float32([[1, 0, 0], [factor, 1, 0]])
272
- out = cv2.warpAffine(
273
- img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
274
- ).astype(np.uint8)
275
- return out
276
-
277
-
278
- def cutout_func(img, pad_size, replace=(0, 0, 0)):
279
- replace = np.array(replace, dtype=np.uint8)
280
- H, W = img.shape[0], img.shape[1]
281
- rh, rw = np.random.random(2)
282
- pad_size = pad_size // 2
283
- ch, cw = int(rh * H), int(rw * W)
284
- x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
285
- y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
286
- out = img.copy()
287
- out[x1:x2, y1:y2, :] = replace
288
- return out
289
-
290
-
291
- # level to args
292
- def enhance_level_to_args(MAX_LEVEL):
293
- def level_to_args(level):
294
- return ((level / MAX_LEVEL) * 1.8 + 0.1,)
295
-
296
- return level_to_args
297
-
298
-
299
- def shear_level_to_args(MAX_LEVEL, replace_value):
300
- def level_to_args(level):
301
- level = (level / MAX_LEVEL) * 0.3
302
- if np.random.random() > 0.5:
303
- level = -level
304
- return (level, replace_value)
305
-
306
- return level_to_args
307
-
308
-
309
- def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
310
- def level_to_args(level):
311
- level = (level / MAX_LEVEL) * float(translate_const)
312
- if np.random.random() > 0.5:
313
- level = -level
314
- return (level, replace_value)
315
-
316
- return level_to_args
317
-
318
-
319
- def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
320
- def level_to_args(level):
321
- level = int((level / MAX_LEVEL) * cutout_const)
322
- return (level, replace_value)
323
-
324
- return level_to_args
325
-
326
-
327
- def solarize_level_to_args(MAX_LEVEL):
328
- def level_to_args(level):
329
- level = int((level / MAX_LEVEL) * 256)
330
- return (level,)
331
-
332
- return level_to_args
333
-
334
-
335
- def none_level_to_args(level):
336
- return ()
337
-
338
-
339
- def posterize_level_to_args(MAX_LEVEL):
340
- def level_to_args(level):
341
- level = int((level / MAX_LEVEL) * 4)
342
- return (level,)
343
-
344
- return level_to_args
345
-
346
-
347
- def rotate_level_to_args(MAX_LEVEL, replace_value):
348
- def level_to_args(level):
349
- level = (level / MAX_LEVEL) * 30
350
- if np.random.random() < 0.5:
351
- level = -level
352
- return (level, replace_value)
353
-
354
- return level_to_args
355
-
356
-
357
- func_dict = {
358
- "Identity": identity_func,
359
- "AutoContrast": autocontrast_func,
360
- "Equalize": equalize_func,
361
- "Rotate": rotate_func,
362
- "Solarize": solarize_func,
363
- "Color": color_func,
364
- "Contrast": contrast_func,
365
- "Brightness": brightness_func,
366
- "Sharpness": sharpness_func,
367
- "ShearX": shear_x_func,
368
- "TranslateX": translate_x_func,
369
- "TranslateY": translate_y_func,
370
- "Posterize": posterize_func,
371
- "ShearY": shear_y_func,
372
- }
373
-
374
- translate_const = 10
375
- MAX_LEVEL = 10
376
- replace_value = (128, 128, 128)
377
- arg_dict = {
378
- "Identity": none_level_to_args,
379
- "AutoContrast": none_level_to_args,
380
- "Equalize": none_level_to_args,
381
- "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
382
- "Solarize": solarize_level_to_args(MAX_LEVEL),
383
- "Color": enhance_level_to_args(MAX_LEVEL),
384
- "Contrast": enhance_level_to_args(MAX_LEVEL),
385
- "Brightness": enhance_level_to_args(MAX_LEVEL),
386
- "Sharpness": enhance_level_to_args(MAX_LEVEL),
387
- "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
388
- "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
389
- "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
390
- "Posterize": posterize_level_to_args(MAX_LEVEL),
391
- "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
392
- }
393
-
394
-
395
- class RandomAugment(object):
396
- def __init__(self, N=2, M=10, isPIL=False, augs=[]):
397
- self.N = N
398
- self.M = M
399
- self.isPIL = isPIL
400
- if augs:
401
- self.augs = augs
402
- else:
403
- self.augs = list(arg_dict.keys())
404
-
405
- def get_random_ops(self):
406
- sampled_ops = np.random.choice(self.augs, self.N)
407
- return [(op, 0.5, self.M) for op in sampled_ops]
408
-
409
- def __call__(self, img):
410
- if self.isPIL:
411
- img = np.array(img)
412
- ops = self.get_random_ops()
413
- for name, prob, level in ops:
414
- if np.random.random() > prob:
415
- continue
416
- args = arg_dict[name](level)
417
- img = func_dict[name](img, *args)
418
- return img
419
-
420
-
421
- def build_transform(
422
- is_train,
423
- randaug=True,
424
- input_size=224,
425
- interpolation="bicubic",
426
- std_mode="IMAGENET_INCEPTION",
427
- ):
428
- if std_mode == "IMAGENET_INCEPTION":
429
- mean = IMAGENET_INCEPTION_MEAN
430
- std = IMAGENET_INCEPTION_STD
431
- elif std_mode == "OPENAI_CLIP":
432
- mean = OPENAI_CLIP_MEAN
433
- std = OPENAI_CLIP_STD
434
- else:
435
- raise NotImplementedError
436
-
437
- if is_train:
438
- crop_scale = float(os.environ.get("TRAIN_CROP_SCALE", 0.9999))
439
- t = [
440
- RandomResizedCropAndInterpolation(
441
- input_size, scale=(crop_scale, 1.0), interpolation="bicubic"
442
- ),
443
- # transforms.RandomHorizontalFlip(),
444
- ]
445
- if randaug and os.environ.get("TRAIN_DO_AUG", "False") == "True":
446
- print(f"@@@@@ Do random aug during training", flush=True)
447
- t.append(
448
- RandomAugment(
449
- 2,
450
- 7,
451
- isPIL=True,
452
- augs=[
453
- "Identity",
454
- "AutoContrast",
455
- "Equalize",
456
- "Brightness",
457
- "Sharpness",
458
- "ShearX",
459
- "ShearY",
460
- "TranslateX",
461
- "TranslateY",
462
- "Rotate",
463
- ],
464
- )
465
- )
466
- else:
467
- print(f"@@@@@ Skip random aug during training", flush=True)
468
- t += [
469
- transforms.ToTensor(),
470
- transforms.Normalize(mean=mean, std=std),
471
- ]
472
- t = transforms.Compose(t)
473
- else:
474
- t = transforms.Compose(
475
- [
476
- transforms.Resize(
477
- (input_size, input_size),
478
- interpolation=transforms.InterpolationMode.BICUBIC,
479
- ),
480
- transforms.ToTensor(),
481
- transforms.Normalize(mean=mean, std=std),
482
- ]
483
- )
484
-
485
- return t
486
-
487
-
488
- def img2b64(img_path):
489
- img = Image.open(img_path) # path to file
490
- img_buffer = BytesIO()
491
- img.save(img_buffer, format=img.format)
492
- byte_data = img_buffer.getvalue()
493
- base64_str = base64.b64encode(byte_data) # bytes
494
- base64_str = base64_str.decode("utf-8") # str
495
- return base64_str
496
-
497
-
498
- def str2b64(str):
499
- return base64.b64encode(str.encode("utf-8")).decode("utf-8")
500
-
501
-
502
- def b642str(b64):
503
- return base64.b64decode(b64).decode("utf-8")
504
-
505
-
506
- def is_dist_avail_and_initialized():
507
- if not dist.is_available():
508
- return False
509
- if not dist.is_initialized():
510
- return False
511
- return True
512
-
513
-
514
- def get_world_size():
515
- if not is_dist_avail_and_initialized():
516
- return 1
517
- return dist.get_world_size()
518
-
519
-
520
- def get_rank():
521
- if not is_dist_avail_and_initialized():
522
- return 0
523
- return dist.get_rank()
524
-
525
-
526
- def all_gather(data):
527
- """
528
- Run all_gather on arbitrary picklable data (not necessarily tensors)
529
- Args:
530
- data: any picklable object
531
- Returns:
532
- list[data]: list of data gathered from each rank
533
- """
534
- world_size = get_world_size()
535
- if world_size == 1:
536
- return [data]
537
-
538
- # serialized to a Tensor
539
- buffer = pickle.dumps(data)
540
- storage = torch.ByteStorage.from_buffer(buffer)
541
- tensor = torch.ByteTensor(storage).to("cuda")
542
-
543
- # obtain Tensor size of each rank
544
- local_size = torch.LongTensor([tensor.numel()]).to("cuda")
545
- size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
546
- dist.all_gather(size_list, local_size)
547
- size_list = [int(size.item()) for size in size_list]
548
- max_size = max(size_list)
549
-
550
- # receiving Tensor from all ranks
551
- # we pad the tensor because torch all_gather does not support
552
- # gathering tensors of different shapes
553
- tensor_list = []
554
- for _ in size_list:
555
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
556
- if local_size != max_size:
557
- padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
558
- tensor = torch.cat((tensor, padding), dim=0)
559
- dist.all_gather(tensor_list, tensor)
560
-
561
- data_list = []
562
- for size, tensor in zip(size_list, tensor_list):
563
- buffer = tensor.cpu().numpy().tobytes()[:size]
564
- data_list.append(pickle.loads(buffer))
565
-
566
- return data_list
567
-
568
-
569
- def mean(lst):
570
- return sum(lst) / len(lst)
571
-
572
-
573
- def stop_gradient_by_name(name: str):
574
- def apply_fn(module):
575
- if hasattr(module, name):
576
- getattr(module, name).requires_grad_(False)
577
-
578
- return apply_fn