jjinx 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinx/__init__.py +0 -0
- jinx/errors.py +41 -0
- jinx/execution/executor.py +60 -0
- jinx/execution/jax/__init__.py +45 -0
- jinx/execution/jax/adverbs.py +91 -0
- jinx/execution/jax/application.py +231 -0
- jinx/execution/jax/verbs.py +90 -0
- jinx/execution/numpy/__init__.py +29 -0
- jinx/execution/numpy/adverbs.py +334 -0
- jinx/execution/numpy/application.py +343 -0
- jinx/execution/numpy/conjunctions.py +437 -0
- jinx/execution/numpy/conversion.py +62 -0
- jinx/execution/numpy/helpers.py +158 -0
- jinx/execution/numpy/printing.py +179 -0
- jinx/execution/numpy/verbs.py +850 -0
- jinx/primitives.py +490 -0
- jinx/shell.py +68 -0
- jinx/vocabulary.py +181 -0
- jinx/word_evaluation.py +375 -0
- jinx/word_formation.py +229 -0
- jinx/word_spelling.py +118 -0
- jjinx-0.0.1.dist-info/METADATA +148 -0
- jjinx-0.0.1.dist-info/RECORD +27 -0
- jjinx-0.0.1.dist-info/WHEEL +5 -0
- jjinx-0.0.1.dist-info/entry_points.txt +2 -0
- jjinx-0.0.1.dist-info/licenses/LICENSE +21 -0
- jjinx-0.0.1.dist-info/top_level.txt +1 -0
jinx/__init__.py
ADDED
File without changes
|
jinx/errors.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
"""J errors for incorrect usage of J primitives."""
|
2
|
+
|
3
|
+
|
4
|
+
class BaseJError(Exception):
|
5
|
+
pass
|
6
|
+
|
7
|
+
|
8
|
+
class LengthError(BaseJError):
|
9
|
+
pass
|
10
|
+
|
11
|
+
|
12
|
+
class DomainError(BaseJError):
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class ValenceError(BaseJError):
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
class JIndexError(BaseJError):
|
21
|
+
pass
|
22
|
+
|
23
|
+
|
24
|
+
class SpellingError(BaseJError):
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
class StackError(BaseJError):
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class EvaluationError(BaseJError):
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class JSyntaxError(BaseJError):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class JinxNotImplementedError(BaseJError):
|
41
|
+
"""Raised when a feature is not implemented in Jinx."""
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
from jinx.vocabulary import Adverb, Conjunction, Noun, Verb
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass(frozen=True)
|
8
|
+
class Executor[T]:
|
9
|
+
apply_monad: Callable[[Verb[T], Noun[T]], Noun[T]]
|
10
|
+
"""Apply monadic form of verb to a noun."""
|
11
|
+
|
12
|
+
apply_dyad: Callable[[Verb[T], Noun[T], Noun[T]], Noun[T]]
|
13
|
+
"""Apply dyadic form of verb to two nouns."""
|
14
|
+
|
15
|
+
apply_conjunction: Callable[
|
16
|
+
[Verb[T] | Noun[T], Conjunction, Verb[T]], Verb[T] | Noun[T]
|
17
|
+
]
|
18
|
+
"""Apply conjunction to left and right arguments."""
|
19
|
+
|
20
|
+
apply_adverb: Callable[[Verb[T] | Noun[T], Adverb], Verb[T] | Noun[T]]
|
21
|
+
"""Apply adverb to left argument."""
|
22
|
+
|
23
|
+
build_fork: Callable[[Noun[T] | Verb[T], Verb[T], Verb[T]], Verb[T]]
|
24
|
+
"""Build fork."""
|
25
|
+
|
26
|
+
build_hook: Callable[[Verb[T], Verb[T]], Verb[T]]
|
27
|
+
"""Build hook."""
|
28
|
+
|
29
|
+
ensure_noun_implementation: Callable[[Noun[T]], None]
|
30
|
+
"""Ensure that the noun has an implementation."""
|
31
|
+
|
32
|
+
primitive_verb_map: dict[
|
33
|
+
str, tuple[Callable[[T], T] | None, Callable[[T, T], T] | None]
|
34
|
+
]
|
35
|
+
"""Map of primitive verb names to implementations of monad and dyad functions."""
|
36
|
+
|
37
|
+
primitive_adverb_map: dict[str, Callable[[Verb[T]], Verb[T]]]
|
38
|
+
"""Map of primitive adverb names to implementation function."""
|
39
|
+
|
40
|
+
primitive_conjuction_map: dict[
|
41
|
+
str, Callable[[Verb[T] | Noun[T], Verb[T] | Noun[T]], Verb[T]]
|
42
|
+
]
|
43
|
+
"""Map of primitive conjunction names to implementation function."""
|
44
|
+
|
45
|
+
noun_to_string: Callable[[Noun[T]], str]
|
46
|
+
"""Convert a noun to a string representation for printing."""
|
47
|
+
|
48
|
+
|
49
|
+
def load_executor(name: str) -> Executor:
|
50
|
+
if name == "numpy":
|
51
|
+
from jinx.execution.numpy import executor as numpy_executor
|
52
|
+
|
53
|
+
return numpy_executor
|
54
|
+
|
55
|
+
if name == "jax":
|
56
|
+
from jinx.execution.jax import executor as jax_executor
|
57
|
+
|
58
|
+
return jax_executor
|
59
|
+
|
60
|
+
raise ValueError(f"Unknown executor: {name}")
|
@@ -0,0 +1,45 @@
|
|
1
|
+
import jax
|
2
|
+
from jinx.errors import JinxNotImplementedError
|
3
|
+
from jinx.execution.executor import Executor
|
4
|
+
from jinx.execution.jax.adverbs import ADVERB_MAP
|
5
|
+
from jinx.execution.jax.application import (
|
6
|
+
apply_adverb,
|
7
|
+
# apply_conjunction,
|
8
|
+
apply_dyad,
|
9
|
+
apply_monad,
|
10
|
+
build_fork,
|
11
|
+
# build_hook,
|
12
|
+
ensure_noun_implementation,
|
13
|
+
)
|
14
|
+
from jinx.execution.jax.verbs import VERB_MAP
|
15
|
+
|
16
|
+
# from jinx.execution.numpy.conjunctions import CONJUNCTION_MAP
|
17
|
+
# from jinx.execution.numpy.conversion import ensure_noun_implementation
|
18
|
+
from jinx.execution.numpy.printing import noun_to_string
|
19
|
+
|
20
|
+
jax.config.update("jax_dynamic_shapes", True)
|
21
|
+
|
22
|
+
|
23
|
+
def make_not_implemented(name: str):
|
24
|
+
def _not_implemented(*_, **__):
|
25
|
+
raise JinxNotImplementedError(
|
26
|
+
f"{name}: not yet implemented in the JAX executor."
|
27
|
+
)
|
28
|
+
|
29
|
+
return _not_implemented
|
30
|
+
|
31
|
+
|
32
|
+
executor = Executor[jax.Array](
|
33
|
+
apply_monad=apply_monad,
|
34
|
+
apply_dyad=apply_dyad,
|
35
|
+
apply_conjunction=make_not_implemented("conjunction"),
|
36
|
+
apply_adverb=apply_adverb,
|
37
|
+
build_fork=build_fork,
|
38
|
+
build_hook=make_not_implemented("hook"),
|
39
|
+
ensure_noun_implementation=ensure_noun_implementation,
|
40
|
+
primitive_verb_map=VERB_MAP,
|
41
|
+
primitive_adverb_map=ADVERB_MAP,
|
42
|
+
primitive_conjuction_map={},
|
43
|
+
# Just use the NumPy implementation of printing.
|
44
|
+
noun_to_string=noun_to_string, # type: ignore[arg-type]
|
45
|
+
)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
"""Methods implementing J adverbs."""
|
2
|
+
|
3
|
+
import functools
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from jinx.errors import JinxNotImplementedError, ValenceError
|
9
|
+
from jinx.execution.jax.application import _apply_dyad
|
10
|
+
from jinx.execution.numpy.helpers import (
|
11
|
+
maybe_parenthesise_verb_spelling,
|
12
|
+
)
|
13
|
+
from jinx.vocabulary import Dyad, Monad, Verb
|
14
|
+
|
15
|
+
INFINITY = float("inf")
|
16
|
+
|
17
|
+
|
18
|
+
def slash_adverb(verb: Verb[jax.Array]) -> Verb[jax.Array]:
|
19
|
+
if verb.dyad is None or verb.dyad.function is None:
|
20
|
+
# Note: this differs from J which still allows the adverb to be applied
|
21
|
+
# to a verb, but may raise an error when the new verb is applied to a noun
|
22
|
+
# and the verb has no dyadic valence.
|
23
|
+
raise ValenceError(f"Verb {verb.spelling} has no dyadic valence.")
|
24
|
+
|
25
|
+
if isinstance(verb.dyad.function, jnp.ufunc) and verb.dyad.is_commutative:
|
26
|
+
monad = verb.dyad.function.reduce
|
27
|
+
dyad = verb.dyad.function.outer
|
28
|
+
|
29
|
+
else:
|
30
|
+
# Slow path: dyad is not a ufunc.
|
31
|
+
# The function is either callable, in which cases it is applied directly,
|
32
|
+
# or a Verb object that needs to be applied indirectly with _apply_dyad().
|
33
|
+
if isinstance(verb.dyad.function, Verb):
|
34
|
+
func = functools.partial(_apply_dyad, verb) # type: ignore[assignment]
|
35
|
+
else:
|
36
|
+
func = verb.dyad.function # type: ignore[assignment]
|
37
|
+
|
38
|
+
def _dyad_arg_swap(x: jax.Array, y: jax.Array) -> jax.Array:
|
39
|
+
return func(y, x)
|
40
|
+
|
41
|
+
def _reduce(y: jax.Array) -> jax.Array:
|
42
|
+
y = jnp.atleast_1d(y)
|
43
|
+
y = jnp.flip(y, axis=0)
|
44
|
+
return functools.reduce(_dyad_arg_swap, y)
|
45
|
+
|
46
|
+
monad = _reduce # type: ignore[assignment]
|
47
|
+
dyad = NotImplemented
|
48
|
+
|
49
|
+
spelling = maybe_parenthesise_verb_spelling(verb.spelling)
|
50
|
+
spelling = f"{verb.spelling}/"
|
51
|
+
|
52
|
+
return Verb[jax.Array](
|
53
|
+
name=spelling,
|
54
|
+
spelling=spelling,
|
55
|
+
monad=Monad(name=spelling, rank=INFINITY, function=monad),
|
56
|
+
dyad=Dyad(
|
57
|
+
name=spelling, left_rank=INFINITY, right_rank=INFINITY, function=dyad
|
58
|
+
),
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
def bslash_adverb(verb: Verb[jax.Array]) -> Verb[jax.Array]:
|
63
|
+
# Common cases that have a straightforward optimisation.
|
64
|
+
SPECIAL_MONAD = {
|
65
|
+
"+/": jnp.cumulative_sum,
|
66
|
+
"*/": jnp.cumulative_prod,
|
67
|
+
}
|
68
|
+
|
69
|
+
if verb.spelling in SPECIAL_MONAD:
|
70
|
+
monad_ = SPECIAL_MONAD[verb.spelling]
|
71
|
+
|
72
|
+
else:
|
73
|
+
raise JinxNotImplementedError(
|
74
|
+
f"Adverb \\ applied to verb {verb.spelling} is not yet implemented."
|
75
|
+
)
|
76
|
+
|
77
|
+
spelling = maybe_parenthesise_verb_spelling(verb.spelling)
|
78
|
+
spelling = f"{spelling}\\"
|
79
|
+
|
80
|
+
return Verb(
|
81
|
+
name=spelling,
|
82
|
+
spelling=spelling,
|
83
|
+
monad=Monad(name=spelling, rank=INFINITY, function=monad_),
|
84
|
+
dyad=Dyad(name=spelling, left_rank=0, right_rank=INFINITY, function=None), # type: ignore[arg-type]
|
85
|
+
)
|
86
|
+
|
87
|
+
|
88
|
+
ADVERB_MAP: dict[str, Callable[[Verb[jax.Array]], Verb[jax.Array]]] = {
|
89
|
+
"SLASH": slash_adverb,
|
90
|
+
"BSLASH": bslash_adverb,
|
91
|
+
}
|
@@ -0,0 +1,231 @@
|
|
1
|
+
import functools
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jinx.errors import JinxNotImplementedError, ValenceError
|
7
|
+
from jinx.vocabulary import Adverb, DataType, Dyad, Monad, Noun, RankT, Verb
|
8
|
+
|
9
|
+
DATATYPE_TO_NP_MAP = {
|
10
|
+
# JAX requires support for int64 to be set via config.
|
11
|
+
DataType.Integer: jnp.int32,
|
12
|
+
DataType.Float: jnp.float64,
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
INFINITY = float("inf")
|
17
|
+
|
18
|
+
|
19
|
+
def get_rank(verb_rank: RankT, noun_rank: int) -> int:
|
20
|
+
"""Get the rank at which to apply the verb to the noun.
|
21
|
+
|
22
|
+
If the verb rank is negative, it means that the verb rank is subtracted
|
23
|
+
from the noun rank, to a minimum of 0.
|
24
|
+
"""
|
25
|
+
if verb_rank < 0:
|
26
|
+
return max(0, noun_rank + verb_rank) # type: ignore[return-value]
|
27
|
+
return min(verb_rank, noun_rank) # type: ignore[return-value]
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class ArrayCells:
|
32
|
+
cell_shape: tuple[int, ...]
|
33
|
+
frame_shape: tuple[int, ...]
|
34
|
+
cells: jax.Array
|
35
|
+
|
36
|
+
|
37
|
+
def split_into_cells(arr: jax.Array, rank: int) -> ArrayCells:
|
38
|
+
"""
|
39
|
+
Look at the array shape and rank to determine frame and cell shape.
|
40
|
+
|
41
|
+
The trailing `rank` axes define the cell shape and the preceding
|
42
|
+
axes define the frame shape. E.g. for rank=2:
|
43
|
+
|
44
|
+
arr.shape = (n0, n1, n2, n3, n4)
|
45
|
+
---------- ------
|
46
|
+
^ frame ^ cell
|
47
|
+
|
48
|
+
If rank=0, the frame shape is the same as the shape and the monad
|
49
|
+
applies to each atom of the array.
|
50
|
+
"""
|
51
|
+
if rank == 0:
|
52
|
+
return ArrayCells(cell_shape=(), frame_shape=arr.shape, cells=arr.ravel())
|
53
|
+
|
54
|
+
return ArrayCells(
|
55
|
+
cell_shape=arr.shape[-rank:],
|
56
|
+
frame_shape=arr.shape[:-rank],
|
57
|
+
cells=arr.reshape(-1, *arr.shape[-rank:]),
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def infer_data_type(data: jax.Array) -> DataType:
|
62
|
+
dtype = data.dtype
|
63
|
+
if jnp.issubdtype(dtype, jnp.integer) or jnp.issubdtype(dtype, jnp.bool_):
|
64
|
+
return DataType.Integer
|
65
|
+
if jnp.issubdtype(dtype, jnp.floating):
|
66
|
+
return DataType.Float
|
67
|
+
|
68
|
+
raise NotImplementedError(f"Cannot handle JAX dtype: {dtype}")
|
69
|
+
|
70
|
+
|
71
|
+
def convert_noun_to_jax_array(noun: Noun[jax.Array]) -> jax.Array:
|
72
|
+
dtype = DATATYPE_TO_NP_MAP[noun.data_type]
|
73
|
+
if len(noun.data) == 1:
|
74
|
+
# A scalar (ndim == 0) is returned for single element arrays.
|
75
|
+
return jnp.array(noun.data[0], dtype=dtype) # type: ignore[call-overload]
|
76
|
+
return jnp.array(noun.data, dtype=dtype) # type: ignore[call-overload]
|
77
|
+
|
78
|
+
|
79
|
+
def ensure_noun_implementation(noun: Noun[jax.Array]) -> None:
|
80
|
+
if noun.implementation is None:
|
81
|
+
noun.implementation = convert_noun_to_jax_array(noun)
|
82
|
+
|
83
|
+
|
84
|
+
def jax_array_to_noun(array: jax.Array) -> Noun[jax.Array]:
|
85
|
+
data_type = infer_data_type(array)
|
86
|
+
return Noun[jax.Array](data_type=data_type, implementation=array)
|
87
|
+
|
88
|
+
|
89
|
+
def apply_monad(verb: Verb[jax.Array], noun: Noun[jax.Array]) -> Noun[jax.Array]:
|
90
|
+
result = _apply_monad(verb, noun.implementation)
|
91
|
+
return jax_array_to_noun(result)
|
92
|
+
|
93
|
+
|
94
|
+
def _apply_monad(verb: Verb[jax.Array], arr: jax.Array) -> jax.Array:
|
95
|
+
if verb.monad is None or verb.monad.function is None:
|
96
|
+
raise ValenceError(f"Verb {verb.spelling} has no monadic valence.")
|
97
|
+
if verb.monad.function is NotImplemented:
|
98
|
+
raise JinxNotImplementedError(
|
99
|
+
f"Verb {verb.spelling} monad function is not yet implemented in Jinx."
|
100
|
+
)
|
101
|
+
|
102
|
+
if isinstance(verb.monad.function, Verb):
|
103
|
+
function = functools.partial(_apply_monad, verb.monad.function)
|
104
|
+
else:
|
105
|
+
function = verb.monad.function # type: ignore[assignment]
|
106
|
+
|
107
|
+
rank = get_rank(verb.monad.rank, arr.ndim)
|
108
|
+
|
109
|
+
if rank == 0:
|
110
|
+
return function(arr)
|
111
|
+
|
112
|
+
array_cells = split_into_cells(arr, rank)
|
113
|
+
|
114
|
+
if array_cells.cells.ndim == 1:
|
115
|
+
cells = function(array_cells.cells)
|
116
|
+
else:
|
117
|
+
# TODO: Use jax.vmap instead
|
118
|
+
cells = jnp.asarray([function(cell) for cell in array_cells.cells])
|
119
|
+
|
120
|
+
# No filling/padding for now...
|
121
|
+
return jnp.asarray(cells).reshape(array_cells.frame_shape + cells[0].shape)
|
122
|
+
|
123
|
+
|
124
|
+
def apply_dyad(
|
125
|
+
verb: Verb[jax.Array], noun_1: Noun[jax.Array], noun_2: Noun[jax.Array]
|
126
|
+
) -> Noun[jax.Array]:
|
127
|
+
result = _apply_dyad(verb, noun_1.implementation, noun_2.implementation)
|
128
|
+
return jax_array_to_noun(result)
|
129
|
+
|
130
|
+
|
131
|
+
def _apply_dyad(
|
132
|
+
verb: Verb[jax.Array], left_arr: jax.Array, right_arr: jax.Array
|
133
|
+
) -> jax.Array:
|
134
|
+
if verb.dyad is None or verb.dyad.function is None:
|
135
|
+
raise ValenceError(f"Verb {verb.spelling} has no dyadic valence.")
|
136
|
+
if verb.dyad.function is NotImplemented:
|
137
|
+
raise JinxNotImplementedError(
|
138
|
+
f"Verb {verb.spelling} dyad function is not yet implemented."
|
139
|
+
)
|
140
|
+
|
141
|
+
if isinstance(verb.dyad.function, Verb):
|
142
|
+
function = functools.partial(_apply_dyad, verb.dyad.function)
|
143
|
+
else:
|
144
|
+
function = verb.dyad.function # type: ignore[assignment]
|
145
|
+
|
146
|
+
left_rank = get_rank(verb.dyad.left_rank, left_arr.ndim)
|
147
|
+
right_rank = get_rank(verb.dyad.right_rank, right_arr.ndim)
|
148
|
+
|
149
|
+
if left_rank == right_rank == 0 and (left_arr.ndim == 0 or right_arr.ndim == 0):
|
150
|
+
return function(left_arr, right_arr)
|
151
|
+
|
152
|
+
left = split_into_cells(left_arr, left_rank)
|
153
|
+
right = split_into_cells(right_arr, right_rank)
|
154
|
+
|
155
|
+
if left.frame_shape == right.frame_shape:
|
156
|
+
cells = [
|
157
|
+
function(left_cell, right_cell)
|
158
|
+
for left_cell, right_cell in zip(left.cells, right.cells, strict=True)
|
159
|
+
]
|
160
|
+
return jnp.asarray(cells).reshape(left.frame_shape + cells[0].shape)
|
161
|
+
|
162
|
+
raise JinxNotImplementedError(
|
163
|
+
"Dyadic verbs with non-zero rank and different frame shape are not yet implemented."
|
164
|
+
)
|
165
|
+
|
166
|
+
|
167
|
+
def apply_adverb(verb_or_noun: Verb | Noun, adverb: Adverb) -> Verb:
|
168
|
+
return adverb.function(verb_or_noun)
|
169
|
+
|
170
|
+
|
171
|
+
def build_fork(
|
172
|
+
f: Verb[jax.Array] | Noun[jax.Array], g: Verb[jax.Array], h: Verb[jax.Array]
|
173
|
+
) -> Verb[jax.Array]:
|
174
|
+
"""Build a fork given verbs f, g, h.
|
175
|
+
|
176
|
+
(f g h) y -> (f y) g (h y)
|
177
|
+
x (f g h) y -> (x f y) g (x h y)
|
178
|
+
|
179
|
+
The new verb has infinite rank.
|
180
|
+
|
181
|
+
Note that f can be a noun, in which case there is one fewer function calls.
|
182
|
+
"""
|
183
|
+
|
184
|
+
def _monad(y: jax.Array) -> jax.Array:
|
185
|
+
if isinstance(f, Verb) and f.spelling == "[:":
|
186
|
+
hy = _apply_monad(h, y)
|
187
|
+
return _apply_monad(g, hy)
|
188
|
+
|
189
|
+
if isinstance(f, Verb):
|
190
|
+
a = _apply_monad(f, y)
|
191
|
+
else:
|
192
|
+
a = f.implementation
|
193
|
+
b = _apply_monad(h, y)
|
194
|
+
return _apply_dyad(g, a, b)
|
195
|
+
|
196
|
+
def _dyad(x: jax.Array, y: jax.Array) -> jax.Array:
|
197
|
+
if isinstance(f, Verb) and f.spelling == "[:":
|
198
|
+
hy = _apply_dyad(h, x, y)
|
199
|
+
return _apply_monad(g, hy)
|
200
|
+
|
201
|
+
if isinstance(f, Verb):
|
202
|
+
a = _apply_dyad(f, x, y)
|
203
|
+
else:
|
204
|
+
a = f.implementation
|
205
|
+
b = _apply_dyad(h, x, y)
|
206
|
+
return _apply_dyad(g, a, b)
|
207
|
+
|
208
|
+
if isinstance(f, Verb):
|
209
|
+
f_spelling = f"({f.spelling})" if " " in f.spelling else f.spelling
|
210
|
+
else:
|
211
|
+
f_spelling = str(f.implementation)
|
212
|
+
|
213
|
+
g_spelling = f"({g.spelling})" if " " in g.spelling else g.spelling
|
214
|
+
h_spelling = f"({h.spelling})" if " " in h.spelling else h.spelling
|
215
|
+
spelling = f"{f_spelling} {g_spelling} {h_spelling}"
|
216
|
+
|
217
|
+
return Verb[jax.Array](
|
218
|
+
spelling=spelling,
|
219
|
+
name=spelling,
|
220
|
+
monad=Monad(
|
221
|
+
name=spelling,
|
222
|
+
rank=INFINITY,
|
223
|
+
function=_monad,
|
224
|
+
),
|
225
|
+
dyad=Dyad(
|
226
|
+
name=spelling,
|
227
|
+
left_rank=INFINITY,
|
228
|
+
right_rank=INFINITY,
|
229
|
+
function=_dyad,
|
230
|
+
),
|
231
|
+
)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
|
6
|
+
MonadT = Callable[[jax.Array], jax.Array]
|
7
|
+
DyadT = Callable[[jax.Array, jax.Array], jax.Array]
|
8
|
+
|
9
|
+
|
10
|
+
@jax.jit
|
11
|
+
def number_monad(y: jax.Array) -> jax.Array:
|
12
|
+
"""# monad: count number of items in y."""
|
13
|
+
if jnp.isscalar(y) or y.shape == ():
|
14
|
+
return jnp.array(1)
|
15
|
+
return jnp.array(y.shape[0])
|
16
|
+
|
17
|
+
|
18
|
+
@jax.jit
|
19
|
+
def percentco_dyad(x: jax.Array, y: jax.Array) -> jax.Array:
|
20
|
+
return jnp.power(y, 1 / x)
|
21
|
+
|
22
|
+
|
23
|
+
@jax.jit
|
24
|
+
def plusdot_monad(y: jax.Array) -> jax.Array:
|
25
|
+
"""+. monad: returns real and imaginary parts of numbers."""
|
26
|
+
return jnp.stack([jnp.real(y), jnp.imag(y)], axis=-1)
|
27
|
+
|
28
|
+
|
29
|
+
@jax.jit
|
30
|
+
def plusco_monad(y: jax.Array) -> jax.Array:
|
31
|
+
"""+: monad: double the values in the array."""
|
32
|
+
return 2 * y
|
33
|
+
|
34
|
+
|
35
|
+
@jax.jit
|
36
|
+
def minusdot_monad(y: jax.Array) -> jax.Array:
|
37
|
+
"""-.: monad: returns 1 - y."""
|
38
|
+
return 1 - y
|
39
|
+
|
40
|
+
|
41
|
+
@jax.jit
|
42
|
+
def minusco_monad(y: jax.Array) -> jax.Array:
|
43
|
+
"""-: monad: halve the values in the array."""
|
44
|
+
return y / 2
|
45
|
+
|
46
|
+
|
47
|
+
@jax.jit
|
48
|
+
def minusco_dyad(x: jax.Array, y: jax.Array) -> jax.Array:
|
49
|
+
"""-: dyad: match, returns true if x and y have same shape and values."""
|
50
|
+
is_equal = jnp.array_equal(x, y, equal_nan=True)
|
51
|
+
return jnp.asarray(is_equal)
|
52
|
+
|
53
|
+
|
54
|
+
@jax.jit
|
55
|
+
def hatdot_dyad(x: jax.Array, y: jax.Array) -> jax.Array:
|
56
|
+
"""^. dyad: logarithm of y to the base x."""
|
57
|
+
return jnp.log(y) / jnp.log(x)
|
58
|
+
|
59
|
+
|
60
|
+
@jax.jit
|
61
|
+
def ltco_monad(y: jax.Array) -> jax.Array:
|
62
|
+
"""<: monad: decrements the array."""
|
63
|
+
return y - 1
|
64
|
+
|
65
|
+
|
66
|
+
@jax.jit
|
67
|
+
def gtco_monad(y: jax.Array) -> jax.Array:
|
68
|
+
""">: monad: increments the array."""
|
69
|
+
return y + 1
|
70
|
+
|
71
|
+
|
72
|
+
VERB_MAP: dict[str, tuple[MonadT | None, DyadT | None]] = {
|
73
|
+
# VERB: (MONAD, DYAD)
|
74
|
+
"PLUS": (jnp.conj, jnp.add),
|
75
|
+
"PLUSDOT": (plusdot_monad, NotImplemented),
|
76
|
+
"PLUSCO": (plusco_monad, NotImplemented),
|
77
|
+
"MINUS": (jnp.negative, jnp.subtract),
|
78
|
+
"MINUSDOT": (minusdot_monad, NotImplemented),
|
79
|
+
"MINUSCO": (minusco_monad, minusco_dyad),
|
80
|
+
"STAR": (jnp.sign, jnp.multiply),
|
81
|
+
"PERCENT": (jnp.reciprocal, jnp.divide),
|
82
|
+
"HAT": (jnp.exp, jnp.power),
|
83
|
+
"HATDOT": (jnp.log, hatdot_dyad),
|
84
|
+
"LTDOT": (jnp.floor, jnp.minimum),
|
85
|
+
"GTDOT": (jnp.ceil, jnp.maximum),
|
86
|
+
"LTCO": (ltco_monad, jnp.less_equal),
|
87
|
+
"GTCO": (gtco_monad, jnp.greater_equal),
|
88
|
+
"NUMBER": (number_monad, NotImplemented),
|
89
|
+
"BAR": (jnp.abs, NotImplemented),
|
90
|
+
}
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from jinx.execution.executor import Executor
|
3
|
+
from jinx.execution.numpy.adverbs import ADVERB_MAP
|
4
|
+
from jinx.execution.numpy.application import (
|
5
|
+
apply_adverb,
|
6
|
+
apply_conjunction,
|
7
|
+
apply_dyad,
|
8
|
+
apply_monad,
|
9
|
+
build_fork,
|
10
|
+
build_hook,
|
11
|
+
)
|
12
|
+
from jinx.execution.numpy.conjunctions import CONJUNCTION_MAP
|
13
|
+
from jinx.execution.numpy.conversion import ensure_noun_implementation
|
14
|
+
from jinx.execution.numpy.printing import noun_to_string
|
15
|
+
from jinx.execution.numpy.verbs import VERB_MAP
|
16
|
+
|
17
|
+
executor = Executor[np.ndarray](
|
18
|
+
apply_monad=apply_monad,
|
19
|
+
apply_dyad=apply_dyad,
|
20
|
+
apply_conjunction=apply_conjunction,
|
21
|
+
apply_adverb=apply_adverb,
|
22
|
+
build_fork=build_fork,
|
23
|
+
build_hook=build_hook,
|
24
|
+
ensure_noun_implementation=ensure_noun_implementation,
|
25
|
+
primitive_verb_map=VERB_MAP,
|
26
|
+
primitive_adverb_map=ADVERB_MAP,
|
27
|
+
primitive_conjuction_map=CONJUNCTION_MAP, # type: ignore[arg-type]
|
28
|
+
noun_to_string=noun_to_string,
|
29
|
+
)
|