brainstate 0.1.0.post20241219__py2.py3-none-any.whl → 0.1.0.post20241220__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class JittedFunction(Callable):
40
40
  jitted_fun: jax.stages.Wrapped # the jitted function
41
41
  clear_cache: Callable # clear the cache of the jitted function
42
42
  eval_shape: Callable # evaluate the shape of the jitted function
43
- lower: Callable # lower the jitted function
43
+ compile: Callable # lower the jitted function
44
44
  trace: Callable # trace the jitted
45
45
 
46
46
  def __call__(self, *args, **kwargs):
@@ -104,7 +104,18 @@ def _get_jitted_fun(
104
104
  def eval_shape():
105
105
  raise NotImplementedError
106
106
 
107
- def lower():
107
+ def trace():
108
+ """Trace this function explicitly for the given arguments.
109
+
110
+ A traced function is staged out of Python and translated to a jaxpr. It is
111
+ ready for lowering but not yet lowered.
112
+
113
+ Returns:
114
+ A ``Traced`` instance representing the tracing.
115
+ """
116
+ raise NotImplementedError
117
+
118
+ def compile(*args, **params):
108
119
  """Lower this function explicitly for the given arguments.
109
120
 
110
121
  A lowered function is staged out of Python and translated to a
@@ -114,18 +125,13 @@ def _get_jitted_fun(
114
125
  Returns:
115
126
  A ``Lowered`` instance representing the lowering.
116
127
  """
117
- raise NotImplementedError
118
-
119
- def trace():
120
- """Trace this function explicitly for the given arguments.
128
+ # compile the function and get the state trace
129
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
130
+ read_state_vals = state_trace.get_read_state_values(True)
121
131
 
122
- A traced function is staged out of Python and translated to a jaxpr. It is
123
- ready for lowering but not yet lowered.
132
+ # call the jitted function
133
+ return jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
124
134
 
125
- Returns:
126
- A ``Traced`` instance representing the tracing.
127
- """
128
- raise NotImplementedError
129
135
 
130
136
  jitted_fun: JittedFunction
131
137
 
@@ -144,8 +150,8 @@ def _get_jitted_fun(
144
150
  # evaluate the shape of the jitted function
145
151
  jitted_fun.eval_shape = eval_shape
146
152
 
147
- # lower the jitted function
148
- jitted_fun.lower = lower
153
+ # compile the jitted function
154
+ jitted_fun.compile = compile
149
155
 
150
156
  # trace the jitted
151
157
  jitted_fun.trace = trace
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20241219
3
+ Version: 0.1.0.post20241220
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -24,7 +24,7 @@ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nb
24
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
25
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
26
26
  brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
27
- brainstate/compile/_jit.py,sha256=3mQ-RUFz35wceZyKE_MoR58OBL0RK_i6sHm4rWYzMLs,13698
27
+ brainstate/compile/_jit.py,sha256=3WBXNTALWPYC9rQH0TPH6w4bjG0BpnZt3RAzUQF5kkc,14045
28
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
29
29
  brainstate/compile/_loop_collect_return.py,sha256=_iOVPytbctgyaIOxQZH3A2ZbsSoT7VXnFk6Q6R8-gvA,23360
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
@@ -137,8 +137,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
137
137
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
138
138
  brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
139
139
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
140
- brainstate-0.1.0.post20241219.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
- brainstate-0.1.0.post20241219.dist-info/METADATA,sha256=nOo68Iuy-6B4po__cxVsBZlUEKbj0tZ1n0gq6x7MqLk,3533
142
- brainstate-0.1.0.post20241219.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
- brainstate-0.1.0.post20241219.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
- brainstate-0.1.0.post20241219.dist-info/RECORD,,
140
+ brainstate-0.1.0.post20241220.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
141
+ brainstate-0.1.0.post20241220.dist-info/METADATA,sha256=lfkUbD1vYx4bikkUolrL_pCp4zryfvGExXyWP_QH5tM,3533
142
+ brainstate-0.1.0.post20241220.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
143
+ brainstate-0.1.0.post20241220.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
144
+ brainstate-0.1.0.post20241220.dist-info/RECORD,,