xinference 0.16.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (50) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +48 -0
  3. xinference/client/restful/restful_client.py +19 -0
  4. xinference/core/chat_interface.py +5 -1
  5. xinference/core/image_interface.py +5 -1
  6. xinference/core/model.py +106 -16
  7. xinference/core/scheduler.py +1 -1
  8. xinference/deploy/supervisor.py +0 -4
  9. xinference/model/audio/chattts.py +25 -14
  10. xinference/model/audio/model_spec.json +1 -1
  11. xinference/model/audio/model_spec_modelscope.json +1 -1
  12. xinference/model/embedding/model_spec.json +1 -1
  13. xinference/model/image/core.py +59 -4
  14. xinference/model/image/model_spec.json +24 -3
  15. xinference/model/image/model_spec_modelscope.json +25 -3
  16. xinference/model/image/ocr/__init__.py +13 -0
  17. xinference/model/image/ocr/got_ocr2.py +76 -0
  18. xinference/model/image/scheduler/flux.py +1 -1
  19. xinference/model/image/stable_diffusion/core.py +2 -3
  20. xinference/model/image/stable_diffusion/mlx.py +221 -0
  21. xinference/model/llm/llm_family.json +9 -0
  22. xinference/model/llm/llm_family_modelscope.json +11 -0
  23. xinference/thirdparty/mlx/__init__.py +13 -0
  24. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  25. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  26. xinference/thirdparty/mlx/flux/clip.py +154 -0
  27. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  28. xinference/thirdparty/mlx/flux/flux.py +247 -0
  29. xinference/thirdparty/mlx/flux/layers.py +302 -0
  30. xinference/thirdparty/mlx/flux/lora.py +76 -0
  31. xinference/thirdparty/mlx/flux/model.py +134 -0
  32. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  33. xinference/thirdparty/mlx/flux/t5.py +244 -0
  34. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  35. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  36. xinference/thirdparty/mlx/flux/utils.py +179 -0
  37. xinference/web/ui/build/asset-manifest.json +3 -3
  38. xinference/web/ui/build/index.html +1 -1
  39. xinference/web/ui/build/static/js/{main.f7da0140.js → main.b76aeeb7.js} +3 -3
  40. xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
  42. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/METADATA +15 -8
  43. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/RECORD +48 -31
  44. xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  46. /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
  47. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
  48. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
  49. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
  50. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+ from mlx.nn.layers.upsample import upsample_nearest
9
+
10
+
11
+ @dataclass
12
+ class AutoEncoderParams:
13
+ resolution: int
14
+ in_channels: int
15
+ ch: int
16
+ out_ch: int
17
+ ch_mult: List[int]
18
+ num_res_blocks: int
19
+ z_channels: int
20
+ scale_factor: float
21
+ shift_factor: float
22
+
23
+
24
+ class AttnBlock(nn.Module):
25
+ def __init__(self, in_channels: int):
26
+ super().__init__()
27
+ self.in_channels = in_channels
28
+
29
+ self.norm = nn.GroupNorm(
30
+ num_groups=32,
31
+ dims=in_channels,
32
+ eps=1e-6,
33
+ affine=True,
34
+ pytorch_compatible=True,
35
+ )
36
+ self.q = nn.Linear(in_channels, in_channels)
37
+ self.k = nn.Linear(in_channels, in_channels)
38
+ self.v = nn.Linear(in_channels, in_channels)
39
+ self.proj_out = nn.Linear(in_channels, in_channels)
40
+
41
+ def __call__(self, x: mx.array) -> mx.array:
42
+ B, H, W, C = x.shape
43
+
44
+ y = x.reshape(B, 1, -1, C)
45
+ y = self.norm(y)
46
+ q = self.q(y)
47
+ k = self.k(y)
48
+ v = self.v(y)
49
+ y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
50
+ y = self.proj_out(y)
51
+
52
+ return x + y.reshape(B, H, W, C)
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(
63
+ num_groups=32,
64
+ dims=in_channels,
65
+ eps=1e-6,
66
+ affine=True,
67
+ pytorch_compatible=True,
68
+ )
69
+ self.conv1 = nn.Conv2d(
70
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
71
+ )
72
+ self.norm2 = nn.GroupNorm(
73
+ num_groups=32,
74
+ dims=out_channels,
75
+ eps=1e-6,
76
+ affine=True,
77
+ pytorch_compatible=True,
78
+ )
79
+ self.conv2 = nn.Conv2d(
80
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
81
+ )
82
+ if self.in_channels != self.out_channels:
83
+ self.nin_shortcut = nn.Linear(in_channels, out_channels)
84
+
85
+ def __call__(self, x):
86
+ h = x
87
+ h = self.norm1(h)
88
+ h = nn.silu(h)
89
+ h = self.conv1(h)
90
+
91
+ h = self.norm2(h)
92
+ h = nn.silu(h)
93
+ h = self.conv2(h)
94
+
95
+ if self.in_channels != self.out_channels:
96
+ x = self.nin_shortcut(x)
97
+
98
+ return x + h
99
+
100
+
101
+ class Downsample(nn.Module):
102
+ def __init__(self, in_channels: int):
103
+ super().__init__()
104
+ self.conv = nn.Conv2d(
105
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
106
+ )
107
+
108
+ def __call__(self, x: mx.array):
109
+ x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
110
+ x = self.conv(x)
111
+ return x
112
+
113
+
114
+ class Upsample(nn.Module):
115
+ def __init__(self, in_channels: int):
116
+ super().__init__()
117
+ self.conv = nn.Conv2d(
118
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
119
+ )
120
+
121
+ def __call__(self, x: mx.array):
122
+ x = upsample_nearest(x, (2, 2))
123
+ x = self.conv(x)
124
+ return x
125
+
126
+
127
+ class Encoder(nn.Module):
128
+ def __init__(
129
+ self,
130
+ resolution: int,
131
+ in_channels: int,
132
+ ch: int,
133
+ ch_mult: list[int],
134
+ num_res_blocks: int,
135
+ z_channels: int,
136
+ ):
137
+ super().__init__()
138
+ self.ch = ch
139
+ self.num_resolutions = len(ch_mult)
140
+ self.num_res_blocks = num_res_blocks
141
+ self.resolution = resolution
142
+ self.in_channels = in_channels
143
+ # downsampling
144
+ self.conv_in = nn.Conv2d(
145
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
146
+ )
147
+
148
+ curr_res = resolution
149
+ in_ch_mult = (1,) + tuple(ch_mult)
150
+ self.in_ch_mult = in_ch_mult
151
+ self.down = []
152
+ block_in = self.ch
153
+ for i_level in range(self.num_resolutions):
154
+ block = []
155
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
156
+ block_in = ch * in_ch_mult[i_level]
157
+ block_out = ch * ch_mult[i_level]
158
+ for _ in range(self.num_res_blocks):
159
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
160
+ block_in = block_out
161
+ down = {}
162
+ down["block"] = block
163
+ down["attn"] = attn
164
+ if i_level != self.num_resolutions - 1:
165
+ down["downsample"] = Downsample(block_in)
166
+ curr_res = curr_res // 2
167
+ self.down.append(down)
168
+
169
+ # middle
170
+ self.mid = {}
171
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
172
+ self.mid["attn_1"] = AttnBlock(block_in)
173
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
174
+
175
+ # end
176
+ self.norm_out = nn.GroupNorm(
177
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
178
+ )
179
+ self.conv_out = nn.Conv2d(
180
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
181
+ )
182
+
183
+ def __call__(self, x: mx.array):
184
+ hs = [self.conv_in(x)]
185
+ for i_level in range(self.num_resolutions):
186
+ for i_block in range(self.num_res_blocks):
187
+ h = self.down[i_level]["block"][i_block](hs[-1])
188
+
189
+ # TODO: Remove the attn
190
+ if len(self.down[i_level]["attn"]) > 0:
191
+ h = self.down[i_level]["attn"][i_block](h)
192
+
193
+ hs.append(h)
194
+
195
+ if i_level != self.num_resolutions - 1:
196
+ hs.append(self.down[i_level]["downsample"](hs[-1]))
197
+
198
+ # middle
199
+ h = hs[-1]
200
+ h = self.mid["block_1"](h)
201
+ h = self.mid["attn_1"](h)
202
+ h = self.mid["block_2"](h)
203
+
204
+ # end
205
+ h = self.norm_out(h)
206
+ h = nn.silu(h)
207
+ h = self.conv_out(h)
208
+
209
+ return h
210
+
211
+
212
+ class Decoder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ ch: int,
216
+ out_ch: int,
217
+ ch_mult: list[int],
218
+ num_res_blocks: int,
219
+ in_channels: int,
220
+ resolution: int,
221
+ z_channels: int,
222
+ ):
223
+ super().__init__()
224
+ self.ch = ch
225
+ self.num_resolutions = len(ch_mult)
226
+ self.num_res_blocks = num_res_blocks
227
+ self.resolution = resolution
228
+ self.in_channels = in_channels
229
+ self.ffactor = 2 ** (self.num_resolutions - 1)
230
+
231
+ # compute in_ch_mult, block_in and curr_res at lowest res
232
+ block_in = ch * ch_mult[self.num_resolutions - 1]
233
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
234
+ self.z_shape = (1, z_channels, curr_res, curr_res)
235
+
236
+ # z to block_in
237
+ self.conv_in = nn.Conv2d(
238
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
239
+ )
240
+
241
+ # middle
242
+ self.mid = {}
243
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
244
+ self.mid["attn_1"] = AttnBlock(block_in)
245
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
246
+
247
+ # upsampling
248
+ self.up = []
249
+ for i_level in reversed(range(self.num_resolutions)):
250
+ block = []
251
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
252
+
253
+ block_out = ch * ch_mult[i_level]
254
+ for _ in range(self.num_res_blocks + 1):
255
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
256
+ block_in = block_out
257
+ up = {}
258
+ up["block"] = block
259
+ up["attn"] = attn
260
+ if i_level != 0:
261
+ up["upsample"] = Upsample(block_in)
262
+ curr_res = curr_res * 2
263
+ self.up.insert(0, up) # prepend to get consistent order
264
+
265
+ # end
266
+ self.norm_out = nn.GroupNorm(
267
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
268
+ )
269
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
270
+
271
+ def __call__(self, z: mx.array):
272
+ # z to block_in
273
+ h = self.conv_in(z)
274
+
275
+ # middle
276
+ h = self.mid["block_1"](h)
277
+ h = self.mid["attn_1"](h)
278
+ h = self.mid["block_2"](h)
279
+
280
+ # upsampling
281
+ for i_level in reversed(range(self.num_resolutions)):
282
+ for i_block in range(self.num_res_blocks + 1):
283
+ h = self.up[i_level]["block"][i_block](h)
284
+
285
+ # TODO: Remove the attn
286
+ if len(self.up[i_level]["attn"]) > 0:
287
+ h = self.up[i_level]["attn"][i_block](h)
288
+
289
+ if i_level != 0:
290
+ h = self.up[i_level]["upsample"](h)
291
+
292
+ # end
293
+ h = self.norm_out(h)
294
+ h = nn.silu(h)
295
+ h = self.conv_out(h)
296
+
297
+ return h
298
+
299
+
300
+ class DiagonalGaussian(nn.Module):
301
+ def __call__(self, z: mx.array):
302
+ mean, logvar = mx.split(z, 2, axis=-1)
303
+ if self.training:
304
+ std = mx.exp(0.5 * logvar)
305
+ eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
306
+ return mean + std * eps
307
+ else:
308
+ return mean
309
+
310
+
311
+ class AutoEncoder(nn.Module):
312
+ def __init__(self, params: AutoEncoderParams):
313
+ super().__init__()
314
+ self.encoder = Encoder(
315
+ resolution=params.resolution,
316
+ in_channels=params.in_channels,
317
+ ch=params.ch,
318
+ ch_mult=params.ch_mult,
319
+ num_res_blocks=params.num_res_blocks,
320
+ z_channels=params.z_channels,
321
+ )
322
+ self.decoder = Decoder(
323
+ resolution=params.resolution,
324
+ in_channels=params.in_channels,
325
+ ch=params.ch,
326
+ out_ch=params.out_ch,
327
+ ch_mult=params.ch_mult,
328
+ num_res_blocks=params.num_res_blocks,
329
+ z_channels=params.z_channels,
330
+ )
331
+ self.reg = DiagonalGaussian()
332
+
333
+ self.scale_factor = params.scale_factor
334
+ self.shift_factor = params.shift_factor
335
+
336
+ def sanitize(self, weights):
337
+ new_weights = {}
338
+ for k, w in weights.items():
339
+ if w.ndim == 4:
340
+ w = w.transpose(0, 2, 3, 1)
341
+ w = w.reshape(-1).reshape(w.shape)
342
+ if w.shape[1:3] == (1, 1):
343
+ w = w.squeeze((1, 2))
344
+ new_weights[k] = w
345
+ return new_weights
346
+
347
+ def encode(self, x: mx.array):
348
+ z = self.reg(self.encoder(x))
349
+ z = self.scale_factor * (z - self.shift_factor)
350
+ return z
351
+
352
+ def decode(self, z: mx.array):
353
+ z = z / self.scale_factor + self.shift_factor
354
+ return self.decoder(z)
355
+
356
+ def __call__(self, x: mx.array):
357
+ return self.decode(self.encode(x))
@@ -0,0 +1,154 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ _ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
10
+
11
+
12
+ @dataclass
13
+ class CLIPTextModelConfig:
14
+ num_layers: int = 23
15
+ model_dims: int = 1024
16
+ num_heads: int = 16
17
+ max_length: int = 77
18
+ vocab_size: int = 49408
19
+ hidden_act: str = "quick_gelu"
20
+
21
+ @classmethod
22
+ def from_dict(cls, config):
23
+ return cls(
24
+ num_layers=config["num_hidden_layers"],
25
+ model_dims=config["hidden_size"],
26
+ num_heads=config["num_attention_heads"],
27
+ max_length=config["max_position_embeddings"],
28
+ vocab_size=config["vocab_size"],
29
+ hidden_act=config["hidden_act"],
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class CLIPOutput:
35
+ # The last_hidden_state indexed at the EOS token and possibly projected if
36
+ # the model has a projection layer
37
+ pooled_output: Optional[mx.array] = None
38
+
39
+ # The full sequence output of the transformer after the final layernorm
40
+ last_hidden_state: Optional[mx.array] = None
41
+
42
+ # A list of hidden states corresponding to the outputs of the transformer layers
43
+ hidden_states: Optional[List[mx.array]] = None
44
+
45
+
46
+ class CLIPEncoderLayer(nn.Module):
47
+ """The transformer encoder layer from CLIP."""
48
+
49
+ def __init__(self, model_dims: int, num_heads: int, activation: str):
50
+ super().__init__()
51
+
52
+ self.layer_norm1 = nn.LayerNorm(model_dims)
53
+ self.layer_norm2 = nn.LayerNorm(model_dims)
54
+
55
+ self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
56
+
57
+ self.linear1 = nn.Linear(model_dims, 4 * model_dims)
58
+ self.linear2 = nn.Linear(4 * model_dims, model_dims)
59
+
60
+ self.act = _ACTIVATIONS[activation]
61
+
62
+ def __call__(self, x, attn_mask=None):
63
+ y = self.layer_norm1(x)
64
+ y = self.attention(y, y, y, attn_mask)
65
+ x = y + x
66
+
67
+ y = self.layer_norm2(x)
68
+ y = self.linear1(y)
69
+ y = self.act(y)
70
+ y = self.linear2(y)
71
+ x = y + x
72
+
73
+ return x
74
+
75
+
76
+ class CLIPTextModel(nn.Module):
77
+ """Implements the text encoder transformer from CLIP."""
78
+
79
+ def __init__(self, config: CLIPTextModelConfig):
80
+ super().__init__()
81
+
82
+ self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
83
+ self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
84
+ self.layers = [
85
+ CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
86
+ for i in range(config.num_layers)
87
+ ]
88
+ self.final_layer_norm = nn.LayerNorm(config.model_dims)
89
+
90
+ def _get_mask(self, N, dtype):
91
+ indices = mx.arange(N)
92
+ mask = indices[:, None] < indices[None]
93
+ mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
94
+ return mask
95
+
96
+ def sanitize(self, weights):
97
+ new_weights = {}
98
+ for key, w in weights.items():
99
+ # Remove prefixes
100
+ if key.startswith("text_model."):
101
+ key = key[11:]
102
+ if key.startswith("embeddings."):
103
+ key = key[11:]
104
+ if key.startswith("encoder."):
105
+ key = key[8:]
106
+
107
+ # Map attention layers
108
+ if "self_attn." in key:
109
+ key = key.replace("self_attn.", "attention.")
110
+ if "q_proj." in key:
111
+ key = key.replace("q_proj.", "query_proj.")
112
+ if "k_proj." in key:
113
+ key = key.replace("k_proj.", "key_proj.")
114
+ if "v_proj." in key:
115
+ key = key.replace("v_proj.", "value_proj.")
116
+
117
+ # Map ffn layers
118
+ if "mlp.fc1" in key:
119
+ key = key.replace("mlp.fc1", "linear1")
120
+ if "mlp.fc2" in key:
121
+ key = key.replace("mlp.fc2", "linear2")
122
+
123
+ new_weights[key] = w
124
+
125
+ return new_weights
126
+
127
+ def __call__(self, x):
128
+ # Extract some shapes
129
+ B, N = x.shape
130
+ eos_tokens = x.argmax(-1)
131
+
132
+ # Compute the embeddings
133
+ x = self.token_embedding(x)
134
+ x = x + self.position_embedding.weight[:N]
135
+
136
+ # Compute the features from the transformer
137
+ mask = self._get_mask(N, x.dtype)
138
+ hidden_states = []
139
+ for l in self.layers:
140
+ x = l(x, mask)
141
+ hidden_states.append(x)
142
+
143
+ # Apply the final layernorm and return
144
+ x = self.final_layer_norm(x)
145
+ last_hidden_state = x
146
+
147
+ # Select the EOS token
148
+ pooled_output = x[mx.arange(len(x)), eos_tokens]
149
+
150
+ return CLIPOutput(
151
+ pooled_output=pooled_output,
152
+ last_hidden_state=last_hidden_state,
153
+ hidden_states=hidden_states,
154
+ )
@@ -0,0 +1,75 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from PIL import Image
5
+
6
+
7
+ class Dataset:
8
+ def __getitem__(self, index: int):
9
+ raise NotImplementedError()
10
+
11
+ def __len__(self):
12
+ raise NotImplementedError()
13
+
14
+
15
+ class LocalDataset(Dataset):
16
+ prompt_key = "prompt"
17
+
18
+ def __init__(self, dataset: str, data_file):
19
+ self.dataset_base = Path(dataset)
20
+ with open(data_file, "r") as fid:
21
+ self._data = [json.loads(l) for l in fid]
22
+
23
+ def __len__(self):
24
+ return len(self._data)
25
+
26
+ def __getitem__(self, index: int):
27
+ item = self._data[index]
28
+ image = Image.open(self.dataset_base / item["image"])
29
+ return image, item[self.prompt_key]
30
+
31
+
32
+ class LegacyDataset(LocalDataset):
33
+ prompt_key = "text"
34
+
35
+ def __init__(self, dataset: str):
36
+ self.dataset_base = Path(dataset)
37
+ with open(self.dataset_base / "index.json") as f:
38
+ self._data = json.load(f)["data"]
39
+
40
+
41
+ class HuggingFaceDataset(Dataset):
42
+
43
+ def __init__(self, dataset: str):
44
+ from datasets import load_dataset as hf_load_dataset
45
+
46
+ self._df = hf_load_dataset(dataset)["train"]
47
+
48
+ def __len__(self):
49
+ return len(self._df)
50
+
51
+ def __getitem__(self, index: int):
52
+ item = self._df[index]
53
+ return item["image"], item["prompt"]
54
+
55
+
56
+ def load_dataset(dataset: str):
57
+ dataset_base = Path(dataset)
58
+ data_file = dataset_base / "train.jsonl"
59
+ legacy_file = dataset_base / "index.json"
60
+
61
+ if data_file.exists():
62
+ print(f"Load the local dataset {data_file} .", flush=True)
63
+ dataset = LocalDataset(dataset, data_file)
64
+ elif legacy_file.exists():
65
+ print(f"Load the local dataset {legacy_file} .")
66
+ print()
67
+ print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
68
+ print(" See the README for details.")
69
+ print(flush=True)
70
+ dataset = LegacyDataset(dataset)
71
+ else:
72
+ print(f"Load the Hugging Face dataset {dataset} .", flush=True)
73
+ dataset = HuggingFaceDataset(dataset)
74
+
75
+ return dataset