brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/augment/_autograd.py +9 -6
  4. brainstate/event/__init__.py +4 -2
  5. brainstate/event/_csr.py +26 -18
  6. brainstate/event/_csr_benchmark.py +14 -0
  7. brainstate/event/_fixed_probability.py +589 -152
  8. brainstate/event/_fixed_probability_benchmark.py +128 -0
  9. brainstate/event/_fixed_probability_test.py +13 -10
  10. brainstate/event/_linear.py +267 -127
  11. brainstate/event/_linear_benckmark.py +82 -0
  12. brainstate/event/_linear_test.py +8 -3
  13. brainstate/event/_xla_custom_op.py +312 -0
  14. brainstate/event/_xla_custom_op_test.py +55 -0
  15. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  16. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  17. brainstate/nn/_dynamics/_projection_base.py +1 -1
  18. brainstate/nn/_exp_euler.py +1 -1
  19. brainstate/nn/_interaction/__init__.py +13 -4
  20. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  21. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  22. brainstate/nn/_interaction/_linear.py +582 -0
  23. brainstate/nn/_interaction/_linear_test.py +42 -0
  24. brainstate/optim/_lr_scheduler.py +1 -1
  25. brainstate/optim/_optax_optimizer.py +18 -0
  26. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
  27. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
  28. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
benchmark/COBA_2005.py ADDED
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ #
17
+ # Implementation of the paper:
18
+ #
19
+ # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
+ # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
+ #
22
+ # which is based on the balanced network proposed by:
23
+ #
24
+ # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
+ #
26
+ import os
27
+ import sys
28
+
29
+ sys.path.append('../')
30
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
31
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
32
+
33
+
34
+ import jax
35
+ import brainunit as u
36
+ import time
37
+ import brainstate as bst
38
+
39
+
40
+ class EINet(bst.nn.DynamicsGroup):
41
+ def __init__(self, scale):
42
+ super().__init__()
43
+ self.n_exc = int(3200 * scale)
44
+ self.n_inh = int(800 * scale)
45
+ self.num = self.n_exc + self.n_inh
46
+ self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
47
+ tau=20. * u.ms, tau_ref=5. * u.ms,
48
+ V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
49
+ self.E = bst.nn.AlignPostProj(
50
+ comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
51
+ syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
52
+ out=bst.nn.COBA.desc(E=0. * u.mV),
53
+ post=self.N
54
+ )
55
+ self.I = bst.nn.AlignPostProj(
56
+ comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
57
+ syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
58
+ out=bst.nn.COBA.desc(E=-80. * u.mV),
59
+ post=self.N
60
+ )
61
+
62
+ def init_state(self, *args, **kwargs):
63
+ self.rate = bst.ShortTermState(u.math.zeros(self.num))
64
+
65
+ def update(self, t, inp):
66
+ with bst.environ.context(t=t):
67
+ spk = self.N.get_spike() != 0.
68
+ self.E(spk[:self.n_exc])
69
+ self.I(spk[self.n_exc:])
70
+ self.N(inp)
71
+ self.rate.value += self.N.get_spike()
72
+
73
+
74
+ @bst.compile.jit(static_argnums=0)
75
+ def run(scale: float):
76
+ # network
77
+ net = EINet(scale)
78
+ bst.nn.init_all_states(net)
79
+
80
+ duration = 1e4 * u.ms
81
+ # simulation
82
+ with bst.environ.context(dt=0.1 * u.ms):
83
+ times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
84
+ bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
85
+
86
+ return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
87
+
88
+
89
+ for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
90
+ jax.block_until_ready(run(s))
91
+
92
+ t0 = time.time()
93
+ n, rate = jax.block_until_ready(run(s))
94
+ t1 = time.time()
95
+ print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
96
+
97
+
98
+ # A6000 NVIDIA GPU
99
+
100
+ # scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
101
+ # scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
102
+ # scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
103
+ # scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
104
+ # scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
105
+ # scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
106
+ # scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
107
+ # scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
108
+ # scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
109
+ # scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
110
+ # scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
111
+
112
+
113
+ # AMD Ryzen 7 7840HS
114
+
115
+ # scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
116
+ # scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
117
+ # scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
118
+ # scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
119
+ # scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
120
+ # scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
121
+ # scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
122
+ # scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
123
+ # scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
124
+ # scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
125
+ # scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
benchmark/CUBA_2005.py ADDED
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ #
17
+ # Implementation of the paper:
18
+ #
19
+ # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
+ # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
+ #
22
+ # which is based on the balanced network proposed by:
23
+ #
24
+ # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
+ #
26
+
27
+ import os
28
+ import sys
29
+
30
+ sys.path.append('../')
31
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
32
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
33
+
34
+
35
+ import jax
36
+ import time
37
+
38
+ import brainunit as u
39
+
40
+ import brainstate as bst
41
+
42
+
43
+
44
+ class FixedProb(bst.nn.Module):
45
+ def __init__(self, n_pre, n_post, prob, weight):
46
+ super().__init__()
47
+ self.prob = prob
48
+ self.weight = weight
49
+ self.n_pre = n_pre
50
+ self.n_post = n_post
51
+
52
+ self.mask = bst.random.rand(n_pre, n_post) < prob
53
+
54
+ def update(self, x):
55
+ return (x @ self.mask) * self.weight
56
+
57
+
58
+ class EINet(bst.nn.DynamicsGroup):
59
+ def __init__(self, scale=1.0):
60
+ super().__init__()
61
+ self.n_exc = int(3200 * scale)
62
+ self.n_inh = int(800 * scale)
63
+ self.num = self.n_exc + self.n_inh
64
+ self.N = bst.nn.LIFRef(
65
+ self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
66
+ tau=20. * u.ms, tau_ref=5. * u.ms,
67
+ V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
68
+ )
69
+ self.E = bst.nn.AlignPostProj(
70
+ comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
71
+ # comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
72
+ syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
73
+ out=bst.nn.CUBA.desc(scale=u.volt),
74
+ post=self.N
75
+ )
76
+ self.I = bst.nn.AlignPostProj(
77
+ comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
78
+ # comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
79
+ syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
80
+ out=bst.nn.CUBA.desc(scale=u.volt),
81
+ post=self.N
82
+ )
83
+
84
+ def init_state(self, *args, **kwargs):
85
+ self.rate = bst.ShortTermState(u.math.zeros(self.num))
86
+
87
+ def update(self, t, inp):
88
+ with bst.environ.context(t=t):
89
+ spk = self.N.get_spike()
90
+ self.E(spk[:self.n_exc])
91
+ self.I(spk[self.n_exc:])
92
+ self.N(inp)
93
+ self.rate.value += self.N.get_spike()
94
+
95
+
96
+ @bst.compile.jit(static_argnums=0)
97
+ def run(scale: float):
98
+ # network
99
+ net = EINet(scale)
100
+ bst.nn.init_all_states(net)
101
+
102
+ duration = 1e4 * u.ms
103
+ # simulation
104
+ with bst.environ.context(dt=0.1 * u.ms):
105
+ times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
106
+ bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
107
+ # pbar=bst.compile.ProgressBar(100)
108
+ )
109
+
110
+ return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
111
+
112
+
113
+ for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
114
+ jax.block_until_ready(run(s))
115
+
116
+ t0 = time.time()
117
+ n, rate = jax.block_until_ready(run(s))
118
+ t1 = time.time()
119
+ print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
120
+
121
+
122
+ # A6000 NVIDIA GPU
123
+
124
+ # scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
125
+ # scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
126
+ # scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
127
+ # scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
128
+ # scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
129
+ # scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
130
+ # scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
131
+ # scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
132
+ # scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
133
+ # scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
134
+ # scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
135
+
136
+
137
+ # AMD Ryzen 7 7840HS
138
+
139
+ # scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
140
+ # scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
141
+ # scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
142
+ # scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
143
+ # scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
144
+ # scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
145
+ # scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
146
+ # scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
147
+ # scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
148
+ # scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
149
+ # scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz
@@ -45,7 +45,7 @@ from brainstate.typing import PyTree, Missing
45
45
  from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
46
46
 
47
47
  __all__ = [
48
- 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
48
+ 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
49
49
  ]
50
50
 
51
51
  A = TypeVar('A')
@@ -159,6 +159,9 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False
159
159
  return jacfun
160
160
 
161
161
 
162
+ TransformFn = Callable
163
+
164
+
162
165
  class GradientTransform(PrettyRepr):
163
166
  """
164
167
  Automatic Differentiation Transformations for the ``State`` system.
@@ -168,11 +171,11 @@ class GradientTransform(PrettyRepr):
168
171
  def __init__(
169
172
  self,
170
173
  target: Callable,
171
- transform: Callable,
172
- grad_states: Any,
173
- argnums: Optional[Union[int, Sequence[int]]],
174
- return_value: bool,
175
- has_aux: bool,
174
+ transform: TransformFn,
175
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
176
+ argnums: Optional[Union[int, Sequence[int]]] = None,
177
+ return_value: bool = False,
178
+ has_aux: bool = False,
176
179
  transform_params: Optional[Dict[str, Any]] = None,
177
180
  ):
178
181
  # gradient variables
@@ -19,7 +19,9 @@ from ._csr import __all__ as __all_csr
19
19
  from ._fixed_probability import *
20
20
  from ._fixed_probability import __all__ as __all_fixed_probability
21
21
  from ._linear import *
22
+ from ._xla_custom_op import *
23
+ from ._xla_custom_op import __all__ as __all_xla_custom_op
22
24
  from ._linear import __all__ as __all_linear
23
25
 
24
- __all__ = __all_fixed_probability + __all_linear + __all_csr
25
- del __all_fixed_probability, __all_linear, __all_csr
26
+ __all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
27
+ del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
brainstate/event/_csr.py CHANGED
@@ -21,12 +21,11 @@ import jax
21
21
  import jax.numpy as jnp
22
22
  import numpy as np
23
23
 
24
- from brainstate._state import ParamState, State
24
+ from brainstate._state import ParamState
25
25
  from brainstate._utils import set_module_as
26
26
  from brainstate.init import param
27
27
  from brainstate.nn._module import Module
28
- from brainstate.typing import ArrayLike
29
- from ._misc import IntScalar
28
+ from brainstate.typing import ArrayLike, Size
30
29
 
31
30
  __all__ = [
32
31
  'CSRLinear',
@@ -39,12 +38,12 @@ class CSRLinear(Module):
39
38
 
40
39
  Parameters
41
40
  ----------
42
- n_pre : int
43
- Number of pre-synaptic neurons.
44
- n_post : int
45
- Number of post-synaptic neurons.
41
+ in_size : Size
42
+ Number of pre-synaptic neurons, i.e., input size.
43
+ out_size : Size
44
+ Number of post-synaptic neurons, i.e., output size.
46
45
  weight : float or callable or jax.Array or brainunit.Quantity
47
- Maximum synaptic conductance.
46
+ Maximum synaptic conductance or a function that returns the maximum synaptic conductance.
48
47
  name : str, optional
49
48
  Name of the module.
50
49
  """
@@ -53,8 +52,8 @@ class CSRLinear(Module):
53
52
 
54
53
  def __init__(
55
54
  self,
56
- n_pre: IntScalar,
57
- n_post: IntScalar,
55
+ in_size: Size,
56
+ out_size: Size,
58
57
  indptr: ArrayLike,
59
58
  indices: ArrayLike,
60
59
  weight: Union[Callable, ArrayLike],
@@ -63,10 +62,11 @@ class CSRLinear(Module):
63
62
  ):
64
63
  super().__init__(name=name)
65
64
 
66
- self.in_size = n_pre
67
- self.out_size = n_post
68
- self.n_pre = n_pre
69
- self.n_post = n_post
65
+ # network size
66
+ self.in_size = in_size
67
+ self.out_size = out_size
68
+ self.n_pre = self.in_size[-1]
69
+ self.n_post = self.out_size[-1]
70
70
 
71
71
  # gradient mode
72
72
  assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
@@ -77,9 +77,10 @@ class CSRLinear(Module):
77
77
  indices = jnp.asarray(indices)
78
78
  assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
79
  assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
- assert indptr.size == n_pre + 1, f"indptr must have size {n_pre + 1}. Got: {indptr.size}"
81
- self.indptr = u.math.asarray(indptr)
82
- self.indices = u.math.asarray(indices)
80
+ assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
81
+ with jax.ensure_compile_time_eval():
82
+ self.indptr = u.math.asarray(indptr)
83
+ self.indices = u.math.asarray(indices)
83
84
 
84
85
  # maximum synaptic conductance
85
86
  weight = param(weight, (len(indices),), allow_none=False)
@@ -88,7 +89,9 @@ class CSRLinear(Module):
88
89
  self.weight = ParamState(weight)
89
90
 
90
91
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
91
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
92
+ weight = self.weight.value
93
+
94
+ # return zero if no pre-synaptic neurons
92
95
  if len(self.indices) == 0:
93
96
  r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
94
97
  dtype=weight.dtype,
@@ -96,6 +99,8 @@ class CSRLinear(Module):
96
99
  return u.maybe_decimal(r)
97
100
 
98
101
  device_kind = jax.devices()[0].platform # spk.device.device_kind
102
+
103
+ # CPU implementation
99
104
  if device_kind == 'cpu':
100
105
  return cpu_event_csr(
101
106
  u.math.asarray(spk),
@@ -104,8 +109,11 @@ class CSRLinear(Module):
104
109
  u.math.asarray(weight),
105
110
  n_post=self.n_post, grad_mode=self.grad_mode
106
111
  )
112
+
113
+ # GPU/TPU implementation
107
114
  elif device_kind in ['gpu', 'tpu']:
108
115
  raise NotImplementedError()
116
+
109
117
  else:
110
118
  raise ValueError(f"Unsupported device: {device_kind}")
111
119
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================