keras-hub-nightly 0.23.0.dev202509180413__py3-none-any.whl → 0.23.0.dev202509280419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. keras_hub/layers/__init__.py +3 -0
  2. keras_hub/models/__init__.py +24 -0
  3. keras_hub/src/models/depth_anything/__init__.py +9 -0
  4. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  5. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  6. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  7. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  8. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  9. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  10. keras_hub/src/models/depth_anything/depth_anything_presets.py +4 -0
  11. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  12. keras_hub/src/models/depth_estimator.py +239 -0
  13. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  14. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  15. keras_hub/src/models/dinov2/dinov2_layers.py +13 -3
  16. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  17. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  18. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  19. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  20. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  21. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  22. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  23. keras_hub/src/tests/test_case.py +3 -2
  24. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  25. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  26. keras_hub/src/utils/transformers/preset_loader.py +3 -0
  27. keras_hub/src/version.py +1 -1
  28. keras_hub/tokenizers/__init__.py +3 -0
  29. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/METADATA +1 -1
  30. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/RECORD +32 -13
  31. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/WHEEL +0 -0
  32. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,725 @@
1
+ from keras import layers
2
+ from keras import ops
3
+
4
+ from keras_hub.src.models.depth_anything.interpolate import interpolate
5
+ from keras_hub.src.utils.keras_utils import standardize_data_format
6
+
7
+
8
+ class DepthAnythingTokenToImage(layers.Layer):
9
+ """A layer that converts tokens into images.
10
+
11
+ Args:
12
+ hidden_dim: int. The number of units in the hidden layers.
13
+ patch_height: int. The height of each patch.
14
+ patch_width: int. The width of each patch.
15
+ num_cls_tokens: int. The number of class tokens at the beginning of
16
+ the sequence. Defaults to `1`.
17
+ num_register_tokens: int. The number of register tokens after the
18
+ class tokens. Defaults to `0`.
19
+ data_format: `None` or str. If specified, either `"channels_last"` or
20
+ `"channels_first"`. The ordering of the dimensions in the
21
+ inputs. `"channels_last"` corresponds to inputs with shape
22
+ `(batch_size, height, width, channels)`
23
+ while `"channels_first"` corresponds to inputs with shape
24
+ `(batch_size, channels, height, width)`. It defaults to the
25
+ `image_data_format` value found in your Keras config file at
26
+ `~/.keras/keras.json`. If you never set it, then it will be
27
+ `"channels_last"`.
28
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
29
+ including `name`, `dtype` etc.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_dim,
35
+ patch_height,
36
+ patch_width,
37
+ num_cls_tokens=1,
38
+ num_register_tokens=0,
39
+ data_format=None,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.hidden_dim = int(hidden_dim)
44
+ self.patch_height = int(patch_height)
45
+ self.patch_width = int(patch_width)
46
+ self.num_cls_tokens = int(num_cls_tokens)
47
+ self.num_register_tokens = int(num_register_tokens)
48
+ self.data_format = standardize_data_format(data_format)
49
+ # Always use channels_last for reshaping first.
50
+ self.target_shape = (
51
+ self.patch_height,
52
+ self.patch_width,
53
+ self.hidden_dim,
54
+ )
55
+
56
+ def call(self, inputs):
57
+ # Remove the cls token.
58
+ x = inputs[:, self.num_cls_tokens + self.num_register_tokens :, ...]
59
+
60
+ x = ops.reshape(x, (ops.shape(x)[0],) + self.target_shape)
61
+ if self.data_format == "channels_first":
62
+ x = ops.transpose(x, (0, 3, 1, 2))
63
+ return x
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update(
68
+ {
69
+ "hidden_dim": self.hidden_dim,
70
+ "patch_height": self.patch_height,
71
+ "patch_width": self.patch_width,
72
+ "num_cls_tokens": self.num_cls_tokens,
73
+ "num_register_tokens": self.num_register_tokens,
74
+ }
75
+ )
76
+ return config
77
+
78
+ def compute_output_shape(self, input_shape):
79
+ output_shape = [input_shape[0], *self.target_shape]
80
+ if self.data_format == "channels_first":
81
+ output_shape = [
82
+ output_shape[0],
83
+ output_shape[3],
84
+ output_shape[1],
85
+ output_shape[2],
86
+ ]
87
+ return output_shape
88
+
89
+
90
+ class DepthAnythingReassembleLayer(layers.Layer):
91
+ """A layer that resizes the input images.
92
+
93
+ Args:
94
+ hidden_dim: int. The number of units in the hidden layers.
95
+ factor: float. The resizing factor. If `factor > 1`, the layer upsamples
96
+ the input. If `factor < 1`, the layer downsamples the input. If
97
+ `factor == 1`, the layer only applies a linear projection.
98
+ data_format: `None` or str. If specified, either `"channels_last"` or
99
+ `"channels_first"`. The ordering of the dimensions in the
100
+ inputs. `"channels_last"` corresponds to inputs with shape
101
+ `(batch_size, height, width, channels)`
102
+ while `"channels_first"` corresponds to inputs with shape
103
+ `(batch_size, channels, height, width)`. It defaults to the
104
+ `image_data_format` value found in your Keras config file at
105
+ `~/.keras/keras.json`. If you never set it, then it will be
106
+ `"channels_last"`.
107
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
108
+ including `name`, `dtype` etc.
109
+ """
110
+
111
+ def __init__(self, hidden_dim, factor, data_format=None, **kwargs):
112
+ super().__init__(**kwargs)
113
+ self.hidden_dim = int(hidden_dim)
114
+ self.factor = float(factor)
115
+ self.data_format = standardize_data_format(data_format)
116
+
117
+ self.projection = layers.Conv2D(
118
+ filters=self.hidden_dim,
119
+ kernel_size=1,
120
+ data_format=self.data_format,
121
+ use_bias=True,
122
+ dtype=self.dtype_policy,
123
+ name="projection",
124
+ )
125
+ if self.factor > 1:
126
+ self.padding = layers.Identity(
127
+ dtype=self.dtype_policy, name="padding"
128
+ )
129
+ self.resize = layers.Conv2DTranspose(
130
+ filters=self.hidden_dim,
131
+ kernel_size=int(self.factor),
132
+ strides=int(self.factor),
133
+ data_format=self.data_format,
134
+ use_bias=True,
135
+ dtype=self.dtype_policy,
136
+ name="resize",
137
+ )
138
+ elif self.factor == 1:
139
+ self.padding = layers.Identity(
140
+ dtype=self.dtype_policy, name="padding"
141
+ )
142
+ self.resize = layers.Identity(
143
+ dtype=self.dtype_policy, name="resize"
144
+ )
145
+ elif self.factor < 1:
146
+ self.padding = layers.ZeroPadding2D(
147
+ padding=(1, 1),
148
+ data_format=self.data_format,
149
+ dtype=self.dtype_policy,
150
+ name="padding",
151
+ )
152
+ self.resize = layers.Conv2D(
153
+ filters=self.hidden_dim,
154
+ kernel_size=3,
155
+ strides=int(1 / self.factor),
156
+ data_format=self.data_format,
157
+ use_bias=True,
158
+ dtype=self.dtype_policy,
159
+ name="resize",
160
+ )
161
+
162
+ def build(self, inputs_shape):
163
+ self.projection.build(inputs_shape)
164
+ inputs_shape = self.projection.compute_output_shape(inputs_shape)
165
+ self.padding.build(inputs_shape)
166
+ inputs_shape = self.padding.compute_output_shape(inputs_shape)
167
+ self.resize.build(inputs_shape)
168
+
169
+ def call(self, inputs, training=None):
170
+ x = self.projection(inputs, training=training)
171
+ x = self.padding(x, training=training)
172
+ return self.resize(x, training=training)
173
+
174
+ def get_config(self):
175
+ config = super().get_config()
176
+ config.update(
177
+ {
178
+ "hidden_dim": self.hidden_dim,
179
+ "factor": self.factor,
180
+ }
181
+ )
182
+ return config
183
+
184
+ def compute_output_shape(self, input_shape):
185
+ output_shape = list(input_shape)
186
+ if self.data_format == "channels_first":
187
+ output_shape[1] = self.hidden_dim
188
+ output_shape[2] = int(output_shape[2] * self.factor)
189
+ output_shape[3] = int(output_shape[3] * self.factor)
190
+ else:
191
+ output_shape[1] = int(output_shape[1] * self.factor)
192
+ output_shape[2] = int(output_shape[2] * self.factor)
193
+ output_shape[3] = self.hidden_dim
194
+ return output_shape
195
+
196
+
197
+ class DepthAnythingPreActResidualLayer(layers.Layer):
198
+ """A ReLU + Conv2D layer.
199
+
200
+ Args:
201
+ hidden_dim: int. The number of units in the hidden layers.
202
+ data_format: `None` or str. If specified, either `"channels_last"` or
203
+ `"channels_first"`. The ordering of the dimensions in the
204
+ inputs. `"channels_last"` corresponds to inputs with shape
205
+ `(batch_size, height, width, channels)`
206
+ while `"channels_first"` corresponds to inputs with shape
207
+ `(batch_size, channels, height, width)`. It defaults to the
208
+ `image_data_format` value found in your Keras config file at
209
+ `~/.keras/keras.json`. If you never set it, then it will be
210
+ `"channels_last"`.
211
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
212
+ including `name`, `dtype` etc.
213
+ """
214
+
215
+ def __init__(self, hidden_dim, data_format=None, **kwargs):
216
+ super().__init__(**kwargs)
217
+ self.hidden_dim = int(hidden_dim)
218
+ self.data_format = standardize_data_format(data_format)
219
+
220
+ self.activation1 = layers.ReLU(
221
+ dtype=self.dtype_policy, name="activation1"
222
+ )
223
+ self.padding1 = layers.ZeroPadding2D(
224
+ padding=(1, 1),
225
+ data_format=self.data_format,
226
+ dtype=self.dtype_policy,
227
+ name="padding1",
228
+ )
229
+ self.convolution1 = layers.Conv2D(
230
+ filters=self.hidden_dim,
231
+ kernel_size=3,
232
+ strides=1,
233
+ data_format=self.data_format,
234
+ use_bias=True,
235
+ dtype=self.dtype_policy,
236
+ name="convolution1",
237
+ )
238
+ self.activation2 = layers.ReLU(
239
+ dtype=self.dtype_policy, name="activation2"
240
+ )
241
+ self.padding2 = layers.ZeroPadding2D(
242
+ padding=(1, 1),
243
+ data_format=self.data_format,
244
+ dtype=self.dtype_policy,
245
+ name="padding2",
246
+ )
247
+ self.convolution2 = layers.Conv2D(
248
+ filters=self.hidden_dim,
249
+ kernel_size=3,
250
+ strides=1,
251
+ data_format=self.data_format,
252
+ use_bias=True,
253
+ dtype=self.dtype_policy,
254
+ name="convolution2",
255
+ )
256
+
257
+ def build(self, inputs_shape):
258
+ self.activation1.build(inputs_shape)
259
+ self.padding1.build(inputs_shape)
260
+ inputs_shape = self.padding1.compute_output_shape(inputs_shape)
261
+ self.convolution1.build(inputs_shape)
262
+ inputs_shape = self.convolution1.compute_output_shape(inputs_shape)
263
+ self.activation2.build(inputs_shape)
264
+ self.padding2.build(inputs_shape)
265
+ inputs_shape = self.padding2.compute_output_shape(inputs_shape)
266
+ self.convolution2.build(inputs_shape)
267
+
268
+ def call(self, inputs, training=None):
269
+ residual = inputs
270
+ x = self.activation1(inputs, training=training)
271
+ x = self.padding1(x, training=training)
272
+ x = self.convolution1(x, training=training)
273
+ x = self.activation2(x, training=training)
274
+ x = self.padding2(x, training=training)
275
+ x = self.convolution2(x, training=training)
276
+ return ops.add(x, residual)
277
+
278
+ def get_config(self):
279
+ config = super().get_config()
280
+ config.update(
281
+ {
282
+ "hidden_dim": self.hidden_dim,
283
+ }
284
+ )
285
+ return config
286
+
287
+ def compute_output_shape(self, input_shape):
288
+ return input_shape
289
+
290
+
291
+ class DepthAnythingFeatureFusionLayer(layers.Layer):
292
+ """A layer that fuses the incoming features.
293
+
294
+ Args:
295
+ hidden_dim: int. The number of units in the hidden layers.
296
+ size: tuple of int. The target size of the output feature map.
297
+ data_format: `None` or str. If specified, either `"channels_last"` or
298
+ `"channels_first"`. The ordering of the dimensions in the
299
+ inputs. `"channels_last"` corresponds to inputs with shape
300
+ `(batch_size, height, width, channels)`
301
+ while `"channels_first"` corresponds to inputs with shape
302
+ `(batch_size, channels, height, width)`. It defaults to the
303
+ `image_data_format` value found in your Keras config file at
304
+ `~/.keras/keras.json`. If you never set it, then it will be
305
+ `"channels_last"`.
306
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
307
+ including `name`, `dtype` etc.
308
+ """
309
+
310
+ def __init__(self, hidden_dim, size, data_format=None, **kwargs):
311
+ super().__init__(**kwargs)
312
+ self.hidden_dim = int(hidden_dim)
313
+ self.size = tuple(int(s) for s in size)
314
+ self.data_format = standardize_data_format(data_format)
315
+
316
+ self.residual_layer1 = DepthAnythingPreActResidualLayer(
317
+ hidden_dim=self.hidden_dim,
318
+ data_format=self.data_format,
319
+ dtype=self.dtype_policy,
320
+ name="residual_layer1",
321
+ )
322
+ self.residual_layer2 = DepthAnythingPreActResidualLayer(
323
+ hidden_dim=self.hidden_dim,
324
+ data_format=self.data_format,
325
+ dtype=self.dtype_policy,
326
+ name="residual_layer2",
327
+ )
328
+ self.projection = layers.Conv2D(
329
+ filters=self.hidden_dim,
330
+ kernel_size=1,
331
+ data_format=self.data_format,
332
+ use_bias=True,
333
+ dtype=self.dtype_policy,
334
+ name="projection",
335
+ )
336
+
337
+ def build(self, inputs_shape):
338
+ self.residual_layer1.build(inputs_shape)
339
+ self.residual_layer2.build(inputs_shape)
340
+ inputs_shape = list(inputs_shape)
341
+ if self.data_format == "channels_last":
342
+ inputs_shape[1] = self.size[0]
343
+ inputs_shape[2] = self.size[1]
344
+ else:
345
+ inputs_shape[2] = self.size[0]
346
+ inputs_shape[3] = self.size[1]
347
+ self.projection.build(inputs_shape)
348
+
349
+ def call(self, inputs, residual=None, training=None):
350
+ if residual is not None:
351
+ inputs = ops.add(
352
+ inputs, self.residual_layer1(residual, training=training)
353
+ )
354
+
355
+ x = self.residual_layer2(inputs, training=training)
356
+ x = interpolate(x, size=self.size, data_format=self.data_format)
357
+ return self.projection(x, training=training)
358
+
359
+ def get_config(self):
360
+ config = super().get_config()
361
+ config.update(
362
+ {
363
+ "hidden_dim": self.hidden_dim,
364
+ "size": self.size,
365
+ }
366
+ )
367
+ return config
368
+
369
+ def compute_output_shape(self, input_shape):
370
+ input_shape = self.residual_layer2.compute_output_shape(input_shape)
371
+ input_shape = list(input_shape)
372
+ if self.data_format == "channels_last":
373
+ input_shape[1] = self.size[0]
374
+ input_shape[2] = self.size[1]
375
+ else:
376
+ input_shape[2] = self.size[0]
377
+ input_shape[3] = self.size[1]
378
+ return self.projection.compute_output_shape(input_shape)
379
+
380
+
381
+ class DepthAnythingNeck(layers.Layer):
382
+ """A DepthAnything neck layer.
383
+
384
+ Args:
385
+ patch_size: int. The size of one side of each patch.
386
+ image_size: tuple of ints. The (height, width) of the input images.
387
+ backbone_hidden_dim: int. The number of units in the backbone layers.
388
+ neck_hidden_dims: List of int. The number of units in each neck layer.
389
+ reassemble_factors: List of float. The resizing factor in each neck
390
+ layer.
391
+ fusion_hidden_dim: int. The number of units in the fusion layers.
392
+ num_cls_tokens: int. The number of class tokens at the beginning of
393
+ the sequence. Defaults to `1`.
394
+ num_register_tokens: int. The number of register tokens after the
395
+ class tokens. Defaults to `0`.
396
+ data_format: `None` or str. If specified, either `"channels_last"` or
397
+ `"channels_first"`. The ordering of the dimensions in the
398
+ inputs. `"channels_last"` corresponds to inputs with shape
399
+ `(batch_size, height, width, channels)`
400
+ while `"channels_first"` corresponds to inputs with shape
401
+ `(batch_size, channels, height, width)`. It defaults to the
402
+ `image_data_format` value found in your Keras config file at
403
+ `~/.keras/keras.json`. If you never set it, then it will be
404
+ `"channels_last"`.
405
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
406
+ including `name`, `dtype` etc.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ patch_size,
412
+ image_size,
413
+ backbone_hidden_dim,
414
+ neck_hidden_dims,
415
+ reassemble_factors,
416
+ fusion_hidden_dim,
417
+ num_cls_tokens=1,
418
+ num_register_tokens=0,
419
+ data_format=None,
420
+ **kwargs,
421
+ ):
422
+ super().__init__(**kwargs)
423
+ self.patch_size = int(patch_size)
424
+ self.image_size = (int(image_size[0]), int(image_size[1]))
425
+ self.backbone_hidden_dim = int(backbone_hidden_dim)
426
+ self.neck_hidden_dims = tuple(int(d) for d in neck_hidden_dims)
427
+ self.reassemble_factors = tuple(float(f) for f in reassemble_factors)
428
+ self.fusion_hidden_dim = int(fusion_hidden_dim)
429
+ self.num_cls_tokens = int(num_cls_tokens)
430
+ self.num_register_tokens = int(num_register_tokens)
431
+ self.data_format = standardize_data_format(data_format)
432
+ if len(self.neck_hidden_dims) != len(self.reassemble_factors):
433
+ raise ValueError(
434
+ "`DepthAnythingNeck` expects the length of `neck_hidden_dims` "
435
+ "and `reassemble_factors` to be the same. "
436
+ f"Received: neck_hidden_dims={neck_hidden_dims}, "
437
+ f"reassemble_factors={reassemble_factors}"
438
+ )
439
+
440
+ # Calculate the patch sizes for token to image layers.
441
+ patch_height = self.image_size[0] // self.patch_size
442
+ patch_width = self.image_size[1] // self.patch_size
443
+ # Calculate the sizes for fusion layers.
444
+ fusion_sizes = [
445
+ (int(patch_height * factor), int(patch_width * factor))
446
+ for factor in reversed(self.reassemble_factors[:-1])
447
+ ]
448
+ fusion_sizes = fusion_sizes + [
449
+ (fusion_sizes[-1][0] * 2, fusion_sizes[-1][1] * 2)
450
+ ]
451
+
452
+ self.token_to_images = [
453
+ DepthAnythingTokenToImage(
454
+ hidden_dim=backbone_hidden_dim,
455
+ patch_height=patch_height,
456
+ patch_width=patch_width,
457
+ num_cls_tokens=num_cls_tokens,
458
+ num_register_tokens=num_register_tokens,
459
+ data_format=self.data_format,
460
+ dtype=self.dtype_policy,
461
+ name=f"token_to_images_{i}",
462
+ )
463
+ for i in range(len(self.neck_hidden_dims))
464
+ ]
465
+ self.reassemble_stage = [
466
+ DepthAnythingReassembleLayer(
467
+ hidden_dim=hidden_dim,
468
+ factor=factor,
469
+ data_format=self.data_format,
470
+ dtype=self.dtype_policy,
471
+ name=f"reassemble_stage_{i}",
472
+ )
473
+ for i, (hidden_dim, factor) in enumerate(
474
+ zip(self.neck_hidden_dims, self.reassemble_factors)
475
+ )
476
+ ]
477
+ self.paddings = [
478
+ layers.ZeroPadding2D(
479
+ padding=(1, 1),
480
+ data_format=self.data_format,
481
+ dtype=self.dtype_policy,
482
+ name=f"paddings_{i}",
483
+ )
484
+ for i in range(len(self.neck_hidden_dims))
485
+ ]
486
+ self.convs = [
487
+ layers.Conv2D(
488
+ filters=self.fusion_hidden_dim,
489
+ kernel_size=3,
490
+ data_format=self.data_format,
491
+ use_bias=False,
492
+ dtype=self.dtype_policy,
493
+ name=f"convs_{i}",
494
+ )
495
+ for i in range(len(self.neck_hidden_dims))
496
+ ]
497
+ self.fusion_stage = [
498
+ DepthAnythingFeatureFusionLayer(
499
+ hidden_dim=self.fusion_hidden_dim,
500
+ size=size,
501
+ data_format=self.data_format,
502
+ dtype=self.dtype_policy,
503
+ name=f"fusion_stage_{i}",
504
+ )
505
+ for i, size in enumerate(fusion_sizes)
506
+ ]
507
+
508
+ def build(self, inputs_shape):
509
+ outputs_shape = []
510
+ # Reassemble stage.
511
+ for i, shape in enumerate(inputs_shape):
512
+ self.token_to_images[i].build(shape)
513
+ shape = self.token_to_images[i].compute_output_shape(shape)
514
+ self.reassemble_stage[i].build(shape)
515
+ shape = self.reassemble_stage[i].compute_output_shape(shape)
516
+ outputs_shape.append(shape)
517
+ # Convs.
518
+ for i, shape in enumerate(outputs_shape):
519
+ self.convs[i].build(shape)
520
+ shape = self.convs[i].compute_output_shape(shape)
521
+ outputs_shape[i] = shape
522
+ # Fusion stage.
523
+ for i, shape in enumerate(reversed(outputs_shape)):
524
+ self.fusion_stage[i].build(shape)
525
+
526
+ def call(self, inputs, training=None):
527
+ # Reassemble stage.
528
+ xs = [
529
+ self.reassemble_stage[i](
530
+ self.token_to_images[i](x), training=training
531
+ )
532
+ for i, x in enumerate(inputs)
533
+ ]
534
+ # Convs.
535
+ xs = [
536
+ self.convs[i](self.paddings[i](x), training=training)
537
+ for i, x in enumerate(xs)
538
+ ]
539
+ # Fusion stage.
540
+ fused_xs = []
541
+ fused_x = None
542
+ for i, x in enumerate(reversed(xs)):
543
+ if fused_x is None:
544
+ fused_x = self.fusion_stage[i](
545
+ x, residual=None, training=training
546
+ )
547
+ else:
548
+ fused_x = self.fusion_stage[i](
549
+ fused_x, residual=x, training=training
550
+ )
551
+ fused_xs.append(fused_x)
552
+ return fused_xs
553
+
554
+ def get_config(self):
555
+ config = super().get_config()
556
+ config.update(
557
+ {
558
+ "patch_size": self.patch_size,
559
+ "image_size": self.image_size,
560
+ "backbone_hidden_dim": self.backbone_hidden_dim,
561
+ "neck_hidden_dims": self.neck_hidden_dims,
562
+ "reassemble_factors": self.reassemble_factors,
563
+ "fusion_hidden_dim": self.fusion_hidden_dim,
564
+ "num_cls_tokens": self.num_cls_tokens,
565
+ "num_register_tokens": self.num_register_tokens,
566
+ }
567
+ )
568
+ return config
569
+
570
+
571
+ class DepthAnythingDepthEstimationHead(layers.Layer):
572
+ """A DepthAnything neck layer.
573
+
574
+ Args:
575
+ patch_size: int. The size of one side of each patch.
576
+ patch_height: int. The height of each patch.
577
+ patch_width: int. The width of each patch.
578
+ hidden_dim: int. The number of units in the hidden layers.
579
+ fusion_hidden_dim: int. The number of units in the fusion layers.
580
+ head_hidden_dim: int. The number of units in the head layers.
581
+ head_in_index: int. The index of the feature map to be used as input
582
+ to the head.
583
+ data_format: `None` or str. If specified, either `"channels_last"` or
584
+ `"channels_first"`. The ordering of the dimensions in the
585
+ inputs. `"channels_last"` corresponds to inputs with shape
586
+ `(batch_size, height, width, channels)`
587
+ while `"channels_first"` corresponds to inputs with shape
588
+ `(batch_size, channels, height, width)`. It defaults to the
589
+ `image_data_format` value found in your Keras config file at
590
+ `~/.keras/keras.json`. If you never set it, then it will be
591
+ `"channels_last"`.
592
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
593
+ including `name`, `dtype` etc.
594
+ """
595
+
596
+ def __init__(
597
+ self,
598
+ patch_size,
599
+ patch_height,
600
+ patch_width,
601
+ fusion_hidden_dim,
602
+ head_hidden_dim,
603
+ head_in_index,
604
+ data_format=None,
605
+ **kwargs,
606
+ ):
607
+ super().__init__(**kwargs)
608
+ self.patch_size = int(patch_size)
609
+ self.patch_height = int(patch_height)
610
+ self.patch_width = int(patch_width)
611
+ self.fusion_hidden_dim = int(fusion_hidden_dim)
612
+ self.head_hidden_dim = int(head_hidden_dim)
613
+ self.head_in_index = int(head_in_index)
614
+ self.data_format = standardize_data_format(data_format)
615
+
616
+ # Calculate the interpolate size.
617
+ self.interpolate_size = (
618
+ int(self.patch_height * self.patch_size),
619
+ int(self.patch_width * self.patch_size),
620
+ )
621
+
622
+ self.padding1 = layers.ZeroPadding2D(
623
+ padding=(1, 1),
624
+ data_format=self.data_format,
625
+ dtype=self.dtype_policy,
626
+ name="padding1",
627
+ )
628
+ self.conv1 = layers.Conv2D(
629
+ filters=self.fusion_hidden_dim // 2,
630
+ kernel_size=3,
631
+ data_format=self.data_format,
632
+ use_bias=True,
633
+ dtype=self.dtype_policy,
634
+ name="conv1",
635
+ )
636
+ self.padding2 = layers.ZeroPadding2D(
637
+ padding=(1, 1),
638
+ data_format=self.data_format,
639
+ dtype=self.dtype_policy,
640
+ name="padding2",
641
+ )
642
+ self.conv2 = layers.Conv2D(
643
+ filters=self.head_hidden_dim,
644
+ kernel_size=3,
645
+ data_format=self.data_format,
646
+ use_bias=True,
647
+ dtype=self.dtype_policy,
648
+ name="conv2",
649
+ )
650
+ self.activation1 = layers.ReLU(
651
+ dtype=self.dtype_policy, name="activation1"
652
+ )
653
+ self.conv3 = layers.Conv2D(
654
+ filters=1,
655
+ kernel_size=1,
656
+ data_format=self.data_format,
657
+ use_bias=True,
658
+ dtype=self.dtype_policy,
659
+ name="conv3",
660
+ )
661
+
662
+ def build(self, inputs_shape):
663
+ inputs_shape = inputs_shape[self.head_in_index]
664
+ self.padding1.build(inputs_shape)
665
+ inputs_shape = self.padding1.compute_output_shape(inputs_shape)
666
+ self.conv1.build(inputs_shape)
667
+ inputs_shape = self.conv1.compute_output_shape(inputs_shape)
668
+ inputs_shape = list(inputs_shape)
669
+ if self.data_format == "channels_last":
670
+ inputs_shape[1] = self.interpolate_size[0]
671
+ inputs_shape[2] = self.interpolate_size[1]
672
+ else:
673
+ inputs_shape[2] = self.interpolate_size[0]
674
+ inputs_shape[3] = self.interpolate_size[1]
675
+ self.padding2.build(inputs_shape)
676
+ inputs_shape = self.padding2.compute_output_shape(inputs_shape)
677
+ self.conv2.build(inputs_shape)
678
+ inputs_shape = self.conv2.compute_output_shape(inputs_shape)
679
+ self.activation1.build(inputs_shape)
680
+ self.conv3.build(inputs_shape)
681
+ inputs_shape = self.conv3.compute_output_shape(inputs_shape)
682
+
683
+ def call(self, inputs, training=None):
684
+ x = inputs[self.head_in_index]
685
+ x = self.padding1(x, training=training)
686
+ x = self.conv1(x, training=training)
687
+ x = interpolate(
688
+ x, size=self.interpolate_size, data_format=self.data_format
689
+ )
690
+ x = self.padding2(x, training=training)
691
+ x = self.conv2(x, training=training)
692
+ x = self.activation1(x, training=training)
693
+ return self.conv3(x, training=training)
694
+
695
+ def get_config(self):
696
+ config = super().get_config()
697
+ config.update(
698
+ {
699
+ "patch_size": self.patch_size,
700
+ "patch_height": self.patch_height,
701
+ "patch_width": self.patch_width,
702
+ "fusion_hidden_dim": self.fusion_hidden_dim,
703
+ "head_hidden_dim": self.head_hidden_dim,
704
+ "head_in_index": self.head_in_index,
705
+ }
706
+ )
707
+ return config
708
+
709
+ def compute_output_shape(self, input_shape):
710
+ input_shape = input_shape[self.head_in_index]
711
+ if self.data_format == "channels_last":
712
+ output_shape = [
713
+ input_shape[0],
714
+ int(self.patch_height * self.patch_size),
715
+ int(self.patch_width * self.patch_size),
716
+ 1,
717
+ ]
718
+ else:
719
+ output_shape = [
720
+ input_shape[0],
721
+ 1,
722
+ int(self.patch_height * self.patch_size),
723
+ int(self.patch_width * self.patch_size),
724
+ ]
725
+ return output_shape