catform 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,377 @@
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "aho-corasick"
7
+ version = "1.1.4"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
10
+ dependencies = [
11
+ "memchr",
12
+ ]
13
+
14
+ [[package]]
15
+ name = "autocfg"
16
+ version = "1.5.0"
17
+ source = "registry+https://github.com/rust-lang/crates.io-index"
18
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
19
+
20
+ [[package]]
21
+ name = "catform"
22
+ version = "0.1.0"
23
+ source = "git+https://github.com/tensorigami/catform.git#e9d43d17277f9a44c3ef0e44aedc87f755e2fd9e"
24
+ dependencies = [
25
+ "indexmap",
26
+ "regex",
27
+ "serde",
28
+ "serde_json",
29
+ ]
30
+
31
+ [[package]]
32
+ name = "catform-bridge"
33
+ version = "0.1.0"
34
+ dependencies = [
35
+ "catform",
36
+ "pyo3",
37
+ "serde",
38
+ "serde_json",
39
+ "toml",
40
+ ]
41
+
42
+ [[package]]
43
+ name = "cfg-if"
44
+ version = "1.0.4"
45
+ source = "registry+https://github.com/rust-lang/crates.io-index"
46
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
47
+
48
+ [[package]]
49
+ name = "equivalent"
50
+ version = "1.0.2"
51
+ source = "registry+https://github.com/rust-lang/crates.io-index"
52
+ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
53
+
54
+ [[package]]
55
+ name = "hashbrown"
56
+ version = "0.16.1"
57
+ source = "registry+https://github.com/rust-lang/crates.io-index"
58
+ checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
59
+
60
+ [[package]]
61
+ name = "heck"
62
+ version = "0.5.0"
63
+ source = "registry+https://github.com/rust-lang/crates.io-index"
64
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
65
+
66
+ [[package]]
67
+ name = "indexmap"
68
+ version = "2.13.0"
69
+ source = "registry+https://github.com/rust-lang/crates.io-index"
70
+ checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017"
71
+ dependencies = [
72
+ "equivalent",
73
+ "hashbrown",
74
+ "serde",
75
+ "serde_core",
76
+ ]
77
+
78
+ [[package]]
79
+ name = "indoc"
80
+ version = "2.0.7"
81
+ source = "registry+https://github.com/rust-lang/crates.io-index"
82
+ checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
83
+ dependencies = [
84
+ "rustversion",
85
+ ]
86
+
87
+ [[package]]
88
+ name = "itoa"
89
+ version = "1.0.17"
90
+ source = "registry+https://github.com/rust-lang/crates.io-index"
91
+ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
92
+
93
+ [[package]]
94
+ name = "libc"
95
+ version = "0.2.183"
96
+ source = "registry+https://github.com/rust-lang/crates.io-index"
97
+ checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d"
98
+
99
+ [[package]]
100
+ name = "memchr"
101
+ version = "2.8.0"
102
+ source = "registry+https://github.com/rust-lang/crates.io-index"
103
+ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
104
+
105
+ [[package]]
106
+ name = "memoffset"
107
+ version = "0.9.1"
108
+ source = "registry+https://github.com/rust-lang/crates.io-index"
109
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
110
+ dependencies = [
111
+ "autocfg",
112
+ ]
113
+
114
+ [[package]]
115
+ name = "once_cell"
116
+ version = "1.21.4"
117
+ source = "registry+https://github.com/rust-lang/crates.io-index"
118
+ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
119
+
120
+ [[package]]
121
+ name = "portable-atomic"
122
+ version = "1.13.1"
123
+ source = "registry+https://github.com/rust-lang/crates.io-index"
124
+ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
125
+
126
+ [[package]]
127
+ name = "proc-macro2"
128
+ version = "1.0.106"
129
+ source = "registry+https://github.com/rust-lang/crates.io-index"
130
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
131
+ dependencies = [
132
+ "unicode-ident",
133
+ ]
134
+
135
+ [[package]]
136
+ name = "pyo3"
137
+ version = "0.24.2"
138
+ source = "registry+https://github.com/rust-lang/crates.io-index"
139
+ checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219"
140
+ dependencies = [
141
+ "cfg-if",
142
+ "indoc",
143
+ "libc",
144
+ "memoffset",
145
+ "once_cell",
146
+ "portable-atomic",
147
+ "pyo3-build-config",
148
+ "pyo3-ffi",
149
+ "pyo3-macros",
150
+ "unindent",
151
+ ]
152
+
153
+ [[package]]
154
+ name = "pyo3-build-config"
155
+ version = "0.24.2"
156
+ source = "registry+https://github.com/rust-lang/crates.io-index"
157
+ checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999"
158
+ dependencies = [
159
+ "once_cell",
160
+ "target-lexicon",
161
+ ]
162
+
163
+ [[package]]
164
+ name = "pyo3-ffi"
165
+ version = "0.24.2"
166
+ source = "registry+https://github.com/rust-lang/crates.io-index"
167
+ checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33"
168
+ dependencies = [
169
+ "libc",
170
+ "pyo3-build-config",
171
+ ]
172
+
173
+ [[package]]
174
+ name = "pyo3-macros"
175
+ version = "0.24.2"
176
+ source = "registry+https://github.com/rust-lang/crates.io-index"
177
+ checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9"
178
+ dependencies = [
179
+ "proc-macro2",
180
+ "pyo3-macros-backend",
181
+ "quote",
182
+ "syn",
183
+ ]
184
+
185
+ [[package]]
186
+ name = "pyo3-macros-backend"
187
+ version = "0.24.2"
188
+ source = "registry+https://github.com/rust-lang/crates.io-index"
189
+ checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a"
190
+ dependencies = [
191
+ "heck",
192
+ "proc-macro2",
193
+ "pyo3-build-config",
194
+ "quote",
195
+ "syn",
196
+ ]
197
+
198
+ [[package]]
199
+ name = "quote"
200
+ version = "1.0.45"
201
+ source = "registry+https://github.com/rust-lang/crates.io-index"
202
+ checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
203
+ dependencies = [
204
+ "proc-macro2",
205
+ ]
206
+
207
+ [[package]]
208
+ name = "regex"
209
+ version = "1.12.3"
210
+ source = "registry+https://github.com/rust-lang/crates.io-index"
211
+ checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276"
212
+ dependencies = [
213
+ "aho-corasick",
214
+ "memchr",
215
+ "regex-automata",
216
+ "regex-syntax",
217
+ ]
218
+
219
+ [[package]]
220
+ name = "regex-automata"
221
+ version = "0.4.14"
222
+ source = "registry+https://github.com/rust-lang/crates.io-index"
223
+ checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
224
+ dependencies = [
225
+ "aho-corasick",
226
+ "memchr",
227
+ "regex-syntax",
228
+ ]
229
+
230
+ [[package]]
231
+ name = "regex-syntax"
232
+ version = "0.8.10"
233
+ source = "registry+https://github.com/rust-lang/crates.io-index"
234
+ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
235
+
236
+ [[package]]
237
+ name = "rustversion"
238
+ version = "1.0.22"
239
+ source = "registry+https://github.com/rust-lang/crates.io-index"
240
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
241
+
242
+ [[package]]
243
+ name = "serde"
244
+ version = "1.0.228"
245
+ source = "registry+https://github.com/rust-lang/crates.io-index"
246
+ checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
247
+ dependencies = [
248
+ "serde_core",
249
+ "serde_derive",
250
+ ]
251
+
252
+ [[package]]
253
+ name = "serde_core"
254
+ version = "1.0.228"
255
+ source = "registry+https://github.com/rust-lang/crates.io-index"
256
+ checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
257
+ dependencies = [
258
+ "serde_derive",
259
+ ]
260
+
261
+ [[package]]
262
+ name = "serde_derive"
263
+ version = "1.0.228"
264
+ source = "registry+https://github.com/rust-lang/crates.io-index"
265
+ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
266
+ dependencies = [
267
+ "proc-macro2",
268
+ "quote",
269
+ "syn",
270
+ ]
271
+
272
+ [[package]]
273
+ name = "serde_json"
274
+ version = "1.0.149"
275
+ source = "registry+https://github.com/rust-lang/crates.io-index"
276
+ checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
277
+ dependencies = [
278
+ "itoa",
279
+ "memchr",
280
+ "serde",
281
+ "serde_core",
282
+ "zmij",
283
+ ]
284
+
285
+ [[package]]
286
+ name = "serde_spanned"
287
+ version = "0.6.9"
288
+ source = "registry+https://github.com/rust-lang/crates.io-index"
289
+ checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3"
290
+ dependencies = [
291
+ "serde",
292
+ ]
293
+
294
+ [[package]]
295
+ name = "syn"
296
+ version = "2.0.117"
297
+ source = "registry+https://github.com/rust-lang/crates.io-index"
298
+ checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
299
+ dependencies = [
300
+ "proc-macro2",
301
+ "quote",
302
+ "unicode-ident",
303
+ ]
304
+
305
+ [[package]]
306
+ name = "target-lexicon"
307
+ version = "0.13.5"
308
+ source = "registry+https://github.com/rust-lang/crates.io-index"
309
+ checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
310
+
311
+ [[package]]
312
+ name = "toml"
313
+ version = "0.8.23"
314
+ source = "registry+https://github.com/rust-lang/crates.io-index"
315
+ checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362"
316
+ dependencies = [
317
+ "serde",
318
+ "serde_spanned",
319
+ "toml_datetime",
320
+ "toml_edit",
321
+ ]
322
+
323
+ [[package]]
324
+ name = "toml_datetime"
325
+ version = "0.6.11"
326
+ source = "registry+https://github.com/rust-lang/crates.io-index"
327
+ checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c"
328
+ dependencies = [
329
+ "serde",
330
+ ]
331
+
332
+ [[package]]
333
+ name = "toml_edit"
334
+ version = "0.22.27"
335
+ source = "registry+https://github.com/rust-lang/crates.io-index"
336
+ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
337
+ dependencies = [
338
+ "indexmap",
339
+ "serde",
340
+ "serde_spanned",
341
+ "toml_datetime",
342
+ "toml_write",
343
+ "winnow",
344
+ ]
345
+
346
+ [[package]]
347
+ name = "toml_write"
348
+ version = "0.1.2"
349
+ source = "registry+https://github.com/rust-lang/crates.io-index"
350
+ checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
351
+
352
+ [[package]]
353
+ name = "unicode-ident"
354
+ version = "1.0.24"
355
+ source = "registry+https://github.com/rust-lang/crates.io-index"
356
+ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
357
+
358
+ [[package]]
359
+ name = "unindent"
360
+ version = "0.2.4"
361
+ source = "registry+https://github.com/rust-lang/crates.io-index"
362
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
363
+
364
+ [[package]]
365
+ name = "winnow"
366
+ version = "0.7.15"
367
+ source = "registry+https://github.com/rust-lang/crates.io-index"
368
+ checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945"
369
+ dependencies = [
370
+ "memchr",
371
+ ]
372
+
373
+ [[package]]
374
+ name = "zmij"
375
+ version = "1.0.21"
376
+ source = "registry+https://github.com/rust-lang/crates.io-index"
377
+ checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
@@ -0,0 +1,19 @@
1
+ [package]
2
+ name = "catform-bridge"
3
+ version = "0.1.0"
4
+ edition = "2024"
5
+
6
+ [lib]
7
+ name = "_catform"
8
+ crate-type = ["cdylib"]
9
+
10
+ [dependencies]
11
+ catform = { git = "https://github.com/tensorigami/catform.git" }
12
+ pyo3 = "0.24"
13
+ serde = { version = "1", features = ["derive"] }
14
+ serde_json = "1"
15
+ toml = "0.8"
16
+
17
+ [features]
18
+ default = []
19
+ extension-module = ["pyo3/extension-module"]
catform-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,5 @@
1
+ Metadata-Version: 2.4
2
+ Name: catform
3
+ Version: 0.1.0
4
+ Requires-Dist: orjson>=3.10
5
+ Requires-Python: >=3.13
@@ -0,0 +1,14 @@
1
+ [build-system]
2
+ requires = ["maturin>=1.7,<2"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "catform"
7
+ version = "0.1.0"
8
+ requires-python = ">=3.13"
9
+ dependencies = ["orjson>=3.10"]
10
+
11
+ [tool.maturin]
12
+ features = ["extension-module"]
13
+ python-source = "python"
14
+ module-name = "catform._catform"
@@ -0,0 +1,79 @@
1
+ """Catform language toolchain: .cat → dict → flat program.
2
+
3
+ Dict API (JSON transport is internal):
4
+ parse : str → dict
5
+ resolve : dict × dict → dict
6
+ flatten : dict × str → dict
7
+ format_cat : dict → str
8
+ check : dict → list[str]
9
+ load_flat : path × path → dict (hot path: parse + resolve + flatten)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ import orjson
19
+
20
+ from catform import _catform
21
+
22
+ # ── Dict API ─────────────────────────────────────────────────────────────────
23
+
24
+
25
+ def parse(source: str) -> dict[str, Any]:
26
+ return orjson.loads(_catform.parse_to_json(source))
27
+
28
+
29
+ def parse_file(path: str | Path) -> dict[str, Any]:
30
+ return orjson.loads(_catform.parse_to_json(Path(path).read_text()))
31
+
32
+
33
+ def resolve(d: dict[str, Any], config: dict[str, int | float]) -> dict[str, Any]:
34
+ return orjson.loads(_catform.resolve_json(json.dumps(d), json.dumps(config)))
35
+
36
+
37
+ def flatten(d: dict[str, Any], entry: str) -> dict[str, Any]:
38
+ return orjson.loads(_catform.flatten_json(json.dumps(d), entry))
39
+
40
+
41
+ def infer_axes(d: dict[str, Any]) -> dict[str, Any]:
42
+ return orjson.loads(_catform.infer_axes_json(json.dumps(d)))
43
+
44
+
45
+ def format_cat(d: dict[str, Any], *, width: int = 100) -> str:
46
+ return _catform.fmt_json(json.dumps(d), width=width)
47
+
48
+
49
+ def check(d: dict[str, Any]) -> dict[str, Any]:
50
+ return orjson.loads(_catform.check_json(json.dumps(d)))
51
+
52
+
53
+ def load_flat(cat_path: str | Path, config_path: str | Path, entry: str = "main") -> dict[str, Any]:
54
+ return orjson.loads(_catform.load_flat_json(str(cat_path), str(config_path), entry))
55
+
56
+
57
+ # ── String API (pass-through) ────────────────────────────────────────────────
58
+
59
+ fmt_source = _catform.fmt_source
60
+ fmt_file = _catform.fmt_file
61
+
62
+
63
+ def check_file(path: str | Path) -> dict[str, Any]:
64
+ return orjson.loads(_catform.check_file(str(path)))
65
+
66
+
67
+ __all__ = [
68
+ "check",
69
+ "check_file",
70
+ "flatten",
71
+ "fmt_file",
72
+ "fmt_source",
73
+ "format_cat",
74
+ "infer_axes",
75
+ "load_flat",
76
+ "parse",
77
+ "parse_file",
78
+ "resolve",
79
+ ]
@@ -0,0 +1,12 @@
1
+ """Type stubs for the native catform bridge."""
2
+
3
+ def load_flat_json(cat_path: str, config_path: str, entry: str = "main") -> str: ...
4
+ def parse_to_json(source: str) -> str: ...
5
+ def resolve_json(module_json: str, config_json: str) -> str: ...
6
+ def flatten_json(module_json: str, entry: str) -> str: ...
7
+ def infer_axes_json(module_json: str) -> str: ...
8
+ def check_json(module_json: str) -> str: ...
9
+ def fmt_source(source: str, *, width: int = 100) -> str: ...
10
+ def fmt_json(module_json: str, *, width: int = 100) -> str: ...
11
+ def check_file(path: str) -> str: ...
12
+ def fmt_file(path: str, *, width: int = 100) -> str: ...
File without changes
@@ -0,0 +1,371 @@
1
+ use std::collections::HashMap;
2
+
3
+ use pyo3::prelude::*;
4
+
5
+ use ::catform::ast::*;
6
+
7
+ // ── Config loading (pianola's interpretation of param.X) ─────────────────────
8
+
9
+ #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10
+ #[serde(untagged)]
11
+ pub enum ConfigValue {
12
+ Int(i64),
13
+ Float(f64),
14
+ Vec(Vec<f64>),
15
+ }
16
+
17
+ fn load_config(path: &str) -> HashMap<String, ConfigValue> {
18
+ let content =
19
+ std::fs::read_to_string(path).unwrap_or_else(|e| panic!("failed to read {path}: {e}"));
20
+ let table: toml::Table = content
21
+ .parse()
22
+ .unwrap_or_else(|e| panic!("failed to parse {path}: {e}"));
23
+
24
+ let mut config = HashMap::new();
25
+ for section in ["shape", "scalar", "vector"] {
26
+ if let Some(toml::Value::Table(t)) = table.get(section) {
27
+ for (k, v) in t {
28
+ let cv = match v {
29
+ toml::Value::Integer(n) => ConfigValue::Int(*n),
30
+ toml::Value::Float(f) => ConfigValue::Float(*f),
31
+ toml::Value::Array(arr) => {
32
+ let vals: Vec<f64> = arr
33
+ .iter()
34
+ .map(|v| match v {
35
+ toml::Value::Float(f) => *f,
36
+ toml::Value::Integer(n) => *n as f64,
37
+ _ => panic!("unexpected value in vector section"),
38
+ })
39
+ .collect();
40
+ ConfigValue::Vec(vals)
41
+ }
42
+ _ => continue,
43
+ };
44
+ config.insert(k.clone(), cv);
45
+ }
46
+ }
47
+ }
48
+ config
49
+ }
50
+
51
+ // ── Weight mapping (hierarchical [weights] → flat dict-path → safetensors key)
52
+
53
+ fn load_weight_mapping(path: &str) -> HashMap<String, String> {
54
+ let content =
55
+ std::fs::read_to_string(path).unwrap_or_else(|e| panic!("failed to read {path}: {e}"));
56
+ let table: toml::Table = content
57
+ .parse()
58
+ .unwrap_or_else(|e| panic!("failed to parse {path}: {e}"));
59
+
60
+ let mut mapping = HashMap::new();
61
+ if let Some(toml::Value::Table(weights)) = table.get("weights") {
62
+ flatten_weight_table(weights, "weights", &mut mapping);
63
+ }
64
+ mapping
65
+ }
66
+
67
+ fn flatten_weight_table(table: &toml::Table, prefix: &str, out: &mut HashMap<String, String>) {
68
+ for (key, value) in table {
69
+ let full_key = format!("{prefix}.{key}");
70
+ match value {
71
+ toml::Value::String(s) => {
72
+ out.insert(full_key, s.clone());
73
+ }
74
+ toml::Value::Table(sub) => {
75
+ flatten_weight_table(sub, &full_key, out);
76
+ }
77
+ _ => {}
78
+ }
79
+ }
80
+ }
81
+
82
+ /// Resolve a dict path like "weights.transformer.layer.3.attn.q" to a safetensors key.
83
+ ///
84
+ /// Algorithm: split path into segments, separate numeric segments (iteration indices)
85
+ /// from key segments, look up the key, substitute {} with numerics.
86
+ fn resolve_dict_path(path: &str, mapping: &HashMap<String, String>) -> Option<String> {
87
+ // Direct lookup first (non-layer weights like weights.embedding)
88
+ if let Some(key) = mapping.get(path) {
89
+ return Some(key.clone());
90
+ }
91
+
92
+ // Split path, separate numeric segments
93
+ let segments: Vec<&str> = path.split('.').collect();
94
+ let mut key_parts = Vec::new();
95
+ let mut numerics = Vec::new();
96
+
97
+ for seg in &segments {
98
+ if seg.chars().all(|c| c.is_ascii_digit()) {
99
+ numerics.push(*seg);
100
+ } else {
101
+ key_parts.push(*seg);
102
+ }
103
+ }
104
+
105
+ let key = key_parts.join(".");
106
+ if let Some(pattern) = mapping.get(&key) {
107
+ // Substitute {} with numeric segments in order
108
+ let mut result = pattern.clone();
109
+ for num in &numerics {
110
+ result = result.replacen("{}", num, 1);
111
+ }
112
+ Some(result)
113
+ } else {
114
+ None
115
+ }
116
+ }
117
+
118
+ /// After flattening, resolve all dict-path args to safetensors keys.
119
+ fn resolve_weight_paths(m: &mut Module, mapping: &HashMap<String, String>) {
120
+ for f in m.functions.values_mut() {
121
+ for op in &mut f.ops {
122
+ for arg in &mut op.args {
123
+ if let Atom::Name(n) = arg {
124
+ if n.starts_with("weights.") {
125
+ if let Some(resolved) = resolve_dict_path(n, mapping) {
126
+ *arg = Atom::Name(resolved);
127
+ }
128
+ }
129
+ }
130
+ }
131
+ }
132
+ }
133
+ }
134
+
135
+ // ── Param substitution (pianola convention: param.X → config lookup) ─────────
136
+
137
+ fn sub_dim(d: &Dim, config: &HashMap<String, ConfigValue>) -> Dim {
138
+ match d {
139
+ Dim::Named(name) if name.starts_with("param.") => {
140
+ let key = &name[6..];
141
+ match config.get(key) {
142
+ Some(ConfigValue::Int(n)) => Dim::Concrete(*n),
143
+ Some(ConfigValue::Float(f)) => Dim::Concrete(*f as i64),
144
+ _ => d.clone(),
145
+ }
146
+ }
147
+ _ => d.clone(),
148
+ }
149
+ }
150
+
151
+ fn sub_type(t: &TensorType, config: &HashMap<String, ConfigValue>) -> TensorType {
152
+ if t.dtype == "*" {
153
+ return t.clone(); // dict types have no shape to substitute
154
+ }
155
+ TensorType {
156
+ dtype: t.dtype.clone(),
157
+ shape: t.shape.iter().map(|d| sub_dim(d, config)).collect(),
158
+ }
159
+ }
160
+
161
+ fn sub_op(op: &Op, config: &HashMap<String, ConfigValue>) -> Op {
162
+ let kind = match &op.kind {
163
+ OpKind::View { pattern, axes } => OpKind::View {
164
+ pattern: pattern.clone(),
165
+ axes: axes.clone(),
166
+ },
167
+ OpKind::Map { function } => OpKind::Map {
168
+ function: function.clone(),
169
+ },
170
+ OpKind::Fold { pattern, function } => OpKind::Fold {
171
+ pattern: pattern.clone(),
172
+ function: function.clone(),
173
+ },
174
+ OpKind::Tile { pattern, axes } => OpKind::Tile {
175
+ pattern: pattern.clone(),
176
+ axes: axes.clone(),
177
+ },
178
+ OpKind::Gather { pattern } => OpKind::Gather {
179
+ pattern: pattern.clone(),
180
+ },
181
+ OpKind::Scatter { pattern } => OpKind::Scatter {
182
+ pattern: pattern.clone(),
183
+ },
184
+ OpKind::Contract { pattern } => OpKind::Contract {
185
+ pattern: pattern.clone(),
186
+ },
187
+ OpKind::Literal { value } => OpKind::Literal {
188
+ value: value.clone(),
189
+ },
190
+ OpKind::Random { function } => OpKind::Random {
191
+ function: function.clone(),
192
+ },
193
+ OpKind::Call { target } => OpKind::Call {
194
+ target: target.clone(),
195
+ },
196
+ OpKind::Loop { target, count } => {
197
+ let resolved = match count {
198
+ LoopCount::Concrete(n) => LoopCount::Concrete(*n),
199
+ LoopCount::Named(name) => {
200
+ let key = name.strip_prefix("param.").unwrap_or(name);
201
+ match config.get(key) {
202
+ Some(ConfigValue::Int(n)) => LoopCount::Concrete(*n as usize),
203
+ _ => count.clone(),
204
+ }
205
+ }
206
+ };
207
+ OpKind::Loop {
208
+ target: target.clone(),
209
+ count: resolved,
210
+ }
211
+ }
212
+ };
213
+
214
+ Op {
215
+ kind,
216
+ outputs: op.outputs.clone(),
217
+ output_types: op
218
+ .output_types
219
+ .iter()
220
+ .map(|t| t.as_ref().map(|t| sub_type(t, config)))
221
+ .collect(),
222
+ args: op.args.clone(),
223
+ comments: op.comments.clone(),
224
+ }
225
+ }
226
+
227
+ fn substitute_params(m: &Module, config: &HashMap<String, ConfigValue>) -> Module {
228
+ let functions = m
229
+ .functions
230
+ .iter()
231
+ .map(|(name, f)| {
232
+ let params = f
233
+ .params
234
+ .iter()
235
+ .map(|p| Param {
236
+ name: p.name.clone(),
237
+ ty: sub_type(&p.ty, config),
238
+ })
239
+ .collect();
240
+ let returns = f
241
+ .returns
242
+ .iter()
243
+ .map(|r| Param {
244
+ name: r.name.clone(),
245
+ ty: sub_type(&r.ty, config),
246
+ })
247
+ .collect();
248
+ let ops = f.ops.iter().map(|op| sub_op(op, config)).collect();
249
+ (
250
+ name.clone(),
251
+ Function {
252
+ name: f.name.clone(),
253
+ comments: f.comments.clone(),
254
+ params,
255
+ returns,
256
+ ops,
257
+ },
258
+ )
259
+ })
260
+ .collect();
261
+
262
+ Module {
263
+ header_comments: m.header_comments.clone(),
264
+ functions,
265
+ }
266
+ }
267
+
268
+ // -- JSON-based API -----------------------------------------------------------
269
+
270
+ #[pyfunction]
271
+ #[pyo3(name = "load_flat_json", signature = (cat_path, config_path, entry = "main"))]
272
+ fn py_load_flat_json(cat_path: &str, config_path: &str, entry: &str) -> String {
273
+ let module = ::catform::parse::parse_file(cat_path);
274
+ let cfg = load_config(config_path);
275
+ let substituted = substitute_params(&module, &cfg);
276
+ let resolved = ::catform::resolve::resolve(&substituted);
277
+ let mut flat = ::catform::flatten::flatten(&resolved, entry);
278
+
279
+ // Resolve dict paths to safetensors keys
280
+ let weight_mapping = load_weight_mapping(config_path);
281
+ resolve_weight_paths(&mut flat, &weight_mapping);
282
+
283
+ serde_json::to_string(&flat).unwrap()
284
+ }
285
+
286
+ #[pyfunction]
287
+ #[pyo3(name = "parse_to_json")]
288
+ fn py_parse_to_json(source: &str) -> String {
289
+ let module = ::catform::parse::parse(source);
290
+ serde_json::to_string(&module).unwrap()
291
+ }
292
+
293
+ #[pyfunction]
294
+ #[pyo3(name = "resolve_json")]
295
+ fn py_resolve_json(module_json: &str, config_json: &str) -> String {
296
+ let module: Module = serde_json::from_str(module_json).unwrap();
297
+ let cfg: HashMap<String, ConfigValue> = serde_json::from_str(config_json).unwrap();
298
+ let substituted = substitute_params(&module, &cfg);
299
+ let resolved = ::catform::resolve::resolve(&substituted);
300
+ serde_json::to_string(&resolved).unwrap()
301
+ }
302
+
303
+ #[pyfunction]
304
+ #[pyo3(name = "flatten_json")]
305
+ fn py_flatten_json(module_json: &str, entry: &str) -> String {
306
+ let module: Module = serde_json::from_str(module_json).unwrap();
307
+ let flat = ::catform::flatten::flatten(&module, entry);
308
+ serde_json::to_string(&flat).unwrap()
309
+ }
310
+
311
+ #[pyfunction]
312
+ #[pyo3(name = "infer_axes_json")]
313
+ fn py_infer_axes_json(module_json: &str) -> String {
314
+ let module: Module = serde_json::from_str(module_json).unwrap();
315
+ let result = ::catform::resolve::resolve(&module);
316
+ serde_json::to_string(&result).unwrap()
317
+ }
318
+
319
+ #[pyfunction]
320
+ #[pyo3(name = "check_json")]
321
+ fn py_check_json(module_json: &str) -> String {
322
+ let module: Module = serde_json::from_str(module_json).unwrap();
323
+ let result = ::catform::check::check(&module);
324
+ serde_json::to_string(&result).unwrap()
325
+ }
326
+
327
+ #[pyfunction]
328
+ #[pyo3(name = "fmt_source", signature = (source, *, width = 100))]
329
+ fn py_fmt_source(source: &str, width: usize) -> String {
330
+ let module = ::catform::parse::parse(source);
331
+ ::catform::fmt::format_cat(&module, width)
332
+ }
333
+
334
+ #[pyfunction]
335
+ #[pyo3(name = "fmt_json", signature = (module_json, *, width = 100))]
336
+ fn py_fmt_json(module_json: &str, width: usize) -> String {
337
+ let module: Module = serde_json::from_str(module_json).unwrap();
338
+ ::catform::fmt::format_cat(&module, width)
339
+ }
340
+
341
+ #[pyfunction]
342
+ #[pyo3(name = "check_file")]
343
+ fn py_check_file(path: &str) -> String {
344
+ let module = ::catform::parse::parse_file(path);
345
+ let result = ::catform::check::check(&module);
346
+ serde_json::to_string(&result).unwrap()
347
+ }
348
+
349
+ #[pyfunction]
350
+ #[pyo3(name = "fmt_file", signature = (path, *, width = 100))]
351
+ fn py_fmt_file(path: &str, width: usize) -> String {
352
+ let module = ::catform::parse::parse_file(path);
353
+ ::catform::fmt::format_cat(&module, width)
354
+ }
355
+
356
+ // -- Module registration ------------------------------------------------------
357
+
358
+ #[pymodule]
359
+ fn _catform(m: &Bound<'_, PyModule>) -> PyResult<()> {
360
+ m.add_function(wrap_pyfunction!(py_load_flat_json, m)?)?;
361
+ m.add_function(wrap_pyfunction!(py_parse_to_json, m)?)?;
362
+ m.add_function(wrap_pyfunction!(py_resolve_json, m)?)?;
363
+ m.add_function(wrap_pyfunction!(py_flatten_json, m)?)?;
364
+ m.add_function(wrap_pyfunction!(py_infer_axes_json, m)?)?;
365
+ m.add_function(wrap_pyfunction!(py_check_json, m)?)?;
366
+ m.add_function(wrap_pyfunction!(py_fmt_source, m)?)?;
367
+ m.add_function(wrap_pyfunction!(py_fmt_json, m)?)?;
368
+ m.add_function(wrap_pyfunction!(py_check_file, m)?)?;
369
+ m.add_function(wrap_pyfunction!(py_fmt_file, m)?)?;
370
+ Ok(())
371
+ }
catform-0.1.0/uv.lock ADDED
@@ -0,0 +1,8 @@
1
+ version = 1
2
+ revision = 1
3
+ requires-python = ">=3.13"
4
+
5
+ [[package]]
6
+ name = "catform"
7
+ version = "0.1.0"
8
+ source = { editable = "." }