arc-agi 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arc_agi-0.2.0/LICENSE ADDED
@@ -0,0 +1 @@
1
+ test
arc_agi-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.1
2
+ Name: arc-agi
3
+ Version: 0.2.0
4
+ Summary: Core package of arc.
5
+ Requires-Python: >=3.13
6
+ Requires-Dist: numpy>=2.2.4
7
+ Description-Content-Type: text/markdown
8
+
File without changes
@@ -0,0 +1,6 @@
1
+ from .arc import Grid, Pair, Task, COLOR
2
+ from .datasets import ARC1, ARC2
3
+
4
+ __all__ = ["Grid", "Pair", "Task", "COLOR", "ARC1", "ARC2"]
5
+
6
+ __version__ = "1.0.0"
@@ -0,0 +1,356 @@
1
+ from enum import Enum
2
+ from typing import List, Union, Dict, Tuple, Self, Optional, Literal
3
+ import numpy as np
4
+ import json
5
+ from pathlib import Path
6
+ from utils import Layout
7
+ import warnings
8
+
9
+
10
+ class COLOR(Enum):
11
+ ZERO = 0 # Background
12
+ ONE = 1
13
+ TWO = 2
14
+ THREE = 3
15
+ FOUR = 4
16
+ FIVE = 5
17
+ SIX = 6
18
+ SEVEN = 7
19
+ EIGHT = 8
20
+ NINE = 9
21
+
22
+
23
+ class Grid:
24
+ PALETTE: Dict[COLOR, str] = {
25
+ COLOR.ZERO: "\033[48;5;0m \033[0m",
26
+ COLOR.ONE: "\033[48;5;20m \033[0m",
27
+ COLOR.TWO: "\033[48;5;124m \033[0m",
28
+ COLOR.THREE: "\033[48;5;10m \033[0m",
29
+ COLOR.FOUR: "\033[48;5;11m \033[0m",
30
+ COLOR.FIVE: "\033[48;5;7m \033[0m",
31
+ COLOR.SIX: "\033[48;5;5m \033[0m",
32
+ COLOR.SEVEN: "\033[48;5;208m \033[0m",
33
+ COLOR.EIGHT: "\033[48;5;14m \033[0m",
34
+ COLOR.NINE: "\033[48;5;1m \033[0m",
35
+ }
36
+
37
+ @classmethod
38
+ def show_palette(cls):
39
+ """Prints the color palette."""
40
+ print(
41
+ " | ".join(
42
+ f"{color}={symbol.value}" for symbol, color in cls.PALETTE.items()
43
+ )
44
+ )
45
+
46
+ def __init__(self, array: Union[List[List[int]], np.ndarray, None] = None) -> None:
47
+ """
48
+ Initializes the `Grid` with a 2D array of integers.
49
+
50
+ Args:
51
+ array (Union[List[List[int]], np.ndarray, None]): A 2D list of int or numpy ndarray representing the grid.
52
+ If None, a default 1x1 `Grid` of `COLOR.ZERO` is used.
53
+
54
+ Raises:
55
+ ValueError: If the input array is not a 2D list or numpy ndarray of integers.
56
+ ValueError: If any element in the array is not a value of `COLOR` (values of `COLOR` is 0~9 integers by default if `COLOR` is not modified. )
57
+
58
+ Returns:
59
+ None
60
+
61
+ """
62
+ if not array:
63
+ array = [[0]]
64
+
65
+ if not isinstance(array, (np.ndarray, list)):
66
+ raise ValueError("Input array must be a 2D list or numpy ndarray.")
67
+
68
+ if not all(item in COLOR for row in array for item in row):
69
+ raise ValueError("Array elements must be values of `COLOR`")
70
+
71
+ self._array = array if isinstance(array, np.ndarray) else np.array(array)
72
+
73
+ @classmethod
74
+ def from_json(cls, file_path: Union[str, Path]) -> Self:
75
+ """
76
+ Creates a `Grid` instance from a JSON file at a given path.
77
+
78
+ Args:
79
+ file_path (Union[str, Path]): File path of the JSON file to be loaded.
80
+
81
+ Raises:
82
+ ValueError: If the string format is incorrect or if any element is not a valid `COLOR` value.
83
+
84
+ Returns:
85
+ Grid: A `Grid` instance created from the JSON file.
86
+ """
87
+ try:
88
+ with open(file_path) as f:
89
+ array = json.load(f)
90
+ except json.JSONDecodeError:
91
+ raise ValueError("Invalid JSON format.")
92
+
93
+ return cls(array)
94
+
95
+ @classmethod
96
+ def from_npy(cls, filePath: Union[str, Path]) -> Self:
97
+ return cls(np.load(filePath))
98
+
99
+ def save_as_json(self, path: Union[str, Path]) -> None:
100
+ with open(path, "w") as f:
101
+ json.dump(self.to_list(), f)
102
+
103
+ def save_as_npy(self, path: str | Path) -> None:
104
+ np.save(path, self.to_numpy())
105
+
106
+ def to_list(self) -> List[List[int]]:
107
+ return self._array.tolist()
108
+
109
+ def to_numpy(self) -> np.ndarray:
110
+ return self._array
111
+
112
+ @property
113
+ def shape(self) -> Tuple[int, int]:
114
+ return self.to_numpy().shape
115
+
116
+ def __repr__(self) -> str:
117
+ return (
118
+ "\n".join(
119
+ "".join(self.PALETTE[COLOR(value)] for value in row)
120
+ for row in self.to_numpy()
121
+ )
122
+ + "\n"
123
+ )
124
+
125
+ def __eq__(self, other: object) -> bool:
126
+ if not isinstance(other, Grid):
127
+ raise ValueError(
128
+ "Cannot compare with non-`Grid` object. "
129
+ "If the object is 2d list or numpy array, try converting it to `Grid` and then compare. "
130
+ )
131
+ return bool((self.to_numpy() == other.to_numpy()).all())
132
+
133
+ def __sub__(self, other: object) -> int:
134
+ """Number of different pixels"""
135
+ if not isinstance(other, Grid):
136
+ raise NotImplementedError(
137
+ "Cannot compare with non-`Grid` object. "
138
+ "If the object is 2d list or numpy array, try converting it to `Grid` and then compare. "
139
+ )
140
+ if self.shape != other.shape:
141
+ raise ValueError(
142
+ f"Connot compare `Grid`s of different shape. {self.shape} != {other.shape}"
143
+ )
144
+ return np.sum(self.to_numpy() != other.to_numpy())
145
+
146
+
147
+ class Pair:
148
+ def __init__(
149
+ self,
150
+ input: Union[Grid, List[List[int]]],
151
+ output: Union[Grid, List[List[int]]],
152
+ censor: bool = False,
153
+ ) -> None:
154
+ self._input = input if isinstance(input, Grid) else Grid(input)
155
+ self._output = output if isinstance(output, Grid) else Grid(output)
156
+ self._is_censored = censor
157
+
158
+ @property
159
+ def input(self):
160
+ return self._input
161
+
162
+ @input.setter
163
+ def input(self, grid: Union[Grid, List[List[int]]]):
164
+ self._input = grid if isinstance(grid, Grid) else Grid(grid)
165
+ return self._input
166
+
167
+ @property
168
+ def output(self):
169
+ if self._is_censored:
170
+ warnings.warn(
171
+ "Access to `output` is censored. Call `.uncensor()` to gain access. ",
172
+ UserWarning,
173
+ )
174
+ return None
175
+ return self._output
176
+
177
+ @output.setter
178
+ def output(self, grid: Union[Grid, List[List[int]]]):
179
+ if self._is_censored:
180
+ warnings.warn(
181
+ "Access to `output` is censored. Call `.uncensor()` to gain access. ",
182
+ UserWarning,
183
+ )
184
+ return None
185
+ self._output = grid if isinstance(grid, Grid) else Grid(grid)
186
+ return self._output
187
+
188
+ def censor(self):
189
+ self._is_censored = True
190
+
191
+ def uncensor(self):
192
+ self._is_censored = False
193
+
194
+ def __repr__(self):
195
+ return repr(
196
+ Layout(
197
+ Layout(
198
+ "INPUT",
199
+ self.input,
200
+ direction="vertical",
201
+ align="center",
202
+ ),
203
+ "->",
204
+ Layout(
205
+ "OUTPUT",
206
+ self.output if self.output else "*CENSORED*",
207
+ direction="vertical",
208
+ align="center",
209
+ ),
210
+ align="center",
211
+ )
212
+ )
213
+
214
+ def to_dict(self):
215
+ return {"input": self.input.to_list(), "output": self.output.to_list()}
216
+
217
+
218
+ class Task:
219
+ def __init__(
220
+ self,
221
+ train: Union[
222
+ List[Pair],
223
+ List[Tuple[List[List[int]], List[List[int]]]],
224
+ List[Tuple[Grid, Grid]],
225
+ ],
226
+ test: Union[
227
+ List[Pair],
228
+ List[Tuple[List[List[int]], List[List[int]]]],
229
+ List[Tuple[Grid, Grid]],
230
+ ],
231
+ task_id: Optional[str] = None,
232
+ ):
233
+ self.train = [pair if isinstance(pair, Pair) else Pair(*pair) for pair in train]
234
+ self.test = [pair if isinstance(pair, Pair) else Pair(*pair) for pair in test]
235
+ self.task_id = task_id
236
+
237
+ @classmethod
238
+ def from_dict(
239
+ cls,
240
+ task_dict: Dict[
241
+ Literal["train", "test"],
242
+ List[Dict[Literal["input", "output"], List[List[int]]]],
243
+ ],
244
+ task_id: Optional[str] = None,
245
+ ):
246
+ train = [Pair(pair["input"], pair["output"]) for pair in task_dict["train"]]
247
+ test = [Pair(pair["input"], pair["output"]) for pair in task_dict["test"]]
248
+
249
+ return cls(train, test, task_id)
250
+
251
+ def to_dict(self):
252
+ return {
253
+ "train": [pair.to_dict() for pair in self.train],
254
+ "test": [pair.to_dict() for pair in self.test],
255
+ }
256
+
257
+ @classmethod
258
+ def from_json(cls, file_path: Union[str, Path]):
259
+ file_path = file_path if isinstance(file_path, Path) else Path(file_path)
260
+ task_id = file_path.stem
261
+
262
+ task = None
263
+ try:
264
+ with file_path.open() as f:
265
+ task = json.load(f)
266
+ # TODO: validate schema
267
+ except Exception as e:
268
+ raise RuntimeError(
269
+ f"Failed to load and parse task json file at '{file_path}: {e}"
270
+ )
271
+
272
+ return cls.from_dict(task, task_id)
273
+
274
+ @property
275
+ def inputs(self):
276
+ return [pair.input for pair in self.train + self.test]
277
+
278
+ @property
279
+ def outputs(self):
280
+ return [pair.output for pair in self.train + self.test]
281
+
282
+ def __repr__(self):
283
+ train = Layout(
284
+ *[
285
+ Layout(
286
+ Layout(
287
+ f"INPUT {i}",
288
+ pair.input,
289
+ direction="vertical",
290
+ align="center",
291
+ ),
292
+ " -> ",
293
+ Layout(
294
+ f"OUTPUT {i}",
295
+ pair.output if pair.output else "*CENSORED*",
296
+ direction="vertical",
297
+ align="center",
298
+ ),
299
+ )
300
+ for i, pair in enumerate(self.train)
301
+ ],
302
+ direction="vertical",
303
+ )
304
+ test = Layout(
305
+ *[
306
+ Layout(
307
+ Layout(
308
+ f"INPUT {i}",
309
+ pair.input,
310
+ direction="vertical",
311
+ align="center",
312
+ ),
313
+ " -> ",
314
+ Layout(
315
+ f"OUTPUT {i}",
316
+ pair.output if pair.output else "*CENSORED*",
317
+ direction="vertical",
318
+ align="center",
319
+ ),
320
+ )
321
+ for i, pair in enumerate(self.test)
322
+ ],
323
+ direction="vertical",
324
+ )
325
+ width = max(train.width, test.width)
326
+ return repr(
327
+ Layout(
328
+ f"< Task{' ' + self.task_id if self.task_id else ''} >".center(
329
+ width, "="
330
+ ),
331
+ " Train ".center(width, "-"),
332
+ train,
333
+ " Test ".center(width, "-"),
334
+ test,
335
+ direction="vertical",
336
+ )
337
+ )
338
+
339
+ def __str__(self):
340
+ return str(repr(self))
341
+
342
+ def censor_outputs(self):
343
+ for pair in self.train + self.test:
344
+ pair.censor()
345
+
346
+ def uncensor_outputs(self):
347
+ for pair in self.train + self.test:
348
+ pair.uncensor()
349
+
350
+ def censor_test_outputs(self):
351
+ for pair in self.test:
352
+ pair.censor()
353
+
354
+ def uncensor_test_outputs(self):
355
+ for pair in self.test:
356
+ pair.uncensor()
@@ -0,0 +1,69 @@
1
+ from typing import Union, List, Dict
2
+ from pathlib import Path
3
+ from utils import download_from_github
4
+ import arc
5
+
6
+
7
+ class ARC1:
8
+ def __init__(
9
+ self, dataset_path: Union[str, Path], train: bool = True, download: bool = True
10
+ ):
11
+ self._tasks: List[arc.Task] = list()
12
+ self._tasks_map: Dict[str, int] = dict()
13
+ self._dataset_path = (
14
+ dataset_path if isinstance(dataset_path, Path) else Path(dataset_path)
15
+ )
16
+ self._train = train
17
+
18
+ if download:
19
+ self.download()
20
+
21
+ self.load()
22
+
23
+ def load(self):
24
+ if not self.dataset_path.exists():
25
+ raise FileNotFoundError(
26
+ f"Dataset path '{self._dataset_path}' does not exist. "
27
+ )
28
+ if self.dataset_path.is_dir():
29
+ raise NotADirectoryError(
30
+ f"Dataset path '{self._dataset_path}' is not a directory. "
31
+ )
32
+
33
+ for file_path in self._dataset_path.glob("*.json"):
34
+ task = arc.Task.from_json(file_path)
35
+ self._tasks.append(task)
36
+
37
+ def download(self):
38
+ download_from_github(
39
+ "fchollet",
40
+ "ARC-AGI",
41
+ f"data/{'training' if self._train else 'evaluation'}",
42
+ "main",
43
+ self._dataset_path,
44
+ )
45
+
46
+ def __contains__(self, task_id: str) -> bool:
47
+ return any(task.task_id == task_id for task in self._tasks)
48
+
49
+ def get(self, task_id: str) -> arc.Task:
50
+ if task_id not in self:
51
+ raise KeyError(f"Task with task id: {task_id} is not in this dataset. ")
52
+ return next((task for task in self._tasks if task.task_id == task_id), None)
53
+
54
+ def __getitem__(self, i: int) -> arc.Task:
55
+ return self._tasks[i]
56
+
57
+ def __len__(self) -> int:
58
+ return len(self._tasks)
59
+
60
+
61
+ class ARC2(ARC1):
62
+ def download(self):
63
+ download_from_github(
64
+ "arcprize",
65
+ "ARC-AGI-2",
66
+ f"data/{'training' if self._train else 'evaluation'}",
67
+ "main",
68
+ self._dataset_path,
69
+ )
@@ -0,0 +1,6 @@
1
+ def main():
2
+ print("Hello from core!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
@@ -0,0 +1,20 @@
1
+ [build-system]
2
+ requires = ["pdm-backend"]
3
+ build-backend = "pdm.backend"
4
+
5
+ [project]
6
+ name = "arc-core"
7
+ version = "0.2.0"
8
+ description = "Core package of arc. "
9
+ readme = "README.md"
10
+ requires-python = ">=3.13"
11
+ dependencies = [
12
+ "numpy>=2.2.4",
13
+ ]
14
+
15
+
16
+ [[tool.uv.index]]
17
+ name = "testpypi"
18
+ url = "https://test.pypi.org/simple/"
19
+ publish-url = "https://test.pypi.org/legacy/"
20
+ explicit = true
@@ -0,0 +1,187 @@
1
+ import re
2
+ from typing import List, Literal, Any
3
+ import os
4
+ import requests
5
+ import zipfile
6
+ import io
7
+
8
+
9
+ ANSI_ESCAPE_PATTERN = re.compile(r"\x1b\[[0-9;]*m")
10
+
11
+
12
+ def strip_ansi(text: str) -> str:
13
+ """Removes ANSI escape codes from the string for correct width calculation."""
14
+ return ANSI_ESCAPE_PATTERN.sub("", text)
15
+
16
+
17
+ def ansi_width(text: str) -> int:
18
+ """Returns the width of text after removing ANSI escape sequences."""
19
+ return len(strip_ansi(text))
20
+
21
+
22
+ def align_lines(
23
+ lines: List[str],
24
+ target_width: int,
25
+ align: Literal["start", "center", "end"],
26
+ ) -> List[str]:
27
+ """Aligns lines while preserving ANSI codes."""
28
+ if align == "start":
29
+ return [
30
+ line + " " * (max(target_width - ansi_width(line), 0)) for line in lines
31
+ ]
32
+ elif align == "center":
33
+ return [
34
+ " " * (max(target_width - ansi_width(line), 0) // 2)
35
+ + line
36
+ + " " * ((max(target_width - ansi_width(line), 0) + 1) // 2)
37
+ for line in lines
38
+ ]
39
+ elif align == "end":
40
+ return [" " * max(target_width - ansi_width(line), 0) + line for line in lines]
41
+
42
+
43
+ class Layout:
44
+ def __init__(
45
+ self,
46
+ *elements: Any,
47
+ direction: Literal["horizontal", "vertical"] = "horizontal",
48
+ align: Literal["start", "center", "end"] = "start",
49
+ show_divider: bool = False,
50
+ min_width: int = 0,
51
+ ):
52
+ self.elements = elements
53
+ self.direction = direction
54
+ self.align = align
55
+ self.show_divider = show_divider
56
+ self.min_width = min_width
57
+
58
+ @property
59
+ def width(self):
60
+ return max(ansi_width(line) for line in repr(self).splitlines())
61
+
62
+ @property
63
+ def height(self):
64
+ return len(repr(self).splitlines())
65
+
66
+ def __repr__(self) -> str:
67
+ elements = [
68
+ (element if isinstance(element, str) else repr(element)).splitlines()
69
+ for element in self.elements
70
+ ]
71
+
72
+ if self.direction == "horizontal":
73
+ max_height = max(len(element) for element in elements)
74
+ normalized_elements = [
75
+ element + [""] * (max_height - len(element)) for element in elements
76
+ ]
77
+ widths = [max(ansi_width(line) for line in element) for element in elements]
78
+ aligned_elements = [
79
+ align_lines(element, width, self.align)
80
+ for element, width in zip(normalized_elements, widths)
81
+ ]
82
+ divider = " | " if self.show_divider else ""
83
+ return "\n".join(
84
+ align_lines(
85
+ [
86
+ divider.join(
87
+ aligned_elements[col][row]
88
+ for col in range(len(aligned_elements))
89
+ )
90
+ for row in range(max_height)
91
+ ],
92
+ self.min_width,
93
+ self.align,
94
+ )
95
+ )
96
+
97
+ elif self.direction == "vertical":
98
+ max_width = max(
99
+ max(ansi_width(line) for line in element) for element in elements
100
+ )
101
+ aligned_elements = [
102
+ align_lines(element, max_width, self.align) for element in elements
103
+ ]
104
+ divider = "\n" + "-" * max_width if self.show_divider else ""
105
+ return f"{divider}\n".join(
106
+ "\n".join(element) for element in aligned_elements
107
+ )
108
+
109
+ def __str__(self) -> str:
110
+ elements = [str(element).splitlines() for element in self.elements]
111
+
112
+ if self.direction == "horizontal":
113
+ max_height = max(len(element) for element in elements)
114
+ normalized_elements = [
115
+ element + [""] * (max_height - len(element)) for element in elements
116
+ ]
117
+ widths = [max(ansi_width(line) for line in element) for element in elements]
118
+ aligned_elements = [
119
+ align_lines(element, width, self.align)
120
+ for element, width in zip(normalized_elements, widths)
121
+ ]
122
+ divider = " | " if self.show_divider else ""
123
+ return "\n".join(
124
+ divider.join(
125
+ aligned_elements[col][row] for col in range(len(aligned_elements))
126
+ )
127
+ for row in range(max_height)
128
+ )
129
+
130
+ elif self.direction == "vertical":
131
+ max_width = max(
132
+ max(ansi_width(line) for line in element) for element in elements
133
+ )
134
+ aligned_elements = [
135
+ align_lines(element, max_width, self.align) for element in elements
136
+ ]
137
+ divider = "\n" + "-" * max_width if self.show_divider else ""
138
+ return f"{divider}\n".join(
139
+ "\n".join(element) for element in aligned_elements
140
+ )
141
+
142
+
143
+ def download_from_github(
144
+ repo_owner, repo_name, path, branch="main", destination="./downloaded"
145
+ ):
146
+ """
147
+ Downloads and extracts a specific path from a GitHub repository.
148
+
149
+ :param repo_owner: GitHub username or organization.
150
+ :param repo_name: Repository name.
151
+ :param path: Path to the folder inside the repo to extract.
152
+ :param branch: Repo branch (default: "main").
153
+ :param destination: Local destination folder for saving files.
154
+ """
155
+ url = f"https://github.com/{repo_owner}/{repo_name}/archive/refs/heads/{branch}.zip"
156
+ print(f"Downloading from {url}")
157
+ response = requests.get(url, stream=True)
158
+
159
+ if response.status_code == 200:
160
+ with zipfile.ZipFile(io.BytesIO(response.content)) as zip_file:
161
+ repo_folder = f"{repo_name}-{branch}/"
162
+ target_folder = f"{repo_folder}{path}/"
163
+
164
+ extracted_files = 0
165
+
166
+ for file in zip_file.namelist():
167
+ if file.startswith(target_folder) and not file.endswith("/"):
168
+ relative_path = file[len(target_folder) :]
169
+ save_path = os.path.join(destination, relative_path)
170
+
171
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
172
+ with zip_file.open(file) as source, open(save_path, "wb") as target:
173
+ target.write(source.read())
174
+
175
+ extracted_files += 1
176
+
177
+ if extracted_files > 0:
178
+ print(
179
+ f"✅ Successfully downloaded '{path}' from {repo_owner}/{repo_name} into '{destination}'."
180
+ )
181
+ else:
182
+ print(
183
+ f"⚠️ No files extracted. Check if the path '{path}' exists in the repository."
184
+ )
185
+
186
+ else:
187
+ print(f"❌ Failed to download: HTTP {response.status_code}")
@@ -0,0 +1,21 @@
1
+ [build-system]
2
+ requires = [
3
+ "pdm-backend",
4
+ ]
5
+ build-backend = "pdm.backend"
6
+
7
+ [project]
8
+ name = "arc-agi"
9
+ version = "0.2.0"
10
+ description = "Core package of arc. "
11
+ readme = "README.md"
12
+ requires-python = ">=3.13"
13
+ dependencies = [
14
+ "numpy>=2.2.4",
15
+ ]
16
+
17
+ [[tool.uv.index]]
18
+ name = "testpypi"
19
+ url = "https://test.pypi.org/simple/"
20
+ publish-url = "https://test.pypi.org/legacy/"
21
+ explicit = true
@@ -0,0 +1 @@
1
+ from .test import A
@@ -0,0 +1,20 @@
1
+ [build-system]
2
+ requires = ["pdm-backend"]
3
+ build-backend = "pdm.backend"
4
+
5
+ [project]
6
+ name = "arc-torch"
7
+ version = "0.2.0"
8
+ description = "arc torch"
9
+ readme = "README.md"
10
+ requires-python = ">=3.13"
11
+ dependencies = [
12
+ "numpy>=2.2.4",
13
+ ]
14
+
15
+
16
+ [[tool.uv.index]]
17
+ name = "testpypi"
18
+ url = "https://test.pypi.org/simple/"
19
+ publish-url = "https://test.pypi.org/legacy/"
20
+ explicit = true
@@ -0,0 +1,2 @@
1
+ class A:
2
+ x = 1