lt-tensor 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,114 @@
1
+ __all__ = [
2
+ "ResidualBlock1D_B",
3
+ "Downsample1D",
4
+ "Upsample1D",
5
+ "DiffusionUNet",
6
+ "DiffusionUNetT",
7
+ ]
8
+
9
+ from ..._torch_commons import *
10
+ from ..._basics import Model
11
+ from ..residual import ResBlock1D, ResBlock1DT
12
+ from ...misc_utils import log_tensor
13
+
14
+
15
+
16
+ class Downsample1D(Model):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ out_channels: int,
21
+ ):
22
+ super().__init__()
23
+ self.pool = nn.Conv1d(in_channels, out_channels, 4, stride=2, padding=1)
24
+
25
+ def forward(self, x):
26
+ return self.pool(x)
27
+
28
+
29
+ class Upsample1D(Model):
30
+ def __init__(
31
+ self,
32
+ in_channels: int,
33
+ out_channels: int,
34
+ activation=nn.ReLU(inplace=True),
35
+ ):
36
+ super().__init__()
37
+ self.up = nn.Sequential(
38
+ nn.ConvTranspose1d(
39
+ in_channels, out_channels, kernel_size=4, stride=2, padding=1
40
+ ),
41
+ nn.BatchNorm1d(out_channels),
42
+ activation,
43
+ )
44
+
45
+ def forward(self, x):
46
+ return self.up(x)
47
+
48
+
49
+ class DiffusionUNet(Model):
50
+ def __init__(self, in_channels=1, base_channels=64, out_channels=1, depth=4):
51
+ super().__init__()
52
+
53
+ self.depth = depth
54
+ self.encoder_blocks = nn.ModuleList()
55
+ self.downsamples = nn.ModuleList()
56
+ self.upsamples = nn.ModuleList()
57
+ self.decoder_blocks = nn.ModuleList()
58
+ # Keep track of channel sizes per layer for skip connections
59
+ self.channels = [in_channels] # starting input channel
60
+ for i in range(depth):
61
+ enc_in = self.channels[-1]
62
+ enc_out = base_channels * (2**i)
63
+ # Encoder block and downsample
64
+ self.encoder_blocks.append(ResBlock1D(enc_in, enc_out))
65
+ self.downsamples.append(
66
+ Downsample1D(enc_out, enc_out)
67
+ ) # halve time, keep channels
68
+ self.channels.append(enc_out)
69
+ # Bottleneck
70
+ bottleneck_ch = self.channels[-1]
71
+ self.bottleneck = ResBlock1D(bottleneck_ch, bottleneck_ch)
72
+ # Decoder blocks (reverse channel flow)
73
+ for i in reversed(range(depth)):
74
+ skip_ch = self.channels[i + 1] # from encoder
75
+ dec_out = self.channels[i] # match earlier stage's output
76
+ self.upsamples.append(Upsample1D(skip_ch, skip_ch))
77
+ self.decoder_blocks.append(ResBlock1D(skip_ch * 2, dec_out))
78
+ # Final output projection (out_channels)
79
+ self.final = nn.Conv1d(in_channels, out_channels, kernel_size=1)
80
+
81
+ def forward(self, x: Tensor):
82
+ skips = []
83
+
84
+ # Encoder
85
+ for enc, down in zip(self.encoder_blocks, self.downsamples):
86
+ # log_tensor(x, "before enc")
87
+ x = enc(x)
88
+ skips.append(x)
89
+ x = down(x)
90
+
91
+ # Bottleneck
92
+ x = self.bottleneck(x)
93
+
94
+ # Decoder
95
+ for up, dec, skip in zip(self.upsamples, self.decoder_blocks, reversed(skips)):
96
+ x = up(x)
97
+
98
+ # Match lengths via trimming or padding
99
+ if x.shape[-1] > skip.shape[-1]:
100
+ x = x[..., : skip.shape[-1]]
101
+ elif x.shape[-1] < skip.shape[-1]:
102
+ diff = skip.shape[-1] - x.shape[-1]
103
+ x = F.pad(x, (0, diff))
104
+
105
+ x = torch.cat([x, skip], dim=1) # concat on channels
106
+ x = dec(x)
107
+
108
+ # Final 1x1 conv
109
+ return self.final(x)
110
+
111
+
112
+
113
+
114
+
@@ -0,0 +1,236 @@
1
+ from .._torch_commons import *
2
+ from .._basics import Model
3
+ import math
4
+ from ..misc_utils import log_tensor
5
+
6
+
7
+ def initialize__weights(model: nn.Module, method: str = "xavier"):
8
+ """Initialize model weights using specified method."""
9
+ for name, param in model.named_parameters():
10
+ if "weight" in name:
11
+ if method == "xavier":
12
+ nn.init.xavier_uniform_(param)
13
+ elif method == "kaiming":
14
+ nn.init.kaiming_uniform_(param, nonlinearity="relu")
15
+ elif "bias" in name:
16
+ nn.init.constant_(param, 0)
17
+
18
+
19
+ def spectral_norm_select(module: Module, enabled: bool):
20
+ if enabled:
21
+ return spectral_norm(module)
22
+ return module
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if "Conv" in classname:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ class ResBlock1D(Model):
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ out_channels: int,
36
+ kernel_size: int = 3,
37
+ dilation: Union[Sequence[int], int] = (1, 3, 5),
38
+ activation: nn.Module = nn.LeakyReLU(0.1),
39
+ num_groups: int = 1,
40
+ batched: bool = True,
41
+ ):
42
+ super().__init__()
43
+ self.conv = nn.ModuleList()
44
+ if isinstance(dilation, int):
45
+ dilation = [dilation]
46
+
47
+ if batched:
48
+ layernorm_fn = lambda x: nn.GroupNorm(num_groups=num_groups, num_channels=x)
49
+ else:
50
+ layernorm_fn = lambda x: nn.LayerNorm(normalized_shape=x)
51
+ for i, dil in enumerate(dilation):
52
+
53
+ self.conv.append(
54
+ nn.ModuleDict(
55
+ dict(
56
+ net=nn.Sequential(
57
+ self._get_conv_layer(
58
+ in_channels, in_channels, kernel_size, dil
59
+ ),
60
+ activation,
61
+ self._get_conv_layer(
62
+ in_channels, in_channels, kernel_size, 1, True
63
+ ),
64
+ activation,
65
+ ),
66
+ l_norm=layernorm_fn(in_channels),
67
+ )
68
+ )
69
+ )
70
+ self.final = nn.Sequential(
71
+ self._get_conv_layer(in_channels, out_channels, kernel_size, 1, True),
72
+ activation,
73
+ )
74
+ self.conv.apply(init_weights)
75
+
76
+ def _get_conv_layer(
77
+ self,
78
+ channels_in: int,
79
+ channels_out: int,
80
+ kernel_size: int,
81
+ dilation: int,
82
+ pad_gate: bool = False,
83
+ ):
84
+ return weight_norm(
85
+ nn.Conv1d(
86
+ in_channels=channels_in,
87
+ out_channels=channels_out,
88
+ kernel_size=kernel_size,
89
+ stride=1,
90
+ dilation=dilation,
91
+ padding=(
92
+ int((kernel_size * dilation - dilation) / 2)
93
+ if not pad_gate
94
+ else int((kernel_size * 1 - 1) / 2)
95
+ ),
96
+ )
97
+ )
98
+
99
+ def forward(self, x: Tensor):
100
+ for i, layer in enumerate(self.conv):
101
+ xt = layer["net"](x)
102
+ x = xt + x
103
+ x = layer["l_norm"](x)
104
+ return self.final(x)
105
+
106
+ def remove_weight_norm(self):
107
+ for module in self.modules():
108
+ try:
109
+ remove_weight_norm(module)
110
+ except ValueError:
111
+ pass # Not normed, skip
112
+
113
+
114
+ class ResBlock2D(Model):
115
+ def __init__(
116
+ self,
117
+ in_channels,
118
+ out_channels,
119
+ downsample=False,
120
+ spec_norm: bool = False,
121
+ ):
122
+ super().__init__()
123
+ stride = 2 if downsample else 1
124
+
125
+ self.block = nn.Sequential(
126
+ spectral_norm_select(
127
+ nn.Conv2d(in_channels, out_channels, 3, stride, 1), spec_norm
128
+ ),
129
+ nn.LeakyReLU(0.2),
130
+ spectral_norm_select(
131
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), spec_norm
132
+ ),
133
+ )
134
+
135
+ self.skip = nn.Identity()
136
+ if downsample or in_channels != out_channels:
137
+ self.skip = spectral_norm_select(
138
+ nn.Conv2d(in_channels, out_channels, 1, stride), spec_norm
139
+ )
140
+ # on less to be handled every cicle
141
+ self.sqrt_2 = math.sqrt(2)
142
+
143
+ def forward(self, x):
144
+ return (self.block(x) + self.skip(x)) / self.sqrt_2
145
+
146
+
147
+ class ResBlock1DT(Model):
148
+ """For time based residual layers"""
149
+
150
+ def __init__(
151
+ self,
152
+ in_channels: int,
153
+ out_channels: int,
154
+ kernel_size: int = 3,
155
+ dilation: Union[Sequence[int], int] = (1, 3, 5),
156
+ activation: nn.Module = nn.LeakyReLU(0.1),
157
+ num_groups: int = 1,
158
+ time_emb_dim: int = 1,
159
+ batched: bool = True,
160
+ ):
161
+ super().__init__()
162
+ self.conv = nn.ModuleList()
163
+ if isinstance(dilation, int):
164
+ dilation = [dilation]
165
+ self.time_proj = nn.Linear(time_emb_dim, out_channels)
166
+ if batched:
167
+ layernorm_fn = lambda x: nn.GroupNorm(num_groups=num_groups, num_channels=x)
168
+ else:
169
+ layernorm_fn = lambda x: nn.LayerNorm(normalized_shape=x)
170
+ for i, dil in enumerate(dilation):
171
+
172
+ self.conv.append(
173
+ nn.ModuleDict(
174
+ dict(
175
+ net=nn.Sequential(
176
+ self._get_conv_layer(
177
+ in_channels, in_channels, kernel_size, dil
178
+ ),
179
+ activation,
180
+ self._get_conv_layer(
181
+ in_channels, in_channels, kernel_size, 1, True
182
+ ),
183
+ activation,
184
+ ),
185
+ l_norm=layernorm_fn(in_channels),
186
+ )
187
+ )
188
+ )
189
+ self.final = nn.Sequential(
190
+ self._get_conv_layer(in_channels, out_channels, kernel_size, 1, True),
191
+ activation,
192
+ )
193
+ self.conv.apply(init_weights)
194
+
195
+ def _get_conv_layer(
196
+ self,
197
+ channels_in: int,
198
+ channels_out: int,
199
+ kernel_size: int,
200
+ dilation: int,
201
+ pad_gate: bool = False,
202
+ ):
203
+ return weight_norm(
204
+ nn.Conv1d(
205
+ in_channels=channels_in,
206
+ out_channels=channels_out,
207
+ kernel_size=kernel_size,
208
+ stride=1,
209
+ dilation=dilation,
210
+ padding=(
211
+ int((kernel_size * dilation - dilation) / 2)
212
+ if not pad_gate
213
+ else int((kernel_size * 1 - 1) / 2)
214
+ ),
215
+ )
216
+ )
217
+
218
+ def forward(self, x: Tensor, t_embed: Optional[Tensor] = None):
219
+ if t_embed is not None:
220
+ t_emb = self.time_proj(t_embed).unsqueeze(-1) # [B, C, 1]
221
+
222
+ for i, layer in enumerate(self.conv):
223
+ if t_embed is not None:
224
+ xt = layer["net"](x) + t_emb
225
+ else:
226
+ xt = layer["net"](x)
227
+ x = xt + x
228
+ x = layer["l_norm"](x)
229
+ return self.final(x)
230
+
231
+ def remove_weight_norm(self):
232
+ for module in self.modules():
233
+ try:
234
+ remove_weight_norm(module)
235
+ except ValueError:
236
+ pass # Not normed, skip
@@ -0,0 +1,6 @@
1
+ from .models import (
2
+ TransformerEncoderLayer,
3
+ TransformerDecoderLayer,
4
+ TransformerEncoder,
5
+ TransformerDecoder,
6
+ )
@@ -0,0 +1,132 @@
1
+ import math
2
+ from ..._torch_commons import *
3
+ from ..._basics import Model
4
+ from lt_utils.misc_utils import default
5
+
6
+ from .positional_encoders import *
7
+ from ..basic import FeedForward
8
+
9
+
10
+ def init_weights(module):
11
+ if isinstance(module, nn.Linear):
12
+ nn.init.xavier_uniform_(module.weight)
13
+ if module.bias is not None:
14
+ nn.init.constant_(module.bias, 0)
15
+ elif isinstance(module, nn.Embedding):
16
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
17
+ elif isinstance(module, nn.LayerNorm):
18
+ nn.init.constant_(module.bias, 0)
19
+ nn.init.constant_(module.weight, 1.0)
20
+
21
+
22
+ class TransformerEncoderLayer(nn.Module):
23
+ def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
24
+ super().__init__()
25
+ self.self_attn = nn.MultiheadAttention(
26
+ d_model, n_heads, dropout=dropout, batch_first=True
27
+ )
28
+ self.norm1 = nn.LayerNorm(d_model)
29
+ self.ff = FeedForward(d_model, ff_size, dropout)
30
+ self.norm2 = nn.LayerNorm(d_model)
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
34
+ attn_output, _ = self.self_attn(x, x, x, attn_mask=src_mask)
35
+ x = self.norm1(x + self.dropout(attn_output))
36
+ ff_output = self.ff(x)
37
+ x = self.norm2(x + self.dropout(ff_output))
38
+ return x
39
+
40
+
41
+ class TransformerDecoderLayer(nn.Module):
42
+ def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
43
+ super().__init__()
44
+ self.self_attn = nn.MultiheadAttention(
45
+ d_model, n_heads, dropout=dropout, batch_first=True
46
+ )
47
+ self.norm1 = nn.LayerNorm(d_model)
48
+
49
+ self.cross_attn = nn.MultiheadAttention(
50
+ d_model, n_heads, dropout=dropout, batch_first=True
51
+ )
52
+ self.norm2 = nn.LayerNorm(d_model)
53
+
54
+ self.ff = FeedForward(d_model, ff_size, dropout)
55
+ self.norm3 = nn.LayerNorm(d_model)
56
+ self.dropout = nn.Dropout(dropout)
57
+
58
+ def forward(
59
+ self,
60
+ x: Tensor, # Decoder input [B, T, d_model]
61
+ encoder_out: Tensor, # Encoder output [B, S, d_model]
62
+ tgt_mask: Optional[Tensor] = None,
63
+ memory_mask: Optional[Tensor] = None,
64
+ ) -> Tensor:
65
+ # 1. Masked Self-Attention
66
+ attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
67
+ x = self.norm1(x + self.dropout(attn_output))
68
+
69
+ # 2. Cross-Attention
70
+ cross_output, _ = self.cross_attn(
71
+ x, encoder_out, encoder_out, attn_mask=memory_mask
72
+ )
73
+ x = self.norm2(x + self.dropout(cross_output))
74
+
75
+ # 3. FeedForward
76
+ ff_output = self.ff(x)
77
+ x = self.norm3(x + self.dropout(ff_output))
78
+ return x
79
+
80
+
81
+ class TransformerEncoder(Model):
82
+ def __init__(
83
+ self,
84
+ d_model: int = 64,
85
+ n_heads: int = 4,
86
+ ff_size: int = 256,
87
+ num_layers: int = 2,
88
+ dropout: float = 0.1,
89
+ ):
90
+ super().__init__()
91
+ self.layers = nn.ModuleList(
92
+ [
93
+ TransformerEncoderLayer(d_model, n_heads, ff_size, dropout)
94
+ for _ in range(num_layers)
95
+ ]
96
+ )
97
+
98
+ def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
99
+
100
+ for layer in self.layers:
101
+ x = layer(x, src_mask)
102
+ return x
103
+
104
+
105
+ class TransformerDecoder(Model):
106
+ def __init__(
107
+ self,
108
+ d_model: int = 64,
109
+ n_heads: int = 2,
110
+ ff_size: int = 256,
111
+ num_layers: int = 2,
112
+ dropout: float = 0.1,
113
+ ):
114
+ super().__init__()
115
+
116
+ self.layers = nn.ModuleList(
117
+ [
118
+ TransformerDecoderLayer(d_model, n_heads, ff_size, dropout)
119
+ for _ in range(num_layers)
120
+ ]
121
+ )
122
+
123
+ def forward(
124
+ self,
125
+ x: Tensor,
126
+ encoder_out: Tensor,
127
+ tgt_mask: Optional[Tensor] = None,
128
+ memory_mask: Optional[Tensor] = None,
129
+ ) -> Tensor:
130
+ for layer in self.layers:
131
+ x = layer(x, encoder_out, tgt_mask, memory_mask)
132
+ return x
@@ -0,0 +1,95 @@
1
+ __all__ = ["RotaryEmbedding", "PositionalEncoding"]
2
+ import math
3
+ from ..._torch_commons import *
4
+ from ..._basics import Model
5
+
6
+
7
+ class RotaryEmbedding(Module):
8
+ def __init__(self, dim: int, base: int = 10000):
9
+ """
10
+ Rotary Positional Embedding Module.
11
+ Args:
12
+ dim (int): The dimension of the rotary embedding (must be even).
13
+ base (int): The base frequency scale (default: 10000).
14
+ """
15
+ super().__init__()
16
+ assert dim % 2 == 0, "Rotary dimension must be even"
17
+ self.dim = dim
18
+ self.base = base
19
+
20
+ def forward(self, x, seq_len=None):
21
+ """
22
+ Apply rotary embeddings to input tensor.
23
+ Args:
24
+ x (torch.Tensor): Input tensor of shape [batch, seq_len, dim].
25
+ seq_len (int, optional): Override for sequence length.
26
+ Returns:
27
+ torch.Tensor: Tensor with rotary embeddings applied.
28
+ """
29
+ bsz, seq_len = x.shape[0], seq_len or x.shape[1]
30
+ device = x.device
31
+
32
+ pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
33
+ freqs = torch.pow(
34
+ self.base, -torch.arange(0, self.dim, 2, device=device).float() / self.dim
35
+ )
36
+ angle = pos * freqs # [seq_len, dim/2]
37
+
38
+ sin = torch.sin(angle)
39
+ cos = torch.cos(angle)
40
+
41
+ # Expand and interleave to [seq_len, dim]
42
+ sin = torch.stack((sin, sin), dim=-1).reshape(seq_len, self.dim)
43
+ cos = torch.stack((cos, cos), dim=-1).reshape(seq_len, self.dim)
44
+
45
+ sin = sin.unsqueeze(0).expand(bsz, -1, -1) # [batch, seq_len, dim]
46
+ cos = cos.unsqueeze(0).expand(bsz, -1, -1)
47
+
48
+ return self.apply_rotary(x, sin, cos)
49
+
50
+ def _apply_rotary(self, x, sin, cos):
51
+ """This version may still be useful, but for now its the problem for the text model"""
52
+ x1, x2 = x.chunk(2, dim=-1)
53
+ return torch.cat((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1)
54
+
55
+ def apply_rotary(self, x, sin, cos):
56
+ """x: [batch, seq_len, dim] → assume dim is even"""
57
+ b, s, d = x.shape
58
+ x = x.view(b, s, d // 2, 2) # [b, s, d//2, 2]
59
+ sin = sin.view(b, s, d // 2, 2)
60
+ cos = cos.view(b, s, d // 2, 2)
61
+
62
+ # Apply rotation: even, odd = x[..., 0], x[..., 1]
63
+ x_rotated = torch.stack(
64
+ [
65
+ x[..., 0] * cos[..., 0] - x[..., 1] * sin[..., 0],
66
+ x[..., 0] * sin[..., 0] + x[..., 1] * cos[..., 0],
67
+ ],
68
+ dim=-1,
69
+ )
70
+
71
+ return x_rotated.view(b, s, d) # Back to [b, s, d]
72
+
73
+
74
+ class PositionalEncoding(Module):
75
+ def __init__(self, d_model: int, max_len: int = 8192):
76
+ super().__init__()
77
+ # create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
78
+ pe = torch.zeros(max_len, d_model)
79
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
80
+ 1
81
+ ) # (max_len, 1)
82
+ div_term = torch.exp(
83
+ torch.arange(0, d_model, 2, dtype=torch.float)
84
+ * (-math.log(10000.0) / d_model)
85
+ )
86
+ pe[:, 0::2] = torch.sin(position * div_term)
87
+ pe[:, 1::2] = torch.cos(position * div_term)
88
+ pe = pe.unsqueeze(0)
89
+ self.register_buffer("pe", pe, persistent=False) # Shape: (1, max_len, d_model)
90
+
91
+ def forward(self, x: Tensor, seq_len: Optional[Tensor] = None):
92
+ # x shape: (batch_size, seq_len, d_model)
93
+ s_sz = seq_len or x.size(1)
94
+ x = x + self.pe[:, :s_sz]
95
+ return x
@@ -0,0 +1,70 @@
1
+ from numba import njit, prange
2
+
3
+
4
+ @njit()
5
+ def maximum_path_each(path, value, t_x, t_y, max_neg_val):
6
+ index = t_x - 1
7
+ # Forward pass: Calculate max path sums
8
+ for y in range(t_y):
9
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
10
+ v_cur = max_neg_val if x == y else value[x, y - 1]
11
+ v_prev = (
12
+ 0.0
13
+ if (x == 0 and y == 0)
14
+ else (max_neg_val if x == 0 else value[x - 1, y - 1])
15
+ )
16
+ value[x, y] = max(v_cur, v_prev) + value[x, y]
17
+
18
+ # Backtrack to store the path
19
+ for y in range(t_y - 1, -1, -1):
20
+ path[index, y] = 1
21
+ if index != 0 and (index == y or value[index, y - 1] < value[index - 1, y - 1]):
22
+ index -= 1
23
+
24
+
25
+ @njit() # Took almost 10x the time while testing using "parallel=True".
26
+ def maximum_path(paths, values, t_xs, t_ys, max_neg_val=-1e9):
27
+ """
28
+ Example:
29
+ ```python
30
+ paths = tc.randn((2, 3, 3)).numpy()
31
+ values = tc.randn((2, 3, 3)).numpy()
32
+ t_xs = tc.tensor([3, 3, 3]).numpy()
33
+ t_ys = tc.tensor([3, 3]).numpy()
34
+
35
+ # to display values (before) and paths:
36
+ print("=====================")
37
+ print("Paths:")
38
+ print(paths)
39
+ print("Original Values:")
40
+ print(values)
41
+
42
+ maximum_path(paths, values, t_xs, t_ys)
43
+
44
+ print("Updated Values:")
45
+ print(values)
46
+ print("=====================")
47
+
48
+ ```
49
+ Outputs:
50
+ ```md
51
+ =====================
52
+ Paths:
53
+ [[[ 2.310408 -1.9375949 -0.57884663]
54
+ [ 1.0308106 1.0793993 0.4461908 ]
55
+ [ 0.26789713 0.48924422 0.3409592 ]]]
56
+ Original Values:
57
+ [[[-0.48256454 0.51348686 -1.8236492 ]
58
+ [ 0.9949021 -0.6066166 0.18991096]
59
+ [ 1.2555764 -0.24222293 -0.78757876]]]
60
+ Updated Values:
61
+ [[[-0.48256454 0.51348686 -1.8236492 ]
62
+ [ 0.9949021 -1.0891812 0.18991096]
63
+ [ 1.2555764 -0.24222293 -1.87676 ]]]
64
+ =====================
65
+ ```
66
+ This may not be the standard, but may work for your project.
67
+ """
68
+ batch_size = values.shape[0]
69
+ for i in prange(batch_size):
70
+ maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)