brainstate 0.1.0.post20241219__py2.py3-none-any.whl → 0.1.0.post20241220__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_jit.py +20 -14
- {brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/RECORD +6 -6
- {brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/top_level.txt +0 -0
brainstate/compile/_jit.py
CHANGED
@@ -40,7 +40,7 @@ class JittedFunction(Callable):
|
|
40
40
|
jitted_fun: jax.stages.Wrapped # the jitted function
|
41
41
|
clear_cache: Callable # clear the cache of the jitted function
|
42
42
|
eval_shape: Callable # evaluate the shape of the jitted function
|
43
|
-
|
43
|
+
compile: Callable # lower the jitted function
|
44
44
|
trace: Callable # trace the jitted
|
45
45
|
|
46
46
|
def __call__(self, *args, **kwargs):
|
@@ -104,7 +104,18 @@ def _get_jitted_fun(
|
|
104
104
|
def eval_shape():
|
105
105
|
raise NotImplementedError
|
106
106
|
|
107
|
-
def
|
107
|
+
def trace():
|
108
|
+
"""Trace this function explicitly for the given arguments.
|
109
|
+
|
110
|
+
A traced function is staged out of Python and translated to a jaxpr. It is
|
111
|
+
ready for lowering but not yet lowered.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
A ``Traced`` instance representing the tracing.
|
115
|
+
"""
|
116
|
+
raise NotImplementedError
|
117
|
+
|
118
|
+
def compile(*args, **params):
|
108
119
|
"""Lower this function explicitly for the given arguments.
|
109
120
|
|
110
121
|
A lowered function is staged out of Python and translated to a
|
@@ -114,18 +125,13 @@ def _get_jitted_fun(
|
|
114
125
|
Returns:
|
115
126
|
A ``Lowered`` instance representing the lowering.
|
116
127
|
"""
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
"""Trace this function explicitly for the given arguments.
|
128
|
+
# compile the function and get the state trace
|
129
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
130
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
121
131
|
|
122
|
-
|
123
|
-
|
132
|
+
# call the jitted function
|
133
|
+
return jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
|
124
134
|
|
125
|
-
Returns:
|
126
|
-
A ``Traced`` instance representing the tracing.
|
127
|
-
"""
|
128
|
-
raise NotImplementedError
|
129
135
|
|
130
136
|
jitted_fun: JittedFunction
|
131
137
|
|
@@ -144,8 +150,8 @@ def _get_jitted_fun(
|
|
144
150
|
# evaluate the shape of the jitted function
|
145
151
|
jitted_fun.eval_shape = eval_shape
|
146
152
|
|
147
|
-
#
|
148
|
-
jitted_fun.
|
153
|
+
# compile the jitted function
|
154
|
+
jitted_fun.compile = compile
|
149
155
|
|
150
156
|
# trace the jitted
|
151
157
|
jitted_fun.trace = trace
|
{brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20241220
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -24,7 +24,7 @@ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nb
|
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
26
|
brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
|
27
|
-
brainstate/compile/_jit.py,sha256=
|
27
|
+
brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
29
|
brainstate/compile/_loop_collect_return.py,sha256=_iOVPytbctgyaIOxQZH3A2ZbsSoT7VXnFk6Q6R8-gvA,23360
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
|
|
137
137
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
138
138
|
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
139
139
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
144
|
-
brainstate-0.1.0.
|
140
|
+
brainstate-0.1.0.post20241220.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
141
|
+
brainstate-0.1.0.post20241220.dist-info/METADATA,sha256=lfkUbD1vYx4bikkUolrL_pCp4zryfvGExXyWP_QH5tM,3533
|
142
|
+
brainstate-0.1.0.post20241220.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
143
|
+
brainstate-0.1.0.post20241220.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
144
|
+
brainstate-0.1.0.post20241220.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20241219.dist-info → brainstate-0.1.0.post20241220.dist-info}/top_level.txt
RENAMED
File without changes
|