jjinx 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinx/__init__.py ADDED
File without changes
jinx/errors.py ADDED
@@ -0,0 +1,41 @@
1
+ """J errors for incorrect usage of J primitives."""
2
+
3
+
4
+ class BaseJError(Exception):
5
+ pass
6
+
7
+
8
+ class LengthError(BaseJError):
9
+ pass
10
+
11
+
12
+ class DomainError(BaseJError):
13
+ pass
14
+
15
+
16
+ class ValenceError(BaseJError):
17
+ pass
18
+
19
+
20
+ class JIndexError(BaseJError):
21
+ pass
22
+
23
+
24
+ class SpellingError(BaseJError):
25
+ pass
26
+
27
+
28
+ class StackError(BaseJError):
29
+ pass
30
+
31
+
32
+ class EvaluationError(BaseJError):
33
+ pass
34
+
35
+
36
+ class JSyntaxError(BaseJError):
37
+ pass
38
+
39
+
40
+ class JinxNotImplementedError(BaseJError):
41
+ """Raised when a feature is not implemented in Jinx."""
@@ -0,0 +1,60 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable
3
+
4
+ from jinx.vocabulary import Adverb, Conjunction, Noun, Verb
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class Executor[T]:
9
+ apply_monad: Callable[[Verb[T], Noun[T]], Noun[T]]
10
+ """Apply monadic form of verb to a noun."""
11
+
12
+ apply_dyad: Callable[[Verb[T], Noun[T], Noun[T]], Noun[T]]
13
+ """Apply dyadic form of verb to two nouns."""
14
+
15
+ apply_conjunction: Callable[
16
+ [Verb[T] | Noun[T], Conjunction, Verb[T]], Verb[T] | Noun[T]
17
+ ]
18
+ """Apply conjunction to left and right arguments."""
19
+
20
+ apply_adverb: Callable[[Verb[T] | Noun[T], Adverb], Verb[T] | Noun[T]]
21
+ """Apply adverb to left argument."""
22
+
23
+ build_fork: Callable[[Noun[T] | Verb[T], Verb[T], Verb[T]], Verb[T]]
24
+ """Build fork."""
25
+
26
+ build_hook: Callable[[Verb[T], Verb[T]], Verb[T]]
27
+ """Build hook."""
28
+
29
+ ensure_noun_implementation: Callable[[Noun[T]], None]
30
+ """Ensure that the noun has an implementation."""
31
+
32
+ primitive_verb_map: dict[
33
+ str, tuple[Callable[[T], T] | None, Callable[[T, T], T] | None]
34
+ ]
35
+ """Map of primitive verb names to implementations of monad and dyad functions."""
36
+
37
+ primitive_adverb_map: dict[str, Callable[[Verb[T]], Verb[T]]]
38
+ """Map of primitive adverb names to implementation function."""
39
+
40
+ primitive_conjuction_map: dict[
41
+ str, Callable[[Verb[T] | Noun[T], Verb[T] | Noun[T]], Verb[T]]
42
+ ]
43
+ """Map of primitive conjunction names to implementation function."""
44
+
45
+ noun_to_string: Callable[[Noun[T]], str]
46
+ """Convert a noun to a string representation for printing."""
47
+
48
+
49
+ def load_executor(name: str) -> Executor:
50
+ if name == "numpy":
51
+ from jinx.execution.numpy import executor as numpy_executor
52
+
53
+ return numpy_executor
54
+
55
+ if name == "jax":
56
+ from jinx.execution.jax import executor as jax_executor
57
+
58
+ return jax_executor
59
+
60
+ raise ValueError(f"Unknown executor: {name}")
@@ -0,0 +1,45 @@
1
+ import jax
2
+ from jinx.errors import JinxNotImplementedError
3
+ from jinx.execution.executor import Executor
4
+ from jinx.execution.jax.adverbs import ADVERB_MAP
5
+ from jinx.execution.jax.application import (
6
+ apply_adverb,
7
+ # apply_conjunction,
8
+ apply_dyad,
9
+ apply_monad,
10
+ build_fork,
11
+ # build_hook,
12
+ ensure_noun_implementation,
13
+ )
14
+ from jinx.execution.jax.verbs import VERB_MAP
15
+
16
+ # from jinx.execution.numpy.conjunctions import CONJUNCTION_MAP
17
+ # from jinx.execution.numpy.conversion import ensure_noun_implementation
18
+ from jinx.execution.numpy.printing import noun_to_string
19
+
20
+ jax.config.update("jax_dynamic_shapes", True)
21
+
22
+
23
+ def make_not_implemented(name: str):
24
+ def _not_implemented(*_, **__):
25
+ raise JinxNotImplementedError(
26
+ f"{name}: not yet implemented in the JAX executor."
27
+ )
28
+
29
+ return _not_implemented
30
+
31
+
32
+ executor = Executor[jax.Array](
33
+ apply_monad=apply_monad,
34
+ apply_dyad=apply_dyad,
35
+ apply_conjunction=make_not_implemented("conjunction"),
36
+ apply_adverb=apply_adverb,
37
+ build_fork=build_fork,
38
+ build_hook=make_not_implemented("hook"),
39
+ ensure_noun_implementation=ensure_noun_implementation,
40
+ primitive_verb_map=VERB_MAP,
41
+ primitive_adverb_map=ADVERB_MAP,
42
+ primitive_conjuction_map={},
43
+ # Just use the NumPy implementation of printing.
44
+ noun_to_string=noun_to_string, # type: ignore[arg-type]
45
+ )
@@ -0,0 +1,91 @@
1
+ """Methods implementing J adverbs."""
2
+
3
+ import functools
4
+ from typing import Callable
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jinx.errors import JinxNotImplementedError, ValenceError
9
+ from jinx.execution.jax.application import _apply_dyad
10
+ from jinx.execution.numpy.helpers import (
11
+ maybe_parenthesise_verb_spelling,
12
+ )
13
+ from jinx.vocabulary import Dyad, Monad, Verb
14
+
15
+ INFINITY = float("inf")
16
+
17
+
18
+ def slash_adverb(verb: Verb[jax.Array]) -> Verb[jax.Array]:
19
+ if verb.dyad is None or verb.dyad.function is None:
20
+ # Note: this differs from J which still allows the adverb to be applied
21
+ # to a verb, but may raise an error when the new verb is applied to a noun
22
+ # and the verb has no dyadic valence.
23
+ raise ValenceError(f"Verb {verb.spelling} has no dyadic valence.")
24
+
25
+ if isinstance(verb.dyad.function, jnp.ufunc) and verb.dyad.is_commutative:
26
+ monad = verb.dyad.function.reduce
27
+ dyad = verb.dyad.function.outer
28
+
29
+ else:
30
+ # Slow path: dyad is not a ufunc.
31
+ # The function is either callable, in which cases it is applied directly,
32
+ # or a Verb object that needs to be applied indirectly with _apply_dyad().
33
+ if isinstance(verb.dyad.function, Verb):
34
+ func = functools.partial(_apply_dyad, verb) # type: ignore[assignment]
35
+ else:
36
+ func = verb.dyad.function # type: ignore[assignment]
37
+
38
+ def _dyad_arg_swap(x: jax.Array, y: jax.Array) -> jax.Array:
39
+ return func(y, x)
40
+
41
+ def _reduce(y: jax.Array) -> jax.Array:
42
+ y = jnp.atleast_1d(y)
43
+ y = jnp.flip(y, axis=0)
44
+ return functools.reduce(_dyad_arg_swap, y)
45
+
46
+ monad = _reduce # type: ignore[assignment]
47
+ dyad = NotImplemented
48
+
49
+ spelling = maybe_parenthesise_verb_spelling(verb.spelling)
50
+ spelling = f"{verb.spelling}/"
51
+
52
+ return Verb[jax.Array](
53
+ name=spelling,
54
+ spelling=spelling,
55
+ monad=Monad(name=spelling, rank=INFINITY, function=monad),
56
+ dyad=Dyad(
57
+ name=spelling, left_rank=INFINITY, right_rank=INFINITY, function=dyad
58
+ ),
59
+ )
60
+
61
+
62
+ def bslash_adverb(verb: Verb[jax.Array]) -> Verb[jax.Array]:
63
+ # Common cases that have a straightforward optimisation.
64
+ SPECIAL_MONAD = {
65
+ "+/": jnp.cumulative_sum,
66
+ "*/": jnp.cumulative_prod,
67
+ }
68
+
69
+ if verb.spelling in SPECIAL_MONAD:
70
+ monad_ = SPECIAL_MONAD[verb.spelling]
71
+
72
+ else:
73
+ raise JinxNotImplementedError(
74
+ f"Adverb \\ applied to verb {verb.spelling} is not yet implemented."
75
+ )
76
+
77
+ spelling = maybe_parenthesise_verb_spelling(verb.spelling)
78
+ spelling = f"{spelling}\\"
79
+
80
+ return Verb(
81
+ name=spelling,
82
+ spelling=spelling,
83
+ monad=Monad(name=spelling, rank=INFINITY, function=monad_),
84
+ dyad=Dyad(name=spelling, left_rank=0, right_rank=INFINITY, function=None), # type: ignore[arg-type]
85
+ )
86
+
87
+
88
+ ADVERB_MAP: dict[str, Callable[[Verb[jax.Array]], Verb[jax.Array]]] = {
89
+ "SLASH": slash_adverb,
90
+ "BSLASH": bslash_adverb,
91
+ }
@@ -0,0 +1,231 @@
1
+ import functools
2
+ from dataclasses import dataclass
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jinx.errors import JinxNotImplementedError, ValenceError
7
+ from jinx.vocabulary import Adverb, DataType, Dyad, Monad, Noun, RankT, Verb
8
+
9
+ DATATYPE_TO_NP_MAP = {
10
+ # JAX requires support for int64 to be set via config.
11
+ DataType.Integer: jnp.int32,
12
+ DataType.Float: jnp.float64,
13
+ }
14
+
15
+
16
+ INFINITY = float("inf")
17
+
18
+
19
+ def get_rank(verb_rank: RankT, noun_rank: int) -> int:
20
+ """Get the rank at which to apply the verb to the noun.
21
+
22
+ If the verb rank is negative, it means that the verb rank is subtracted
23
+ from the noun rank, to a minimum of 0.
24
+ """
25
+ if verb_rank < 0:
26
+ return max(0, noun_rank + verb_rank) # type: ignore[return-value]
27
+ return min(verb_rank, noun_rank) # type: ignore[return-value]
28
+
29
+
30
+ @dataclass
31
+ class ArrayCells:
32
+ cell_shape: tuple[int, ...]
33
+ frame_shape: tuple[int, ...]
34
+ cells: jax.Array
35
+
36
+
37
+ def split_into_cells(arr: jax.Array, rank: int) -> ArrayCells:
38
+ """
39
+ Look at the array shape and rank to determine frame and cell shape.
40
+
41
+ The trailing `rank` axes define the cell shape and the preceding
42
+ axes define the frame shape. E.g. for rank=2:
43
+
44
+ arr.shape = (n0, n1, n2, n3, n4)
45
+ ---------- ------
46
+ ^ frame ^ cell
47
+
48
+ If rank=0, the frame shape is the same as the shape and the monad
49
+ applies to each atom of the array.
50
+ """
51
+ if rank == 0:
52
+ return ArrayCells(cell_shape=(), frame_shape=arr.shape, cells=arr.ravel())
53
+
54
+ return ArrayCells(
55
+ cell_shape=arr.shape[-rank:],
56
+ frame_shape=arr.shape[:-rank],
57
+ cells=arr.reshape(-1, *arr.shape[-rank:]),
58
+ )
59
+
60
+
61
+ def infer_data_type(data: jax.Array) -> DataType:
62
+ dtype = data.dtype
63
+ if jnp.issubdtype(dtype, jnp.integer) or jnp.issubdtype(dtype, jnp.bool_):
64
+ return DataType.Integer
65
+ if jnp.issubdtype(dtype, jnp.floating):
66
+ return DataType.Float
67
+
68
+ raise NotImplementedError(f"Cannot handle JAX dtype: {dtype}")
69
+
70
+
71
+ def convert_noun_to_jax_array(noun: Noun[jax.Array]) -> jax.Array:
72
+ dtype = DATATYPE_TO_NP_MAP[noun.data_type]
73
+ if len(noun.data) == 1:
74
+ # A scalar (ndim == 0) is returned for single element arrays.
75
+ return jnp.array(noun.data[0], dtype=dtype) # type: ignore[call-overload]
76
+ return jnp.array(noun.data, dtype=dtype) # type: ignore[call-overload]
77
+
78
+
79
+ def ensure_noun_implementation(noun: Noun[jax.Array]) -> None:
80
+ if noun.implementation is None:
81
+ noun.implementation = convert_noun_to_jax_array(noun)
82
+
83
+
84
+ def jax_array_to_noun(array: jax.Array) -> Noun[jax.Array]:
85
+ data_type = infer_data_type(array)
86
+ return Noun[jax.Array](data_type=data_type, implementation=array)
87
+
88
+
89
+ def apply_monad(verb: Verb[jax.Array], noun: Noun[jax.Array]) -> Noun[jax.Array]:
90
+ result = _apply_monad(verb, noun.implementation)
91
+ return jax_array_to_noun(result)
92
+
93
+
94
+ def _apply_monad(verb: Verb[jax.Array], arr: jax.Array) -> jax.Array:
95
+ if verb.monad is None or verb.monad.function is None:
96
+ raise ValenceError(f"Verb {verb.spelling} has no monadic valence.")
97
+ if verb.monad.function is NotImplemented:
98
+ raise JinxNotImplementedError(
99
+ f"Verb {verb.spelling} monad function is not yet implemented in Jinx."
100
+ )
101
+
102
+ if isinstance(verb.monad.function, Verb):
103
+ function = functools.partial(_apply_monad, verb.monad.function)
104
+ else:
105
+ function = verb.monad.function # type: ignore[assignment]
106
+
107
+ rank = get_rank(verb.monad.rank, arr.ndim)
108
+
109
+ if rank == 0:
110
+ return function(arr)
111
+
112
+ array_cells = split_into_cells(arr, rank)
113
+
114
+ if array_cells.cells.ndim == 1:
115
+ cells = function(array_cells.cells)
116
+ else:
117
+ # TODO: Use jax.vmap instead
118
+ cells = jnp.asarray([function(cell) for cell in array_cells.cells])
119
+
120
+ # No filling/padding for now...
121
+ return jnp.asarray(cells).reshape(array_cells.frame_shape + cells[0].shape)
122
+
123
+
124
+ def apply_dyad(
125
+ verb: Verb[jax.Array], noun_1: Noun[jax.Array], noun_2: Noun[jax.Array]
126
+ ) -> Noun[jax.Array]:
127
+ result = _apply_dyad(verb, noun_1.implementation, noun_2.implementation)
128
+ return jax_array_to_noun(result)
129
+
130
+
131
+ def _apply_dyad(
132
+ verb: Verb[jax.Array], left_arr: jax.Array, right_arr: jax.Array
133
+ ) -> jax.Array:
134
+ if verb.dyad is None or verb.dyad.function is None:
135
+ raise ValenceError(f"Verb {verb.spelling} has no dyadic valence.")
136
+ if verb.dyad.function is NotImplemented:
137
+ raise JinxNotImplementedError(
138
+ f"Verb {verb.spelling} dyad function is not yet implemented."
139
+ )
140
+
141
+ if isinstance(verb.dyad.function, Verb):
142
+ function = functools.partial(_apply_dyad, verb.dyad.function)
143
+ else:
144
+ function = verb.dyad.function # type: ignore[assignment]
145
+
146
+ left_rank = get_rank(verb.dyad.left_rank, left_arr.ndim)
147
+ right_rank = get_rank(verb.dyad.right_rank, right_arr.ndim)
148
+
149
+ if left_rank == right_rank == 0 and (left_arr.ndim == 0 or right_arr.ndim == 0):
150
+ return function(left_arr, right_arr)
151
+
152
+ left = split_into_cells(left_arr, left_rank)
153
+ right = split_into_cells(right_arr, right_rank)
154
+
155
+ if left.frame_shape == right.frame_shape:
156
+ cells = [
157
+ function(left_cell, right_cell)
158
+ for left_cell, right_cell in zip(left.cells, right.cells, strict=True)
159
+ ]
160
+ return jnp.asarray(cells).reshape(left.frame_shape + cells[0].shape)
161
+
162
+ raise JinxNotImplementedError(
163
+ "Dyadic verbs with non-zero rank and different frame shape are not yet implemented."
164
+ )
165
+
166
+
167
+ def apply_adverb(verb_or_noun: Verb | Noun, adverb: Adverb) -> Verb:
168
+ return adverb.function(verb_or_noun)
169
+
170
+
171
+ def build_fork(
172
+ f: Verb[jax.Array] | Noun[jax.Array], g: Verb[jax.Array], h: Verb[jax.Array]
173
+ ) -> Verb[jax.Array]:
174
+ """Build a fork given verbs f, g, h.
175
+
176
+ (f g h) y -> (f y) g (h y)
177
+ x (f g h) y -> (x f y) g (x h y)
178
+
179
+ The new verb has infinite rank.
180
+
181
+ Note that f can be a noun, in which case there is one fewer function calls.
182
+ """
183
+
184
+ def _monad(y: jax.Array) -> jax.Array:
185
+ if isinstance(f, Verb) and f.spelling == "[:":
186
+ hy = _apply_monad(h, y)
187
+ return _apply_monad(g, hy)
188
+
189
+ if isinstance(f, Verb):
190
+ a = _apply_monad(f, y)
191
+ else:
192
+ a = f.implementation
193
+ b = _apply_monad(h, y)
194
+ return _apply_dyad(g, a, b)
195
+
196
+ def _dyad(x: jax.Array, y: jax.Array) -> jax.Array:
197
+ if isinstance(f, Verb) and f.spelling == "[:":
198
+ hy = _apply_dyad(h, x, y)
199
+ return _apply_monad(g, hy)
200
+
201
+ if isinstance(f, Verb):
202
+ a = _apply_dyad(f, x, y)
203
+ else:
204
+ a = f.implementation
205
+ b = _apply_dyad(h, x, y)
206
+ return _apply_dyad(g, a, b)
207
+
208
+ if isinstance(f, Verb):
209
+ f_spelling = f"({f.spelling})" if " " in f.spelling else f.spelling
210
+ else:
211
+ f_spelling = str(f.implementation)
212
+
213
+ g_spelling = f"({g.spelling})" if " " in g.spelling else g.spelling
214
+ h_spelling = f"({h.spelling})" if " " in h.spelling else h.spelling
215
+ spelling = f"{f_spelling} {g_spelling} {h_spelling}"
216
+
217
+ return Verb[jax.Array](
218
+ spelling=spelling,
219
+ name=spelling,
220
+ monad=Monad(
221
+ name=spelling,
222
+ rank=INFINITY,
223
+ function=_monad,
224
+ ),
225
+ dyad=Dyad(
226
+ name=spelling,
227
+ left_rank=INFINITY,
228
+ right_rank=INFINITY,
229
+ function=_dyad,
230
+ ),
231
+ )
@@ -0,0 +1,90 @@
1
+ from typing import Callable
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ MonadT = Callable[[jax.Array], jax.Array]
7
+ DyadT = Callable[[jax.Array, jax.Array], jax.Array]
8
+
9
+
10
+ @jax.jit
11
+ def number_monad(y: jax.Array) -> jax.Array:
12
+ """# monad: count number of items in y."""
13
+ if jnp.isscalar(y) or y.shape == ():
14
+ return jnp.array(1)
15
+ return jnp.array(y.shape[0])
16
+
17
+
18
+ @jax.jit
19
+ def percentco_dyad(x: jax.Array, y: jax.Array) -> jax.Array:
20
+ return jnp.power(y, 1 / x)
21
+
22
+
23
+ @jax.jit
24
+ def plusdot_monad(y: jax.Array) -> jax.Array:
25
+ """+. monad: returns real and imaginary parts of numbers."""
26
+ return jnp.stack([jnp.real(y), jnp.imag(y)], axis=-1)
27
+
28
+
29
+ @jax.jit
30
+ def plusco_monad(y: jax.Array) -> jax.Array:
31
+ """+: monad: double the values in the array."""
32
+ return 2 * y
33
+
34
+
35
+ @jax.jit
36
+ def minusdot_monad(y: jax.Array) -> jax.Array:
37
+ """-.: monad: returns 1 - y."""
38
+ return 1 - y
39
+
40
+
41
+ @jax.jit
42
+ def minusco_monad(y: jax.Array) -> jax.Array:
43
+ """-: monad: halve the values in the array."""
44
+ return y / 2
45
+
46
+
47
+ @jax.jit
48
+ def minusco_dyad(x: jax.Array, y: jax.Array) -> jax.Array:
49
+ """-: dyad: match, returns true if x and y have same shape and values."""
50
+ is_equal = jnp.array_equal(x, y, equal_nan=True)
51
+ return jnp.asarray(is_equal)
52
+
53
+
54
+ @jax.jit
55
+ def hatdot_dyad(x: jax.Array, y: jax.Array) -> jax.Array:
56
+ """^. dyad: logarithm of y to the base x."""
57
+ return jnp.log(y) / jnp.log(x)
58
+
59
+
60
+ @jax.jit
61
+ def ltco_monad(y: jax.Array) -> jax.Array:
62
+ """<: monad: decrements the array."""
63
+ return y - 1
64
+
65
+
66
+ @jax.jit
67
+ def gtco_monad(y: jax.Array) -> jax.Array:
68
+ """>: monad: increments the array."""
69
+ return y + 1
70
+
71
+
72
+ VERB_MAP: dict[str, tuple[MonadT | None, DyadT | None]] = {
73
+ # VERB: (MONAD, DYAD)
74
+ "PLUS": (jnp.conj, jnp.add),
75
+ "PLUSDOT": (plusdot_monad, NotImplemented),
76
+ "PLUSCO": (plusco_monad, NotImplemented),
77
+ "MINUS": (jnp.negative, jnp.subtract),
78
+ "MINUSDOT": (minusdot_monad, NotImplemented),
79
+ "MINUSCO": (minusco_monad, minusco_dyad),
80
+ "STAR": (jnp.sign, jnp.multiply),
81
+ "PERCENT": (jnp.reciprocal, jnp.divide),
82
+ "HAT": (jnp.exp, jnp.power),
83
+ "HATDOT": (jnp.log, hatdot_dyad),
84
+ "LTDOT": (jnp.floor, jnp.minimum),
85
+ "GTDOT": (jnp.ceil, jnp.maximum),
86
+ "LTCO": (ltco_monad, jnp.less_equal),
87
+ "GTCO": (gtco_monad, jnp.greater_equal),
88
+ "NUMBER": (number_monad, NotImplemented),
89
+ "BAR": (jnp.abs, NotImplemented),
90
+ }
@@ -0,0 +1,29 @@
1
+ import numpy as np
2
+ from jinx.execution.executor import Executor
3
+ from jinx.execution.numpy.adverbs import ADVERB_MAP
4
+ from jinx.execution.numpy.application import (
5
+ apply_adverb,
6
+ apply_conjunction,
7
+ apply_dyad,
8
+ apply_monad,
9
+ build_fork,
10
+ build_hook,
11
+ )
12
+ from jinx.execution.numpy.conjunctions import CONJUNCTION_MAP
13
+ from jinx.execution.numpy.conversion import ensure_noun_implementation
14
+ from jinx.execution.numpy.printing import noun_to_string
15
+ from jinx.execution.numpy.verbs import VERB_MAP
16
+
17
+ executor = Executor[np.ndarray](
18
+ apply_monad=apply_monad,
19
+ apply_dyad=apply_dyad,
20
+ apply_conjunction=apply_conjunction,
21
+ apply_adverb=apply_adverb,
22
+ build_fork=build_fork,
23
+ build_hook=build_hook,
24
+ ensure_noun_implementation=ensure_noun_implementation,
25
+ primitive_verb_map=VERB_MAP,
26
+ primitive_adverb_map=ADVERB_MAP,
27
+ primitive_conjuction_map=CONJUNCTION_MAP, # type: ignore[arg-type]
28
+ noun_to_string=noun_to_string,
29
+ )