xinference 0.9.4__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (103) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +47 -18
  3. xinference/api/oauth2/types.py +1 -0
  4. xinference/api/restful_api.py +34 -7
  5. xinference/client/oscar/actor_client.py +4 -3
  6. xinference/client/restful/restful_client.py +20 -4
  7. xinference/conftest.py +13 -2
  8. xinference/core/supervisor.py +48 -1
  9. xinference/core/worker.py +139 -20
  10. xinference/deploy/cmdline.py +119 -20
  11. xinference/model/embedding/core.py +1 -2
  12. xinference/model/llm/__init__.py +4 -6
  13. xinference/model/llm/ggml/llamacpp.py +2 -10
  14. xinference/model/llm/llm_family.json +877 -13
  15. xinference/model/llm/llm_family.py +15 -0
  16. xinference/model/llm/llm_family_modelscope.json +571 -0
  17. xinference/model/llm/pytorch/chatglm.py +2 -0
  18. xinference/model/llm/pytorch/core.py +22 -26
  19. xinference/model/llm/pytorch/deepseek_vl.py +232 -0
  20. xinference/model/llm/pytorch/internlm2.py +2 -0
  21. xinference/model/llm/pytorch/omnilmm.py +153 -0
  22. xinference/model/llm/pytorch/qwen_vl.py +2 -0
  23. xinference/model/llm/pytorch/yi_vl.py +4 -2
  24. xinference/model/llm/utils.py +53 -5
  25. xinference/model/llm/vllm/core.py +54 -6
  26. xinference/model/rerank/core.py +3 -0
  27. xinference/thirdparty/deepseek_vl/__init__.py +31 -0
  28. xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
  29. xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
  30. xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
  31. xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
  32. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
  33. xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
  34. xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
  35. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
  36. xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
  37. xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
  38. xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
  39. xinference/thirdparty/omnilmm/__init__.py +0 -0
  40. xinference/thirdparty/omnilmm/chat.py +216 -0
  41. xinference/thirdparty/omnilmm/constants.py +4 -0
  42. xinference/thirdparty/omnilmm/conversation.py +332 -0
  43. xinference/thirdparty/omnilmm/model/__init__.py +1 -0
  44. xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
  45. xinference/thirdparty/omnilmm/model/resampler.py +166 -0
  46. xinference/thirdparty/omnilmm/model/utils.py +563 -0
  47. xinference/thirdparty/omnilmm/train/__init__.py +13 -0
  48. xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
  49. xinference/thirdparty/omnilmm/utils.py +134 -0
  50. xinference/types.py +15 -19
  51. xinference/web/ui/build/asset-manifest.json +3 -3
  52. xinference/web/ui/build/index.html +1 -1
  53. xinference/web/ui/build/static/js/main.76ef2b17.js +3 -0
  54. xinference/web/ui/build/static/js/main.76ef2b17.js.map +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/35d0e4a317e5582cbb79d901302e9d706520ac53f8a734c2fd8bfde6eb5a4f02.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/b9cbcb6d77ba21b22c6950b6fb5b305d23c19cf747f99f7d48b6b046f8f7b1b0.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/d076fd56cf3b15ed2433e3744b98c6b4e4410a19903d1db4de5bba0e1a1b3347.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/daad8131d91134f6d7aef895a0c9c32e1cb928277cb5aa66c01028126d215be0.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/f16aec63602a77bd561d0e67fa00b76469ac54b8033754bba114ec5eb3257964.json +1 -0
  73. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/METADATA +25 -12
  74. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/RECORD +79 -58
  75. xinference/model/llm/ggml/ctransformers.py +0 -281
  76. xinference/model/llm/ggml/ctransformers_util.py +0 -161
  77. xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
  78. xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/0bd70b1ecf307e2681318e864f4692305b6350c8683863007f4caf2f9ac33b6e.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/0db651c046ef908f45cde73af0dbea0a797d3e35bb57f4a0863b481502103a64.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/18e5d5422e2464abf4a3e6d38164570e2e426e0a921e9a2628bbae81b18da353.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/3d93bd9a74a1ab0cec85af40f9baa5f6a8e7384b9e18c409b95a81a7b45bb7e2.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/3e055de705e397e1d413d7f429589b1a98dd78ef378b97f0cdb462c5f2487d5e.json +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/60c4b98d8ea7479fb0c94cfd19c8128f17bd7e27a1e73e6dd9adf6e9d88d18eb.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/7e094845f611802b024b57439cbf911038169d06cdf6c34a72a7277f35aa71a4.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/98b7ef307f436affe13d75a4f265b27e828ccc2b10ffae6513abe2681bc11971.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/b400cfc9db57fa6c70cd2bad055b73c5079fde0ed37974009d898083f6af8cd8.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +0 -1
  93. xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
  94. xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +0 -1
  95. xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +0 -1
  96. xinference/web/ui/node_modules/.cache/babel-loader/e1d9b2ae4e1248658704bc6bfc5d6160dcd1a9e771ea4ae8c1fed0aaddeedd29.json +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +0 -1
  99. /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.76ef2b17.js.LICENSE.txt} +0 -0
  100. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/LICENSE +0 -0
  101. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/WHEEL +0 -0
  102. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/entry_points.txt +0 -0
  103. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,563 @@
1
+ import base64
2
+ import os
3
+ import pickle
4
+ from io import BytesIO
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from PIL import Image
11
+ from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
12
+ from timm.data.transforms import RandomResizedCropAndInterpolation
13
+ from torchvision import transforms
14
+ from transformers import AutoConfig, StoppingCriteria
15
+
16
+ try:
17
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
18
+ except ImportError:
19
+ OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
20
+ OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
21
+
22
+
23
+ class KeywordsStoppingCriteria(StoppingCriteria):
24
+ def __init__(self, keywords, tokenizer, input_ids):
25
+ self.keywords = keywords
26
+ self.tokenizer = tokenizer
27
+ self.start_len = None
28
+ self.input_ids = input_ids
29
+
30
+ def __call__(
31
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
32
+ ) -> bool:
33
+ if self.start_len is None:
34
+ self.start_len = self.input_ids.shape[1]
35
+ else:
36
+ outputs = self.tokenizer.batch_decode(
37
+ output_ids[:, self.start_len :], skip_special_tokens=True
38
+ )[0]
39
+ for keyword in self.keywords:
40
+ if keyword in outputs:
41
+ return True
42
+ return False
43
+
44
+
45
+ def auto_upgrade(config):
46
+ cfg = AutoConfig.from_pretrained(config)
47
+ if "llava" in config and cfg.model_type != "llava":
48
+ print(
49
+ "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
50
+ )
51
+ print(
52
+ "You must upgrade the checkpoint to the new code base (this can be done automatically)."
53
+ )
54
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
55
+ if confirm.lower() in ["y", "yes"]:
56
+ print("Upgrading checkpoint...")
57
+ assert len(cfg.architectures) == 1
58
+ setattr(cfg.__class__, "model_type", "llava")
59
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
60
+ cfg.save_pretrained(config)
61
+ print("Checkpoint upgraded.")
62
+ else:
63
+ print("Checkpoint upgrade aborted.")
64
+ exit(1)
65
+
66
+
67
+ # aug functions
68
+
69
+
70
+ def identity_func(img):
71
+ return img
72
+
73
+
74
+ def autocontrast_func(img, cutoff=0):
75
+ """
76
+ same output as PIL.ImageOps.autocontrast
77
+ """
78
+ n_bins = 256
79
+
80
+ def tune_channel(ch):
81
+ n = ch.size
82
+ cut = cutoff * n // 100
83
+ if cut == 0:
84
+ high, low = ch.max(), ch.min()
85
+ else:
86
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
87
+ low = np.argwhere(np.cumsum(hist) > cut)
88
+ low = 0 if low.shape[0] == 0 else low[0]
89
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
90
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
91
+ if high <= low:
92
+ table = np.arange(n_bins)
93
+ else:
94
+ scale = (n_bins - 1) / (high - low)
95
+ table = np.arange(n_bins) * scale - low * scale
96
+ table[table < 0] = 0
97
+ table[table > n_bins - 1] = n_bins - 1
98
+ table = table.clip(0, 255).astype(np.uint8)
99
+ return table[ch]
100
+
101
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
102
+ out = cv2.merge(channels)
103
+ return out
104
+
105
+
106
+ def equalize_func(img):
107
+ """
108
+ same output as PIL.ImageOps.equalize
109
+ PIL's implementation is different from cv2.equalize
110
+ """
111
+ n_bins = 256
112
+
113
+ def tune_channel(ch):
114
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
115
+ non_zero_hist = hist[hist != 0].reshape(-1)
116
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
117
+ if step == 0:
118
+ return ch
119
+ n = np.empty_like(hist)
120
+ n[0] = step // 2
121
+ n[1:] = hist[:-1]
122
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
123
+ return table[ch]
124
+
125
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
126
+ out = cv2.merge(channels)
127
+ return out
128
+
129
+
130
+ def rotate_func(img, degree, fill=(0, 0, 0)):
131
+ """
132
+ like PIL, rotate by degree, not radians
133
+ """
134
+ H, W = img.shape[0], img.shape[1]
135
+ center = W / 2, H / 2
136
+ M = cv2.getRotationMatrix2D(center, degree, 1)
137
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
138
+ return out
139
+
140
+
141
+ def solarize_func(img, thresh=128):
142
+ """
143
+ same output as PIL.ImageOps.posterize
144
+ """
145
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
146
+ table = table.clip(0, 255).astype(np.uint8)
147
+ out = table[img]
148
+ return out
149
+
150
+
151
+ def color_func(img, factor):
152
+ """
153
+ same output as PIL.ImageEnhance.Color
154
+ """
155
+ # implementation according to PIL definition, quite slow
156
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
157
+ # out = blend(degenerate, img, factor)
158
+ # M = (
159
+ # np.eye(3) * factor
160
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
161
+ # )[np.newaxis, np.newaxis, :]
162
+ M = np.float32(
163
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
164
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
165
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
166
+ return out
167
+
168
+
169
+ def contrast_func(img, factor):
170
+ """
171
+ same output as PIL.ImageEnhance.Contrast
172
+ """
173
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
174
+ table = (
175
+ np.array([(el - mean) * factor + mean for el in range(256)])
176
+ .clip(0, 255)
177
+ .astype(np.uint8)
178
+ )
179
+ out = table[img]
180
+ return out
181
+
182
+
183
+ def brightness_func(img, factor):
184
+ """
185
+ same output as PIL.ImageEnhance.Contrast
186
+ """
187
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
188
+ out = table[img]
189
+ return out
190
+
191
+
192
+ def sharpness_func(img, factor):
193
+ """
194
+ The differences the this result and PIL are all on the 4 boundaries, the center
195
+ areas are same
196
+ """
197
+ kernel = np.ones((3, 3), dtype=np.float32)
198
+ kernel[1][1] = 5
199
+ kernel /= 13
200
+ degenerate = cv2.filter2D(img, -1, kernel)
201
+ if factor == 0.0:
202
+ out = degenerate
203
+ elif factor == 1.0:
204
+ out = img
205
+ else:
206
+ out = img.astype(np.float32)
207
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
208
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
209
+ out = out.astype(np.uint8)
210
+ return out
211
+
212
+
213
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
214
+ H, W = img.shape[0], img.shape[1]
215
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
216
+ out = cv2.warpAffine(
217
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
218
+ ).astype(np.uint8)
219
+ return out
220
+
221
+
222
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
223
+ """
224
+ same output as PIL.Image.transform
225
+ """
226
+ H, W = img.shape[0], img.shape[1]
227
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
228
+ out = cv2.warpAffine(
229
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
230
+ ).astype(np.uint8)
231
+ return out
232
+
233
+
234
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
235
+ """
236
+ same output as PIL.Image.transform
237
+ """
238
+ H, W = img.shape[0], img.shape[1]
239
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
240
+ out = cv2.warpAffine(
241
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
242
+ ).astype(np.uint8)
243
+ return out
244
+
245
+
246
+ def posterize_func(img, bits):
247
+ """
248
+ same output as PIL.ImageOps.posterize
249
+ """
250
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
251
+ return out
252
+
253
+
254
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
255
+ H, W = img.shape[0], img.shape[1]
256
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
257
+ out = cv2.warpAffine(
258
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
259
+ ).astype(np.uint8)
260
+ return out
261
+
262
+
263
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
264
+ replace = np.array(replace, dtype=np.uint8)
265
+ H, W = img.shape[0], img.shape[1]
266
+ rh, rw = np.random.random(2)
267
+ pad_size = pad_size // 2
268
+ ch, cw = int(rh * H), int(rw * W)
269
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
270
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
271
+ out = img.copy()
272
+ out[x1:x2, y1:y2, :] = replace
273
+ return out
274
+
275
+
276
+ # level to args
277
+ def enhance_level_to_args(MAX_LEVEL):
278
+ def level_to_args(level):
279
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
280
+
281
+ return level_to_args
282
+
283
+
284
+ def shear_level_to_args(MAX_LEVEL, replace_value):
285
+ def level_to_args(level):
286
+ level = (level / MAX_LEVEL) * 0.3
287
+ if np.random.random() > 0.5:
288
+ level = -level
289
+ return (level, replace_value)
290
+
291
+ return level_to_args
292
+
293
+
294
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
295
+ def level_to_args(level):
296
+ level = (level / MAX_LEVEL) * float(translate_const)
297
+ if np.random.random() > 0.5:
298
+ level = -level
299
+ return (level, replace_value)
300
+
301
+ return level_to_args
302
+
303
+
304
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
305
+ def level_to_args(level):
306
+ level = int((level / MAX_LEVEL) * cutout_const)
307
+ return (level, replace_value)
308
+
309
+ return level_to_args
310
+
311
+
312
+ def solarize_level_to_args(MAX_LEVEL):
313
+ def level_to_args(level):
314
+ level = int((level / MAX_LEVEL) * 256)
315
+ return (level,)
316
+
317
+ return level_to_args
318
+
319
+
320
+ def none_level_to_args(level):
321
+ return ()
322
+
323
+
324
+ def posterize_level_to_args(MAX_LEVEL):
325
+ def level_to_args(level):
326
+ level = int((level / MAX_LEVEL) * 4)
327
+ return (level,)
328
+
329
+ return level_to_args
330
+
331
+
332
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
333
+ def level_to_args(level):
334
+ level = (level / MAX_LEVEL) * 30
335
+ if np.random.random() < 0.5:
336
+ level = -level
337
+ return (level, replace_value)
338
+
339
+ return level_to_args
340
+
341
+
342
+ func_dict = {
343
+ "Identity": identity_func,
344
+ "AutoContrast": autocontrast_func,
345
+ "Equalize": equalize_func,
346
+ "Rotate": rotate_func,
347
+ "Solarize": solarize_func,
348
+ "Color": color_func,
349
+ "Contrast": contrast_func,
350
+ "Brightness": brightness_func,
351
+ "Sharpness": sharpness_func,
352
+ "ShearX": shear_x_func,
353
+ "TranslateX": translate_x_func,
354
+ "TranslateY": translate_y_func,
355
+ "Posterize": posterize_func,
356
+ "ShearY": shear_y_func,
357
+ }
358
+
359
+ translate_const = 10
360
+ MAX_LEVEL = 10
361
+ replace_value = (128, 128, 128)
362
+ arg_dict = {
363
+ "Identity": none_level_to_args,
364
+ "AutoContrast": none_level_to_args,
365
+ "Equalize": none_level_to_args,
366
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
367
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
368
+ "Color": enhance_level_to_args(MAX_LEVEL),
369
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
370
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
371
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
372
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
373
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
374
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
375
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
376
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
377
+ }
378
+
379
+
380
+ class RandomAugment(object):
381
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
382
+ self.N = N
383
+ self.M = M
384
+ self.isPIL = isPIL
385
+ if augs:
386
+ self.augs = augs
387
+ else:
388
+ self.augs = list(arg_dict.keys())
389
+
390
+ def get_random_ops(self):
391
+ sampled_ops = np.random.choice(self.augs, self.N)
392
+ return [(op, 0.5, self.M) for op in sampled_ops]
393
+
394
+ def __call__(self, img):
395
+ if self.isPIL:
396
+ img = np.array(img)
397
+ ops = self.get_random_ops()
398
+ for name, prob, level in ops:
399
+ if np.random.random() > prob:
400
+ continue
401
+ args = arg_dict[name](level)
402
+ img = func_dict[name](img, *args)
403
+ return img
404
+
405
+
406
+ def build_transform(
407
+ is_train,
408
+ randaug=True,
409
+ input_size=224,
410
+ interpolation="bicubic",
411
+ std_mode="IMAGENET_INCEPTION",
412
+ ):
413
+ if std_mode == "IMAGENET_INCEPTION":
414
+ mean = IMAGENET_INCEPTION_MEAN
415
+ std = IMAGENET_INCEPTION_STD
416
+ elif std_mode == "OPENAI_CLIP":
417
+ mean = OPENAI_CLIP_MEAN
418
+ std = OPENAI_CLIP_STD
419
+ else:
420
+ raise NotImplementedError
421
+
422
+ if is_train:
423
+ crop_scale = float(os.environ.get("TRAIN_CROP_SCALE", 0.9999))
424
+ t = [
425
+ RandomResizedCropAndInterpolation(
426
+ input_size, scale=(crop_scale, 1.0), interpolation="bicubic"
427
+ ),
428
+ # transforms.RandomHorizontalFlip(),
429
+ ]
430
+ if randaug and os.environ.get("TRAIN_DO_AUG", "False") == "True":
431
+ print(f"@@@@@ Do random aug during training", flush=True)
432
+ t.append(
433
+ RandomAugment(
434
+ 2,
435
+ 7,
436
+ isPIL=True,
437
+ augs=[
438
+ "Identity",
439
+ "AutoContrast",
440
+ "Equalize",
441
+ "Brightness",
442
+ "Sharpness",
443
+ "ShearX",
444
+ "ShearY",
445
+ "TranslateX",
446
+ "TranslateY",
447
+ "Rotate",
448
+ ],
449
+ )
450
+ )
451
+ else:
452
+ print(f"@@@@@ Skip random aug during training", flush=True)
453
+ t += [
454
+ transforms.ToTensor(),
455
+ transforms.Normalize(mean=mean, std=std),
456
+ ]
457
+ t = transforms.Compose(t)
458
+ else:
459
+ t = transforms.Compose(
460
+ [
461
+ transforms.Resize(
462
+ (input_size, input_size),
463
+ interpolation=transforms.InterpolationMode.BICUBIC,
464
+ ),
465
+ transforms.ToTensor(),
466
+ transforms.Normalize(mean=mean, std=std),
467
+ ]
468
+ )
469
+
470
+ return t
471
+
472
+
473
+ def img2b64(img_path):
474
+ img = Image.open(img_path) # path to file
475
+ img_buffer = BytesIO()
476
+ img.save(img_buffer, format=img.format)
477
+ byte_data = img_buffer.getvalue()
478
+ base64_str = base64.b64encode(byte_data) # bytes
479
+ base64_str = base64_str.decode("utf-8") # str
480
+ return base64_str
481
+
482
+
483
+ def str2b64(str):
484
+ return base64.b64encode(str.encode("utf-8")).decode("utf-8")
485
+
486
+
487
+ def b642str(b64):
488
+ return base64.b64decode(b64).decode("utf-8")
489
+
490
+
491
+ def is_dist_avail_and_initialized():
492
+ if not dist.is_available():
493
+ return False
494
+ if not dist.is_initialized():
495
+ return False
496
+ return True
497
+
498
+
499
+ def get_world_size():
500
+ if not is_dist_avail_and_initialized():
501
+ return 1
502
+ return dist.get_world_size()
503
+
504
+
505
+ def get_rank():
506
+ if not is_dist_avail_and_initialized():
507
+ return 0
508
+ return dist.get_rank()
509
+
510
+
511
+ def all_gather(data):
512
+ """
513
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
514
+ Args:
515
+ data: any picklable object
516
+ Returns:
517
+ list[data]: list of data gathered from each rank
518
+ """
519
+ world_size = get_world_size()
520
+ if world_size == 1:
521
+ return [data]
522
+
523
+ # serialized to a Tensor
524
+ buffer = pickle.dumps(data)
525
+ storage = torch.ByteStorage.from_buffer(buffer)
526
+ tensor = torch.ByteTensor(storage).to("cuda")
527
+
528
+ # obtain Tensor size of each rank
529
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
530
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
531
+ dist.all_gather(size_list, local_size)
532
+ size_list = [int(size.item()) for size in size_list]
533
+ max_size = max(size_list)
534
+
535
+ # receiving Tensor from all ranks
536
+ # we pad the tensor because torch all_gather does not support
537
+ # gathering tensors of different shapes
538
+ tensor_list = []
539
+ for _ in size_list:
540
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
541
+ if local_size != max_size:
542
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
543
+ tensor = torch.cat((tensor, padding), dim=0)
544
+ dist.all_gather(tensor_list, tensor)
545
+
546
+ data_list = []
547
+ for size, tensor in zip(size_list, tensor_list):
548
+ buffer = tensor.cpu().numpy().tobytes()[:size]
549
+ data_list.append(pickle.loads(buffer))
550
+
551
+ return data_list
552
+
553
+
554
+ def mean(lst):
555
+ return sum(lst) / len(lst)
556
+
557
+
558
+ def stop_gradient_by_name(name: str):
559
+ def apply_fn(module):
560
+ if hasattr(module, name):
561
+ getattr(module, name).requires_grad_(False)
562
+
563
+ return apply_fn
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.