nmn 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/nmn.py CHANGED
@@ -4,26 +4,18 @@ import typing as tp
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
- import numpy as np
8
7
  from jax import lax
9
- import opt_einsum
10
8
 
11
- from flax.core.frozen_dict import FrozenDict
12
9
  from flax import nnx
13
- from flax.nnx import rnglib, variablelib
14
- from flax.nnx.module import Module, first_from
10
+ from flax.nnx import rnglib
11
+ from flax.nnx.module import Module
15
12
  from flax.nnx.nn import dtypes, initializers
16
13
  from flax.typing import (
17
14
  Dtype,
18
- Shape,
19
15
  Initializer,
20
16
  PrecisionLike,
21
17
  DotGeneralT,
22
- ConvGeneralDilatedT,
23
- PaddingLike,
24
- LaxPadding,
25
18
  PromoteDtypeFn,
26
- EinsumT,
27
19
  )
28
20
 
29
21
  Array = jax.Array
@@ -60,21 +52,26 @@ class YatNMN(Module):
60
52
  in_features: the number of input features.
61
53
  out_features: the number of output features.
62
54
  use_bias: whether to add a bias to the output (default: True).
55
+ use_alpha: whether to use alpha scaling (default: True).
56
+ use_dropconnect: whether to use DropConnect (default: False).
63
57
  dtype: the dtype of the computation (default: infer from input and params).
64
58
  param_dtype: the dtype passed to parameter initializers (default: float32).
65
59
  precision: numerical precision of the computation see ``jax.lax.Precision``
66
60
  for details.
67
61
  kernel_init: initializer function for the weight matrix.
68
62
  bias_init: initializer function for the bias.
63
+ alpha_init: initializer function for the alpha.
69
64
  dot_general: dot product function.
70
65
  promote_dtype: function to promote the dtype of the arrays to the desired
71
66
  dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
72
67
  and a ``dtype`` keyword argument, and return a tuple of arrays with the
73
68
  promoted dtype.
69
+ epsilon: A small float added to the denominator to prevent division by zero.
70
+ drop_rate: dropout rate for DropConnect (default: 0.0).
74
71
  rngs: rng key.
75
72
  """
76
73
 
77
- __data__ = ('kernel', 'bias')
74
+ __data__ = ('kernel', 'bias', 'alpha', 'dropconnect_key')
78
75
 
79
76
  def __init__(
80
77
  self,
@@ -83,6 +80,7 @@ class YatNMN(Module):
83
80
  *,
84
81
  use_bias: bool = True,
85
82
  use_alpha: bool = True,
83
+ use_dropconnect: bool = False,
86
84
  dtype: tp.Optional[Dtype] = None,
87
85
  param_dtype: Dtype = jnp.float32,
88
86
  precision: PrecisionLike = None,
@@ -91,8 +89,9 @@ class YatNMN(Module):
91
89
  alpha_init: Initializer = default_alpha_init,
92
90
  dot_general: DotGeneralT = lax.dot_general,
93
91
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
94
- rngs: rnglib.Rngs,
95
92
  epsilon: float = 1e-5,
93
+ drop_rate: float = 0.0,
94
+ rngs: rnglib.Rngs,
96
95
  ):
97
96
 
98
97
  kernel_key = rngs.params()
@@ -109,7 +108,7 @@ class YatNMN(Module):
109
108
  self.alpha: nnx.Param[jax.Array] | None
110
109
  if use_alpha:
111
110
  alpha_key = rngs.params()
112
- self.alpha = nnx.Param(alpha_init(bias_key, (1,), param_dtype))
111
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
113
112
  else:
114
113
  self.alpha = None
115
114
 
@@ -117,6 +116,7 @@ class YatNMN(Module):
117
116
  self.out_features = out_features
118
117
  self.use_bias = use_bias
119
118
  self.use_alpha = use_alpha
119
+ self.use_dropconnect = use_dropconnect
120
120
  self.dtype = dtype
121
121
  self.param_dtype = param_dtype
122
122
  self.precision = precision
@@ -125,12 +125,19 @@ class YatNMN(Module):
125
125
  self.dot_general = dot_general
126
126
  self.promote_dtype = promote_dtype
127
127
  self.epsilon = epsilon
128
+ self.drop_rate = drop_rate
129
+
130
+ if use_dropconnect:
131
+ self.dropconnect_key = rngs.params()
132
+ else:
133
+ self.dropconnect_key = None
128
134
 
129
- def __call__(self, inputs: Array) -> Array:
135
+ def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
130
136
  """Applies a linear transformation to the inputs along the last dimension.
131
137
 
132
138
  Args:
133
139
  inputs: The nd-array to be transformed.
140
+ deterministic: If true, DropConnect is not applied (e.g., during inference).
134
141
 
135
142
  Returns:
136
143
  The transformed input.
@@ -139,6 +146,11 @@ class YatNMN(Module):
139
146
  bias = self.bias.value if self.bias is not None else None
140
147
  alpha = self.alpha.value if self.alpha is not None else None
141
148
 
149
+ if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
150
+ keep_prob = 1.0 - self.drop_rate
151
+ mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel.shape)
152
+ kernel = (kernel * mask) / keep_prob
153
+
142
154
  inputs, kernel, bias, alpha = self.promote_dtype(
143
155
  (inputs, kernel, bias, alpha), dtype=self.dtype
144
156
  )
@@ -166,5 +178,4 @@ class YatNMN(Module):
166
178
  scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
167
179
  y = y * scale
168
180
 
169
-
170
181
  return y