arc-agi 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arc_agi-0.2.0.dist-info/METADATA +8 -0
- arc_agi-0.2.0.dist-info/RECORD +14 -0
- arc_agi-0.2.0.dist-info/WHEEL +4 -0
- arc_agi-0.2.0.dist-info/entry_points.txt +4 -0
- arc_agi-0.2.0.dist-info/licenses/LICENSE +1 -0
- core/__init__.py +6 -0
- core/arc.py +356 -0
- core/datasets.py +69 -0
- core/main.py +6 -0
- core/pyproject.toml +20 -0
- core/utils.py +187 -0
- torch/__init__.py +1 -0
- torch/pyproject.toml +20 -0
- torch/test.py +2 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
arc_agi-0.2.0.dist-info/METADATA,sha256=MnrKuhhtkitJ3iZYufMqBKeQtzxIc30Bi8pq3GqkS7U,175
|
|
2
|
+
arc_agi-0.2.0.dist-info/WHEEL,sha256=thaaA2w1JzcGC48WYufAs8nrYZjJm8LqNfnXFOFyCC4,90
|
|
3
|
+
arc_agi-0.2.0.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
|
|
4
|
+
arc_agi-0.2.0.dist-info/licenses/LICENSE,sha256=n4bQgYhMfWWaL-qgxVrQFaO_TxsrC4Is0V1sFbDwCgg,4
|
|
5
|
+
core/__init__.py,sha256=TWfvvMoRaBDHWrxwkMJqrgYGdydL-tZykY18gq4UTlA,158
|
|
6
|
+
core/arc.py,sha256=PH3miMhi5891_guZ_iyFbepPjlgB0sOVEaQEJXi5_wo,10835
|
|
7
|
+
core/datasets.py,sha256=_vmhrk5umeDpHhDOBL2_SMcgZcZKcNztBj1xDfNjJZQ,2048
|
|
8
|
+
core/main.py,sha256=Wlb2wVYp0PRS9thbjKpSKIuIgELezk4rYOHS9rvG7B0,82
|
|
9
|
+
core/pyproject.toml,sha256=prRYKq3EuvB78XU5QXMdtqe0wpPV0jYQyIuY14BcMWg,381
|
|
10
|
+
core/utils.py,sha256=g-x51fPI5uCTQaG1UiHkYEXFOZ8uqC8QCVfztIKT8JE,6683
|
|
11
|
+
torch/__init__.py,sha256=tpkf7K2Ug708wjW-MWaHPKQ-QeOAgoJaqsH0rE7lqxs,19
|
|
12
|
+
torch/pyproject.toml,sha256=ZUZw1Xynn9wIeg3IlUXVaK-U9O-CJhiPIBc6I9LfL38,370
|
|
13
|
+
torch/test.py,sha256=Pt7adtoQ7FjpOUebvZ-sswRzJfv7e6Y8gj92MfHchDE,18
|
|
14
|
+
arc_agi-0.2.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
test
|
core/__init__.py
ADDED
core/arc.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import List, Union, Dict, Tuple, Self, Optional, Literal
|
|
3
|
+
import numpy as np
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from utils import Layout
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class COLOR(Enum):
|
|
11
|
+
ZERO = 0 # Background
|
|
12
|
+
ONE = 1
|
|
13
|
+
TWO = 2
|
|
14
|
+
THREE = 3
|
|
15
|
+
FOUR = 4
|
|
16
|
+
FIVE = 5
|
|
17
|
+
SIX = 6
|
|
18
|
+
SEVEN = 7
|
|
19
|
+
EIGHT = 8
|
|
20
|
+
NINE = 9
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Grid:
|
|
24
|
+
PALETTE: Dict[COLOR, str] = {
|
|
25
|
+
COLOR.ZERO: "\033[48;5;0m \033[0m",
|
|
26
|
+
COLOR.ONE: "\033[48;5;20m \033[0m",
|
|
27
|
+
COLOR.TWO: "\033[48;5;124m \033[0m",
|
|
28
|
+
COLOR.THREE: "\033[48;5;10m \033[0m",
|
|
29
|
+
COLOR.FOUR: "\033[48;5;11m \033[0m",
|
|
30
|
+
COLOR.FIVE: "\033[48;5;7m \033[0m",
|
|
31
|
+
COLOR.SIX: "\033[48;5;5m \033[0m",
|
|
32
|
+
COLOR.SEVEN: "\033[48;5;208m \033[0m",
|
|
33
|
+
COLOR.EIGHT: "\033[48;5;14m \033[0m",
|
|
34
|
+
COLOR.NINE: "\033[48;5;1m \033[0m",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def show_palette(cls):
|
|
39
|
+
"""Prints the color palette."""
|
|
40
|
+
print(
|
|
41
|
+
" | ".join(
|
|
42
|
+
f"{color}={symbol.value}" for symbol, color in cls.PALETTE.items()
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def __init__(self, array: Union[List[List[int]], np.ndarray, None] = None) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Initializes the `Grid` with a 2D array of integers.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
array (Union[List[List[int]], np.ndarray, None]): A 2D list of int or numpy ndarray representing the grid.
|
|
52
|
+
If None, a default 1x1 `Grid` of `COLOR.ZERO` is used.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If the input array is not a 2D list or numpy ndarray of integers.
|
|
56
|
+
ValueError: If any element in the array is not a value of `COLOR` (values of `COLOR` is 0~9 integers by default if `COLOR` is not modified. )
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
None
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
if not array:
|
|
63
|
+
array = [[0]]
|
|
64
|
+
|
|
65
|
+
if not isinstance(array, (np.ndarray, list)):
|
|
66
|
+
raise ValueError("Input array must be a 2D list or numpy ndarray.")
|
|
67
|
+
|
|
68
|
+
if not all(item in COLOR for row in array for item in row):
|
|
69
|
+
raise ValueError("Array elements must be values of `COLOR`")
|
|
70
|
+
|
|
71
|
+
self._array = array if isinstance(array, np.ndarray) else np.array(array)
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def from_json(cls, file_path: Union[str, Path]) -> Self:
|
|
75
|
+
"""
|
|
76
|
+
Creates a `Grid` instance from a JSON file at a given path.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
file_path (Union[str, Path]): File path of the JSON file to be loaded.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If the string format is incorrect or if any element is not a valid `COLOR` value.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Grid: A `Grid` instance created from the JSON file.
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
with open(file_path) as f:
|
|
89
|
+
array = json.load(f)
|
|
90
|
+
except json.JSONDecodeError:
|
|
91
|
+
raise ValueError("Invalid JSON format.")
|
|
92
|
+
|
|
93
|
+
return cls(array)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_npy(cls, filePath: Union[str, Path]) -> Self:
|
|
97
|
+
return cls(np.load(filePath))
|
|
98
|
+
|
|
99
|
+
def save_as_json(self, path: Union[str, Path]) -> None:
|
|
100
|
+
with open(path, "w") as f:
|
|
101
|
+
json.dump(self.to_list(), f)
|
|
102
|
+
|
|
103
|
+
def save_as_npy(self, path: str | Path) -> None:
|
|
104
|
+
np.save(path, self.to_numpy())
|
|
105
|
+
|
|
106
|
+
def to_list(self) -> List[List[int]]:
|
|
107
|
+
return self._array.tolist()
|
|
108
|
+
|
|
109
|
+
def to_numpy(self) -> np.ndarray:
|
|
110
|
+
return self._array
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def shape(self) -> Tuple[int, int]:
|
|
114
|
+
return self.to_numpy().shape
|
|
115
|
+
|
|
116
|
+
def __repr__(self) -> str:
|
|
117
|
+
return (
|
|
118
|
+
"\n".join(
|
|
119
|
+
"".join(self.PALETTE[COLOR(value)] for value in row)
|
|
120
|
+
for row in self.to_numpy()
|
|
121
|
+
)
|
|
122
|
+
+ "\n"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def __eq__(self, other: object) -> bool:
|
|
126
|
+
if not isinstance(other, Grid):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"Cannot compare with non-`Grid` object. "
|
|
129
|
+
"If the object is 2d list or numpy array, try converting it to `Grid` and then compare. "
|
|
130
|
+
)
|
|
131
|
+
return bool((self.to_numpy() == other.to_numpy()).all())
|
|
132
|
+
|
|
133
|
+
def __sub__(self, other: object) -> int:
|
|
134
|
+
"""Number of different pixels"""
|
|
135
|
+
if not isinstance(other, Grid):
|
|
136
|
+
raise NotImplementedError(
|
|
137
|
+
"Cannot compare with non-`Grid` object. "
|
|
138
|
+
"If the object is 2d list or numpy array, try converting it to `Grid` and then compare. "
|
|
139
|
+
)
|
|
140
|
+
if self.shape != other.shape:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Connot compare `Grid`s of different shape. {self.shape} != {other.shape}"
|
|
143
|
+
)
|
|
144
|
+
return np.sum(self.to_numpy() != other.to_numpy())
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class Pair:
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
input: Union[Grid, List[List[int]]],
|
|
151
|
+
output: Union[Grid, List[List[int]]],
|
|
152
|
+
censor: bool = False,
|
|
153
|
+
) -> None:
|
|
154
|
+
self._input = input if isinstance(input, Grid) else Grid(input)
|
|
155
|
+
self._output = output if isinstance(output, Grid) else Grid(output)
|
|
156
|
+
self._is_censored = censor
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def input(self):
|
|
160
|
+
return self._input
|
|
161
|
+
|
|
162
|
+
@input.setter
|
|
163
|
+
def input(self, grid: Union[Grid, List[List[int]]]):
|
|
164
|
+
self._input = grid if isinstance(grid, Grid) else Grid(grid)
|
|
165
|
+
return self._input
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def output(self):
|
|
169
|
+
if self._is_censored:
|
|
170
|
+
warnings.warn(
|
|
171
|
+
"Access to `output` is censored. Call `.uncensor()` to gain access. ",
|
|
172
|
+
UserWarning,
|
|
173
|
+
)
|
|
174
|
+
return None
|
|
175
|
+
return self._output
|
|
176
|
+
|
|
177
|
+
@output.setter
|
|
178
|
+
def output(self, grid: Union[Grid, List[List[int]]]):
|
|
179
|
+
if self._is_censored:
|
|
180
|
+
warnings.warn(
|
|
181
|
+
"Access to `output` is censored. Call `.uncensor()` to gain access. ",
|
|
182
|
+
UserWarning,
|
|
183
|
+
)
|
|
184
|
+
return None
|
|
185
|
+
self._output = grid if isinstance(grid, Grid) else Grid(grid)
|
|
186
|
+
return self._output
|
|
187
|
+
|
|
188
|
+
def censor(self):
|
|
189
|
+
self._is_censored = True
|
|
190
|
+
|
|
191
|
+
def uncensor(self):
|
|
192
|
+
self._is_censored = False
|
|
193
|
+
|
|
194
|
+
def __repr__(self):
|
|
195
|
+
return repr(
|
|
196
|
+
Layout(
|
|
197
|
+
Layout(
|
|
198
|
+
"INPUT",
|
|
199
|
+
self.input,
|
|
200
|
+
direction="vertical",
|
|
201
|
+
align="center",
|
|
202
|
+
),
|
|
203
|
+
"->",
|
|
204
|
+
Layout(
|
|
205
|
+
"OUTPUT",
|
|
206
|
+
self.output if self.output else "*CENSORED*",
|
|
207
|
+
direction="vertical",
|
|
208
|
+
align="center",
|
|
209
|
+
),
|
|
210
|
+
align="center",
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def to_dict(self):
|
|
215
|
+
return {"input": self.input.to_list(), "output": self.output.to_list()}
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class Task:
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
train: Union[
|
|
222
|
+
List[Pair],
|
|
223
|
+
List[Tuple[List[List[int]], List[List[int]]]],
|
|
224
|
+
List[Tuple[Grid, Grid]],
|
|
225
|
+
],
|
|
226
|
+
test: Union[
|
|
227
|
+
List[Pair],
|
|
228
|
+
List[Tuple[List[List[int]], List[List[int]]]],
|
|
229
|
+
List[Tuple[Grid, Grid]],
|
|
230
|
+
],
|
|
231
|
+
task_id: Optional[str] = None,
|
|
232
|
+
):
|
|
233
|
+
self.train = [pair if isinstance(pair, Pair) else Pair(*pair) for pair in train]
|
|
234
|
+
self.test = [pair if isinstance(pair, Pair) else Pair(*pair) for pair in test]
|
|
235
|
+
self.task_id = task_id
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def from_dict(
|
|
239
|
+
cls,
|
|
240
|
+
task_dict: Dict[
|
|
241
|
+
Literal["train", "test"],
|
|
242
|
+
List[Dict[Literal["input", "output"], List[List[int]]]],
|
|
243
|
+
],
|
|
244
|
+
task_id: Optional[str] = None,
|
|
245
|
+
):
|
|
246
|
+
train = [Pair(pair["input"], pair["output"]) for pair in task_dict["train"]]
|
|
247
|
+
test = [Pair(pair["input"], pair["output"]) for pair in task_dict["test"]]
|
|
248
|
+
|
|
249
|
+
return cls(train, test, task_id)
|
|
250
|
+
|
|
251
|
+
def to_dict(self):
|
|
252
|
+
return {
|
|
253
|
+
"train": [pair.to_dict() for pair in self.train],
|
|
254
|
+
"test": [pair.to_dict() for pair in self.test],
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
@classmethod
|
|
258
|
+
def from_json(cls, file_path: Union[str, Path]):
|
|
259
|
+
file_path = file_path if isinstance(file_path, Path) else Path(file_path)
|
|
260
|
+
task_id = file_path.stem
|
|
261
|
+
|
|
262
|
+
task = None
|
|
263
|
+
try:
|
|
264
|
+
with file_path.open() as f:
|
|
265
|
+
task = json.load(f)
|
|
266
|
+
# TODO: validate schema
|
|
267
|
+
except Exception as e:
|
|
268
|
+
raise RuntimeError(
|
|
269
|
+
f"Failed to load and parse task json file at '{file_path}: {e}"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return cls.from_dict(task, task_id)
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def inputs(self):
|
|
276
|
+
return [pair.input for pair in self.train + self.test]
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def outputs(self):
|
|
280
|
+
return [pair.output for pair in self.train + self.test]
|
|
281
|
+
|
|
282
|
+
def __repr__(self):
|
|
283
|
+
train = Layout(
|
|
284
|
+
*[
|
|
285
|
+
Layout(
|
|
286
|
+
Layout(
|
|
287
|
+
f"INPUT {i}",
|
|
288
|
+
pair.input,
|
|
289
|
+
direction="vertical",
|
|
290
|
+
align="center",
|
|
291
|
+
),
|
|
292
|
+
" -> ",
|
|
293
|
+
Layout(
|
|
294
|
+
f"OUTPUT {i}",
|
|
295
|
+
pair.output if pair.output else "*CENSORED*",
|
|
296
|
+
direction="vertical",
|
|
297
|
+
align="center",
|
|
298
|
+
),
|
|
299
|
+
)
|
|
300
|
+
for i, pair in enumerate(self.train)
|
|
301
|
+
],
|
|
302
|
+
direction="vertical",
|
|
303
|
+
)
|
|
304
|
+
test = Layout(
|
|
305
|
+
*[
|
|
306
|
+
Layout(
|
|
307
|
+
Layout(
|
|
308
|
+
f"INPUT {i}",
|
|
309
|
+
pair.input,
|
|
310
|
+
direction="vertical",
|
|
311
|
+
align="center",
|
|
312
|
+
),
|
|
313
|
+
" -> ",
|
|
314
|
+
Layout(
|
|
315
|
+
f"OUTPUT {i}",
|
|
316
|
+
pair.output if pair.output else "*CENSORED*",
|
|
317
|
+
direction="vertical",
|
|
318
|
+
align="center",
|
|
319
|
+
),
|
|
320
|
+
)
|
|
321
|
+
for i, pair in enumerate(self.test)
|
|
322
|
+
],
|
|
323
|
+
direction="vertical",
|
|
324
|
+
)
|
|
325
|
+
width = max(train.width, test.width)
|
|
326
|
+
return repr(
|
|
327
|
+
Layout(
|
|
328
|
+
f"< Task{' ' + self.task_id if self.task_id else ''} >".center(
|
|
329
|
+
width, "="
|
|
330
|
+
),
|
|
331
|
+
" Train ".center(width, "-"),
|
|
332
|
+
train,
|
|
333
|
+
" Test ".center(width, "-"),
|
|
334
|
+
test,
|
|
335
|
+
direction="vertical",
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
def __str__(self):
|
|
340
|
+
return str(repr(self))
|
|
341
|
+
|
|
342
|
+
def censor_outputs(self):
|
|
343
|
+
for pair in self.train + self.test:
|
|
344
|
+
pair.censor()
|
|
345
|
+
|
|
346
|
+
def uncensor_outputs(self):
|
|
347
|
+
for pair in self.train + self.test:
|
|
348
|
+
pair.uncensor()
|
|
349
|
+
|
|
350
|
+
def censor_test_outputs(self):
|
|
351
|
+
for pair in self.test:
|
|
352
|
+
pair.censor()
|
|
353
|
+
|
|
354
|
+
def uncensor_test_outputs(self):
|
|
355
|
+
for pair in self.test:
|
|
356
|
+
pair.uncensor()
|
core/datasets.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Union, List, Dict
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from utils import download_from_github
|
|
4
|
+
import arc
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ARC1:
|
|
8
|
+
def __init__(
|
|
9
|
+
self, dataset_path: Union[str, Path], train: bool = True, download: bool = True
|
|
10
|
+
):
|
|
11
|
+
self._tasks: List[arc.Task] = list()
|
|
12
|
+
self._tasks_map: Dict[str, int] = dict()
|
|
13
|
+
self._dataset_path = (
|
|
14
|
+
dataset_path if isinstance(dataset_path, Path) else Path(dataset_path)
|
|
15
|
+
)
|
|
16
|
+
self._train = train
|
|
17
|
+
|
|
18
|
+
if download:
|
|
19
|
+
self.download()
|
|
20
|
+
|
|
21
|
+
self.load()
|
|
22
|
+
|
|
23
|
+
def load(self):
|
|
24
|
+
if not self.dataset_path.exists():
|
|
25
|
+
raise FileNotFoundError(
|
|
26
|
+
f"Dataset path '{self._dataset_path}' does not exist. "
|
|
27
|
+
)
|
|
28
|
+
if self.dataset_path.is_dir():
|
|
29
|
+
raise NotADirectoryError(
|
|
30
|
+
f"Dataset path '{self._dataset_path}' is not a directory. "
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
for file_path in self._dataset_path.glob("*.json"):
|
|
34
|
+
task = arc.Task.from_json(file_path)
|
|
35
|
+
self._tasks.append(task)
|
|
36
|
+
|
|
37
|
+
def download(self):
|
|
38
|
+
download_from_github(
|
|
39
|
+
"fchollet",
|
|
40
|
+
"ARC-AGI",
|
|
41
|
+
f"data/{'training' if self._train else 'evaluation'}",
|
|
42
|
+
"main",
|
|
43
|
+
self._dataset_path,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def __contains__(self, task_id: str) -> bool:
|
|
47
|
+
return any(task.task_id == task_id for task in self._tasks)
|
|
48
|
+
|
|
49
|
+
def get(self, task_id: str) -> arc.Task:
|
|
50
|
+
if task_id not in self:
|
|
51
|
+
raise KeyError(f"Task with task id: {task_id} is not in this dataset. ")
|
|
52
|
+
return next((task for task in self._tasks if task.task_id == task_id), None)
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, i: int) -> arc.Task:
|
|
55
|
+
return self._tasks[i]
|
|
56
|
+
|
|
57
|
+
def __len__(self) -> int:
|
|
58
|
+
return len(self._tasks)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ARC2(ARC1):
|
|
62
|
+
def download(self):
|
|
63
|
+
download_from_github(
|
|
64
|
+
"arcprize",
|
|
65
|
+
"ARC-AGI-2",
|
|
66
|
+
f"data/{'training' if self._train else 'evaluation'}",
|
|
67
|
+
"main",
|
|
68
|
+
self._dataset_path,
|
|
69
|
+
)
|
core/main.py
ADDED
core/pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["pdm-backend"]
|
|
3
|
+
build-backend = "pdm.backend"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "arc-core"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "Core package of arc. "
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.13"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"numpy>=2.2.4",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
[[tool.uv.index]]
|
|
17
|
+
name = "testpypi"
|
|
18
|
+
url = "https://test.pypi.org/simple/"
|
|
19
|
+
publish-url = "https://test.pypi.org/legacy/"
|
|
20
|
+
explicit = true
|
core/utils.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Literal, Any
|
|
3
|
+
import os
|
|
4
|
+
import requests
|
|
5
|
+
import zipfile
|
|
6
|
+
import io
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
ANSI_ESCAPE_PATTERN = re.compile(r"\x1b\[[0-9;]*m")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def strip_ansi(text: str) -> str:
|
|
13
|
+
"""Removes ANSI escape codes from the string for correct width calculation."""
|
|
14
|
+
return ANSI_ESCAPE_PATTERN.sub("", text)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def ansi_width(text: str) -> int:
|
|
18
|
+
"""Returns the width of text after removing ANSI escape sequences."""
|
|
19
|
+
return len(strip_ansi(text))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def align_lines(
|
|
23
|
+
lines: List[str],
|
|
24
|
+
target_width: int,
|
|
25
|
+
align: Literal["start", "center", "end"],
|
|
26
|
+
) -> List[str]:
|
|
27
|
+
"""Aligns lines while preserving ANSI codes."""
|
|
28
|
+
if align == "start":
|
|
29
|
+
return [
|
|
30
|
+
line + " " * (max(target_width - ansi_width(line), 0)) for line in lines
|
|
31
|
+
]
|
|
32
|
+
elif align == "center":
|
|
33
|
+
return [
|
|
34
|
+
" " * (max(target_width - ansi_width(line), 0) // 2)
|
|
35
|
+
+ line
|
|
36
|
+
+ " " * ((max(target_width - ansi_width(line), 0) + 1) // 2)
|
|
37
|
+
for line in lines
|
|
38
|
+
]
|
|
39
|
+
elif align == "end":
|
|
40
|
+
return [" " * max(target_width - ansi_width(line), 0) + line for line in lines]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Layout:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
*elements: Any,
|
|
47
|
+
direction: Literal["horizontal", "vertical"] = "horizontal",
|
|
48
|
+
align: Literal["start", "center", "end"] = "start",
|
|
49
|
+
show_divider: bool = False,
|
|
50
|
+
min_width: int = 0,
|
|
51
|
+
):
|
|
52
|
+
self.elements = elements
|
|
53
|
+
self.direction = direction
|
|
54
|
+
self.align = align
|
|
55
|
+
self.show_divider = show_divider
|
|
56
|
+
self.min_width = min_width
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def width(self):
|
|
60
|
+
return max(ansi_width(line) for line in repr(self).splitlines())
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def height(self):
|
|
64
|
+
return len(repr(self).splitlines())
|
|
65
|
+
|
|
66
|
+
def __repr__(self) -> str:
|
|
67
|
+
elements = [
|
|
68
|
+
(element if isinstance(element, str) else repr(element)).splitlines()
|
|
69
|
+
for element in self.elements
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
if self.direction == "horizontal":
|
|
73
|
+
max_height = max(len(element) for element in elements)
|
|
74
|
+
normalized_elements = [
|
|
75
|
+
element + [""] * (max_height - len(element)) for element in elements
|
|
76
|
+
]
|
|
77
|
+
widths = [max(ansi_width(line) for line in element) for element in elements]
|
|
78
|
+
aligned_elements = [
|
|
79
|
+
align_lines(element, width, self.align)
|
|
80
|
+
for element, width in zip(normalized_elements, widths)
|
|
81
|
+
]
|
|
82
|
+
divider = " | " if self.show_divider else ""
|
|
83
|
+
return "\n".join(
|
|
84
|
+
align_lines(
|
|
85
|
+
[
|
|
86
|
+
divider.join(
|
|
87
|
+
aligned_elements[col][row]
|
|
88
|
+
for col in range(len(aligned_elements))
|
|
89
|
+
)
|
|
90
|
+
for row in range(max_height)
|
|
91
|
+
],
|
|
92
|
+
self.min_width,
|
|
93
|
+
self.align,
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
elif self.direction == "vertical":
|
|
98
|
+
max_width = max(
|
|
99
|
+
max(ansi_width(line) for line in element) for element in elements
|
|
100
|
+
)
|
|
101
|
+
aligned_elements = [
|
|
102
|
+
align_lines(element, max_width, self.align) for element in elements
|
|
103
|
+
]
|
|
104
|
+
divider = "\n" + "-" * max_width if self.show_divider else ""
|
|
105
|
+
return f"{divider}\n".join(
|
|
106
|
+
"\n".join(element) for element in aligned_elements
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def __str__(self) -> str:
|
|
110
|
+
elements = [str(element).splitlines() for element in self.elements]
|
|
111
|
+
|
|
112
|
+
if self.direction == "horizontal":
|
|
113
|
+
max_height = max(len(element) for element in elements)
|
|
114
|
+
normalized_elements = [
|
|
115
|
+
element + [""] * (max_height - len(element)) for element in elements
|
|
116
|
+
]
|
|
117
|
+
widths = [max(ansi_width(line) for line in element) for element in elements]
|
|
118
|
+
aligned_elements = [
|
|
119
|
+
align_lines(element, width, self.align)
|
|
120
|
+
for element, width in zip(normalized_elements, widths)
|
|
121
|
+
]
|
|
122
|
+
divider = " | " if self.show_divider else ""
|
|
123
|
+
return "\n".join(
|
|
124
|
+
divider.join(
|
|
125
|
+
aligned_elements[col][row] for col in range(len(aligned_elements))
|
|
126
|
+
)
|
|
127
|
+
for row in range(max_height)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
elif self.direction == "vertical":
|
|
131
|
+
max_width = max(
|
|
132
|
+
max(ansi_width(line) for line in element) for element in elements
|
|
133
|
+
)
|
|
134
|
+
aligned_elements = [
|
|
135
|
+
align_lines(element, max_width, self.align) for element in elements
|
|
136
|
+
]
|
|
137
|
+
divider = "\n" + "-" * max_width if self.show_divider else ""
|
|
138
|
+
return f"{divider}\n".join(
|
|
139
|
+
"\n".join(element) for element in aligned_elements
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def download_from_github(
|
|
144
|
+
repo_owner, repo_name, path, branch="main", destination="./downloaded"
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Downloads and extracts a specific path from a GitHub repository.
|
|
148
|
+
|
|
149
|
+
:param repo_owner: GitHub username or organization.
|
|
150
|
+
:param repo_name: Repository name.
|
|
151
|
+
:param path: Path to the folder inside the repo to extract.
|
|
152
|
+
:param branch: Repo branch (default: "main").
|
|
153
|
+
:param destination: Local destination folder for saving files.
|
|
154
|
+
"""
|
|
155
|
+
url = f"https://github.com/{repo_owner}/{repo_name}/archive/refs/heads/{branch}.zip"
|
|
156
|
+
print(f"Downloading from {url}")
|
|
157
|
+
response = requests.get(url, stream=True)
|
|
158
|
+
|
|
159
|
+
if response.status_code == 200:
|
|
160
|
+
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_file:
|
|
161
|
+
repo_folder = f"{repo_name}-{branch}/"
|
|
162
|
+
target_folder = f"{repo_folder}{path}/"
|
|
163
|
+
|
|
164
|
+
extracted_files = 0
|
|
165
|
+
|
|
166
|
+
for file in zip_file.namelist():
|
|
167
|
+
if file.startswith(target_folder) and not file.endswith("/"):
|
|
168
|
+
relative_path = file[len(target_folder) :]
|
|
169
|
+
save_path = os.path.join(destination, relative_path)
|
|
170
|
+
|
|
171
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
172
|
+
with zip_file.open(file) as source, open(save_path, "wb") as target:
|
|
173
|
+
target.write(source.read())
|
|
174
|
+
|
|
175
|
+
extracted_files += 1
|
|
176
|
+
|
|
177
|
+
if extracted_files > 0:
|
|
178
|
+
print(
|
|
179
|
+
f"✅ Successfully downloaded '{path}' from {repo_owner}/{repo_name} into '{destination}'."
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
print(
|
|
183
|
+
f"⚠️ No files extracted. Check if the path '{path}' exists in the repository."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
print(f"❌ Failed to download: HTTP {response.status_code}")
|
torch/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .test import A
|
torch/pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["pdm-backend"]
|
|
3
|
+
build-backend = "pdm.backend"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "arc-torch"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "arc torch"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.13"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"numpy>=2.2.4",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
[[tool.uv.index]]
|
|
17
|
+
name = "testpypi"
|
|
18
|
+
url = "https://test.pypi.org/simple/"
|
|
19
|
+
publish-url = "https://test.pypi.org/legacy/"
|
|
20
|
+
explicit = true
|
torch/test.py
ADDED