@sqldoc/templates 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +161 -0
- package/src/__tests__/dedent.test.ts +45 -0
- package/src/__tests__/docker-templates.test.ts +134 -0
- package/src/__tests__/go-structs.test.ts +184 -0
- package/src/__tests__/naming.test.ts +48 -0
- package/src/__tests__/python-dataclasses.test.ts +185 -0
- package/src/__tests__/rust-structs.test.ts +176 -0
- package/src/__tests__/tags-helpers.test.ts +72 -0
- package/src/__tests__/type-mapping.test.ts +332 -0
- package/src/__tests__/typescript.test.ts +202 -0
- package/src/cobol-copybook/index.ts +220 -0
- package/src/cobol-copybook/test/.gitignore +6 -0
- package/src/cobol-copybook/test/Dockerfile +7 -0
- package/src/csharp-records/index.ts +131 -0
- package/src/csharp-records/test/.gitignore +6 -0
- package/src/csharp-records/test/Dockerfile +6 -0
- package/src/diesel/index.ts +247 -0
- package/src/diesel/test/.gitignore +6 -0
- package/src/diesel/test/Dockerfile +16 -0
- package/src/drizzle/index.ts +255 -0
- package/src/drizzle/test/.gitignore +6 -0
- package/src/drizzle/test/Dockerfile +8 -0
- package/src/drizzle/test/test.ts +71 -0
- package/src/efcore/index.ts +190 -0
- package/src/efcore/test/.gitignore +6 -0
- package/src/efcore/test/Dockerfile +7 -0
- package/src/go-structs/index.ts +119 -0
- package/src/go-structs/test/.gitignore +6 -0
- package/src/go-structs/test/Dockerfile +13 -0
- package/src/go-structs/test/test.go +71 -0
- package/src/gorm/index.ts +134 -0
- package/src/gorm/test/.gitignore +6 -0
- package/src/gorm/test/Dockerfile +13 -0
- package/src/gorm/test/test.go +65 -0
- package/src/helpers/atlas.ts +43 -0
- package/src/helpers/enrich.ts +396 -0
- package/src/helpers/naming.ts +19 -0
- package/src/helpers/tags.ts +63 -0
- package/src/index.ts +24 -0
- package/src/java-records/index.ts +179 -0
- package/src/java-records/test/.gitignore +6 -0
- package/src/java-records/test/Dockerfile +11 -0
- package/src/java-records/test/Test.java +93 -0
- package/src/jpa/index.ts +279 -0
- package/src/jpa/test/.gitignore +6 -0
- package/src/jpa/test/Dockerfile +14 -0
- package/src/jpa/test/Test.java +111 -0
- package/src/json-schema/index.ts +351 -0
- package/src/json-schema/test/.gitignore +6 -0
- package/src/json-schema/test/Dockerfile +18 -0
- package/src/knex/index.ts +168 -0
- package/src/knex/test/.gitignore +6 -0
- package/src/knex/test/Dockerfile +7 -0
- package/src/knex/test/test.ts +75 -0
- package/src/kotlin-data/index.ts +147 -0
- package/src/kotlin-data/test/.gitignore +6 -0
- package/src/kotlin-data/test/Dockerfile +14 -0
- package/src/kotlin-data/test/Test.kt +82 -0
- package/src/kysely/index.ts +165 -0
- package/src/kysely/test/.gitignore +6 -0
- package/src/kysely/test/Dockerfile +8 -0
- package/src/kysely/test/test.ts +82 -0
- package/src/prisma/index.ts +387 -0
- package/src/prisma/test/.gitignore +6 -0
- package/src/prisma/test/Dockerfile +7 -0
- package/src/protobuf/index.ts +219 -0
- package/src/protobuf/test/.gitignore +6 -0
- package/src/protobuf/test/Dockerfile +6 -0
- package/src/pydantic/index.ts +272 -0
- package/src/pydantic/test/.gitignore +6 -0
- package/src/pydantic/test/Dockerfile +8 -0
- package/src/pydantic/test/test.py +63 -0
- package/src/python-dataclasses/index.ts +217 -0
- package/src/python-dataclasses/test/.gitignore +6 -0
- package/src/python-dataclasses/test/Dockerfile +8 -0
- package/src/python-dataclasses/test/test.py +63 -0
- package/src/rust-structs/index.ts +152 -0
- package/src/rust-structs/test/.gitignore +6 -0
- package/src/rust-structs/test/Dockerfile +22 -0
- package/src/rust-structs/test/test.rs +82 -0
- package/src/sqlalchemy/index.ts +258 -0
- package/src/sqlalchemy/test/.gitignore +6 -0
- package/src/sqlalchemy/test/Dockerfile +8 -0
- package/src/sqlalchemy/test/test.py +61 -0
- package/src/sqlc/index.ts +148 -0
- package/src/sqlc/test/.gitignore +6 -0
- package/src/sqlc/test/Dockerfile +13 -0
- package/src/sqlc/test/test.go +91 -0
- package/src/tags/dedent.ts +28 -0
- package/src/tags/index.ts +14 -0
- package/src/types/index.ts +8 -0
- package/src/types/pg-to-csharp.ts +136 -0
- package/src/types/pg-to-go.ts +120 -0
- package/src/types/pg-to-java.ts +141 -0
- package/src/types/pg-to-kotlin.ts +119 -0
- package/src/types/pg-to-python.ts +120 -0
- package/src/types/pg-to-rust.ts +121 -0
- package/src/types/pg-to-ts.ts +173 -0
- package/src/typescript/index.ts +168 -0
- package/src/typescript/test/.gitignore +6 -0
- package/src/typescript/test/Dockerfile +8 -0
- package/src/typescript/test/test.ts +89 -0
- package/src/xsd/index.ts +191 -0
- package/src/xsd/test/.gitignore +6 -0
- package/src/xsd/test/Dockerfile +6 -0
- package/src/zod/index.ts +289 -0
- package/src/zod/test/.gitignore +6 -0
- package/src/zod/test/Dockerfile +6 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import { defineTemplate } from '@sqldoc/ns-codegen'
|
|
2
|
+
import { activeTables, enrichRealm } from '../helpers/enrich.ts'
|
|
3
|
+
import { toPascalCase } from '../helpers/naming.ts'
|
|
4
|
+
import { pgToRust } from '../types/pg-to-rust.ts'
|
|
5
|
+
|
|
6
|
+
export default defineTemplate({
|
|
7
|
+
name: 'Rust Structs',
|
|
8
|
+
description: 'Generate Rust structs with serde derives from SQL schema',
|
|
9
|
+
language: 'rust',
|
|
10
|
+
|
|
11
|
+
generate(ctx) {
|
|
12
|
+
const schema = enrichRealm(ctx)
|
|
13
|
+
const allImports = new Set<string>()
|
|
14
|
+
allImports.add('serde::{Serialize, Deserialize}')
|
|
15
|
+
const structs: string[] = []
|
|
16
|
+
|
|
17
|
+
// Enums
|
|
18
|
+
for (const e of schema.enums) {
|
|
19
|
+
const enumName = toPascalCase(e.name)
|
|
20
|
+
const variants = e.values.map((v) => {
|
|
21
|
+
const variantName = toPascalCase(v)
|
|
22
|
+
return ` #[serde(rename = "${v}")]\n ${variantName},`
|
|
23
|
+
})
|
|
24
|
+
structs.push(`#[derive(Debug, Clone, Serialize, Deserialize)]\npub enum ${enumName} {\n${variants.join('\n')}\n}`)
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
// Composite types as structs
|
|
28
|
+
const composites = new Map<string, Array<{ name: string; type: string }>>()
|
|
29
|
+
for (const table of schema.tables) {
|
|
30
|
+
for (const col of table.columns) {
|
|
31
|
+
if (col.category === 'composite' && col.compositeFields?.length && !composites.has(col.pgType)) {
|
|
32
|
+
composites.set(col.pgType, col.compositeFields)
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
for (const [name, fields] of composites) {
|
|
37
|
+
const structName = toPascalCase(name)
|
|
38
|
+
const rustFields = fields.map((f) => {
|
|
39
|
+
const mapped = pgToRust(f.type, false)
|
|
40
|
+
for (const imp of mapped.imports) allImports.add(imp)
|
|
41
|
+
return ` pub ${f.name}: ${mapped.type},`
|
|
42
|
+
})
|
|
43
|
+
structs.push(
|
|
44
|
+
`#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ${structName} {\n${rustFields.join('\n')}\n}`,
|
|
45
|
+
)
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
for (const table of activeTables(schema)) {
|
|
49
|
+
const fields: string[] = []
|
|
50
|
+
for (const col of table.columns) {
|
|
51
|
+
let rustType: string
|
|
52
|
+
if (col.typeOverride) {
|
|
53
|
+
rustType = col.nullable ? `Option<${col.typeOverride}>` : col.typeOverride
|
|
54
|
+
} else if (col.category === 'enum' && col.enumValues?.length) {
|
|
55
|
+
const enumType = toPascalCase(col.pgType)
|
|
56
|
+
rustType = col.nullable ? `Option<${enumType}>` : enumType
|
|
57
|
+
} else if (col.category === 'composite' && col.compositeFields?.length) {
|
|
58
|
+
const compositeType = toPascalCase(col.pgType)
|
|
59
|
+
rustType = col.nullable ? `Option<${compositeType}>` : compositeType
|
|
60
|
+
} else {
|
|
61
|
+
const mapped = pgToRust(col.pgType, col.nullable, col.category)
|
|
62
|
+
rustType = mapped.type
|
|
63
|
+
for (const imp of mapped.imports) allImports.add(imp)
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// snake_case is the Rust convention, which matches PostgreSQL column names
|
|
67
|
+
const fieldName = col.name
|
|
68
|
+
fields.push(` pub ${fieldName}: ${rustType},`)
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
structs.push('#[derive(Debug, Clone, Serialize, Deserialize)]')
|
|
72
|
+
structs.push(`pub struct ${table.pascalName} {`)
|
|
73
|
+
structs.push(fields.join('\n'))
|
|
74
|
+
structs.push('}')
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Views (read-only)
|
|
78
|
+
for (const view of schema.views.filter((v) => !v.skipped)) {
|
|
79
|
+
const fields: string[] = []
|
|
80
|
+
for (const col of view.columns) {
|
|
81
|
+
let rustType: string
|
|
82
|
+
if (col.typeOverride) {
|
|
83
|
+
rustType = col.nullable ? `Option<${col.typeOverride}>` : col.typeOverride
|
|
84
|
+
} else if (col.category === 'enum' && col.enumValues?.length) {
|
|
85
|
+
const enumType = toPascalCase(col.pgType)
|
|
86
|
+
rustType = col.nullable ? `Option<${enumType}>` : enumType
|
|
87
|
+
} else if (col.category === 'composite' && col.compositeFields?.length) {
|
|
88
|
+
const compositeType = toPascalCase(col.pgType)
|
|
89
|
+
rustType = col.nullable ? `Option<${compositeType}>` : compositeType
|
|
90
|
+
} else {
|
|
91
|
+
const mapped = pgToRust(col.pgType, col.nullable, col.category)
|
|
92
|
+
rustType = mapped.type
|
|
93
|
+
for (const imp of mapped.imports) allImports.add(imp)
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
fields.push(` pub ${col.name}: ${rustType},`)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
structs.push(`/// Read-only (from view)`)
|
|
100
|
+
structs.push('#[derive(Debug, Clone, Serialize, Deserialize)]')
|
|
101
|
+
structs.push(`pub struct ${view.pascalName} {`)
|
|
102
|
+
structs.push(fields.join('\n'))
|
|
103
|
+
structs.push('}')
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// Functions (skip trigger functions)
|
|
107
|
+
for (const fn of schema.functions) {
|
|
108
|
+
const retRaw = fn.returnType?.type?.toLowerCase() ?? ''
|
|
109
|
+
if (retRaw === 'trigger') continue
|
|
110
|
+
|
|
111
|
+
let retType: string
|
|
112
|
+
if (retRaw.startsWith('setof ')) {
|
|
113
|
+
const tableName = retRaw.replace('setof ', '')
|
|
114
|
+
const table = schema.tables.find((t) => t.name === tableName)
|
|
115
|
+
retType = table ? `Vec<${table.pascalName}>` : `Vec<${toPascalCase(tableName)}>`
|
|
116
|
+
} else if (fn.returnType) {
|
|
117
|
+
const mapped = pgToRust(fn.returnType.type, false, fn.returnType.category)
|
|
118
|
+
retType = mapped.type
|
|
119
|
+
for (const imp of mapped.imports) allImports.add(imp)
|
|
120
|
+
} else {
|
|
121
|
+
retType = '()'
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
const argTypes = fn.args
|
|
125
|
+
.filter((a) => !a.name?.startsWith('_') && (a as any).mode !== 'OUT')
|
|
126
|
+
.map((a) => {
|
|
127
|
+
const mapped = pgToRust(a.type, false, a.category)
|
|
128
|
+
for (const imp of mapped.imports) allImports.add(imp)
|
|
129
|
+
return mapped.type
|
|
130
|
+
})
|
|
131
|
+
|
|
132
|
+
structs.push(`pub type ${fn.pascalName} = fn(${argTypes.join(', ')}) -> ${retType};`)
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if (structs.length === 0) {
|
|
136
|
+
return { files: [] }
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
const sortedImports = [...allImports].sort()
|
|
140
|
+
const useLines = sortedImports.map((imp) => `use ${imp};`)
|
|
141
|
+
|
|
142
|
+
const parts: string[] = []
|
|
143
|
+
parts.push(useLines.join('\n'))
|
|
144
|
+
parts.push('')
|
|
145
|
+
parts.push(structs.join('\n\n'))
|
|
146
|
+
parts.push('')
|
|
147
|
+
|
|
148
|
+
return {
|
|
149
|
+
files: [{ path: 'models.rs', content: parts.join('\n') }],
|
|
150
|
+
}
|
|
151
|
+
},
|
|
152
|
+
})
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
FROM rust:1.85-slim
|
|
2
|
+
RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/*
|
|
3
|
+
WORKDIR /app
|
|
4
|
+
RUN cargo init --name typecheck .
|
|
5
|
+
RUN cat >> Cargo.toml <<'TOML'
|
|
6
|
+
serde = { version = "1", features = ["derive"] }
|
|
7
|
+
serde_json = "1"
|
|
8
|
+
chrono = { version = "0.4", features = ["serde"] }
|
|
9
|
+
uuid = { version = "1", features = ["serde"] }
|
|
10
|
+
bigdecimal = { version = "0.4", features = ["serde"] }
|
|
11
|
+
tokio = { version = "1", features = ["full"] }
|
|
12
|
+
tokio-postgres = "0.7"
|
|
13
|
+
TOML
|
|
14
|
+
RUN rm src/main.rs
|
|
15
|
+
COPY models.rs src/models.rs
|
|
16
|
+
COPY test.rs src/main.rs
|
|
17
|
+
# Prepend mod declaration to main.rs
|
|
18
|
+
RUN sed -i '1s/^/mod models;\n/' src/main.rs
|
|
19
|
+
# Step 1: compile (typechecks models + test)
|
|
20
|
+
RUN cargo build --release
|
|
21
|
+
# Step 2: run integration test against real DB
|
|
22
|
+
CMD ["./target/release/typecheck"]
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
// Integration test for @sqldoc/templates/rust-structs
|
|
2
|
+
// Connects to real Postgres, verifies generated structs work with actual data.
|
|
3
|
+
//
|
|
4
|
+
// Note: `mod models;` is prepended by the Dockerfile.
|
|
5
|
+
|
|
6
|
+
use models::{Users, Posts};
|
|
7
|
+
|
|
8
|
+
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
9
|
+
|
|
10
|
+
static FAILED: AtomicUsize = AtomicUsize::new(0);
|
|
11
|
+
|
|
12
|
+
fn assert_eq_val<T: PartialEq + std::fmt::Debug>(actual: T, expected: T, msg: &str) {
|
|
13
|
+
if actual != expected {
|
|
14
|
+
eprintln!("FAIL: {} (got {:?}, expected {:?})", msg, actual, expected);
|
|
15
|
+
FAILED.fetch_add(1, Ordering::SeqCst);
|
|
16
|
+
} else {
|
|
17
|
+
println!(" ok: {}", msg);
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
#[tokio::main]
|
|
22
|
+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
23
|
+
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL not set");
|
|
24
|
+
|
|
25
|
+
let (client, connection) = tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await?;
|
|
26
|
+
|
|
27
|
+
tokio::spawn(async move {
|
|
28
|
+
if let Err(e) = connection.await {
|
|
29
|
+
eprintln!("connection error: {}", e);
|
|
30
|
+
}
|
|
31
|
+
});
|
|
32
|
+
|
|
33
|
+
println!("--- rust-structs integration test ---");
|
|
34
|
+
|
|
35
|
+
// 1. Query user and populate generated struct
|
|
36
|
+
let row = client
|
|
37
|
+
.query_one("SELECT id, email, name, age, is_active FROM users WHERE id = 1", &[])
|
|
38
|
+
.await?;
|
|
39
|
+
let user = Users {
|
|
40
|
+
id: row.get("id"),
|
|
41
|
+
email: row.get("email"),
|
|
42
|
+
name: row.get("name"),
|
|
43
|
+
age: row.get("age"),
|
|
44
|
+
is_active: row.get("is_active"),
|
|
45
|
+
metadata: None,
|
|
46
|
+
address: None,
|
|
47
|
+
created_at: chrono::Utc::now(),
|
|
48
|
+
tags: None,
|
|
49
|
+
avatar: None,
|
|
50
|
+
balance: None,
|
|
51
|
+
external_id: None,
|
|
52
|
+
};
|
|
53
|
+
assert_eq_val(user.email.as_str(), "test@example.com", "user.email matches");
|
|
54
|
+
assert_eq_val(user.name, Some("Test User".to_string()), "user.name matches");
|
|
55
|
+
assert_eq_val(user.age, Some(30), "user.age matches");
|
|
56
|
+
assert_eq_val(user.is_active, true, "user.is_active matches");
|
|
57
|
+
|
|
58
|
+
// 2. Query post and populate generated struct
|
|
59
|
+
let row = client
|
|
60
|
+
.query_one("SELECT id, user_id, title, body, view_count FROM posts WHERE id = 1", &[])
|
|
61
|
+
.await?;
|
|
62
|
+
let post = Posts {
|
|
63
|
+
id: row.get("id"),
|
|
64
|
+
user_id: row.get("user_id"),
|
|
65
|
+
title: row.get("title"),
|
|
66
|
+
body: row.get("body"),
|
|
67
|
+
published_at: None,
|
|
68
|
+
view_count: row.get("view_count"),
|
|
69
|
+
rating: None,
|
|
70
|
+
};
|
|
71
|
+
assert_eq_val(post.title.as_str(), "Hello World", "post.title matches");
|
|
72
|
+
assert_eq_val(post.user_id, 1i64, "post.user_id matches");
|
|
73
|
+
assert_eq_val(post.view_count, 42i32, "post.view_count matches");
|
|
74
|
+
|
|
75
|
+
let fail_count = FAILED.load(Ordering::SeqCst);
|
|
76
|
+
if fail_count > 0 {
|
|
77
|
+
eprintln!("\n{} assertion(s) failed", fail_count);
|
|
78
|
+
std::process::exit(1);
|
|
79
|
+
}
|
|
80
|
+
println!("\nAll assertions passed!");
|
|
81
|
+
Ok(())
|
|
82
|
+
}
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import { defineTemplate } from '@sqldoc/ns-codegen'
|
|
2
|
+
import { activeTables, enrichRealm } from '../helpers/enrich.ts'
|
|
3
|
+
import { toPascalCase, toScreamingSnake } from '../helpers/naming.ts'
|
|
4
|
+
import { pgToPython } from '../types/pg-to-python.ts'
|
|
5
|
+
|
|
6
|
+
/** Map PostgreSQL types to SQLAlchemy Column types */
|
|
7
|
+
const PG_TO_SA: Record<string, string> = {
|
|
8
|
+
smallint: 'SmallInteger',
|
|
9
|
+
int2: 'SmallInteger',
|
|
10
|
+
integer: 'Integer',
|
|
11
|
+
int: 'Integer',
|
|
12
|
+
int4: 'Integer',
|
|
13
|
+
bigint: 'BigInteger',
|
|
14
|
+
int8: 'BigInteger',
|
|
15
|
+
serial: 'Integer',
|
|
16
|
+
serial4: 'Integer',
|
|
17
|
+
bigserial: 'BigInteger',
|
|
18
|
+
serial8: 'BigInteger',
|
|
19
|
+
smallserial: 'SmallInteger',
|
|
20
|
+
serial2: 'SmallInteger',
|
|
21
|
+
real: 'Float',
|
|
22
|
+
float4: 'Float',
|
|
23
|
+
'double precision': 'Float',
|
|
24
|
+
float8: 'Float',
|
|
25
|
+
numeric: 'Numeric',
|
|
26
|
+
decimal: 'Numeric',
|
|
27
|
+
money: 'Numeric',
|
|
28
|
+
text: 'Text',
|
|
29
|
+
varchar: 'String',
|
|
30
|
+
'character varying': 'String',
|
|
31
|
+
char: 'String',
|
|
32
|
+
character: 'String',
|
|
33
|
+
name: 'String',
|
|
34
|
+
citext: 'Text',
|
|
35
|
+
boolean: 'Boolean',
|
|
36
|
+
bool: 'Boolean',
|
|
37
|
+
timestamp: 'DateTime',
|
|
38
|
+
'timestamp without time zone': 'DateTime',
|
|
39
|
+
timestamptz: 'DateTime',
|
|
40
|
+
'timestamp with time zone': 'DateTime',
|
|
41
|
+
date: 'Date',
|
|
42
|
+
time: 'Time',
|
|
43
|
+
'time without time zone': 'Time',
|
|
44
|
+
timetz: 'Time',
|
|
45
|
+
'time with time zone': 'Time',
|
|
46
|
+
interval: 'Interval',
|
|
47
|
+
bytea: 'LargeBinary',
|
|
48
|
+
json: 'JSON',
|
|
49
|
+
jsonb: 'JSON',
|
|
50
|
+
uuid: 'String',
|
|
51
|
+
inet: 'String',
|
|
52
|
+
cidr: 'String',
|
|
53
|
+
macaddr: 'String',
|
|
54
|
+
macaddr8: 'String',
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
/** Python imports needed for specific types (used by composite dataclasses) */
|
|
58
|
+
const TYPE_IMPORTS: Record<string, string> = {
|
|
59
|
+
datetime: 'from datetime import datetime',
|
|
60
|
+
date: 'from datetime import date',
|
|
61
|
+
time: 'from datetime import time',
|
|
62
|
+
timedelta: 'from datetime import timedelta',
|
|
63
|
+
Decimal: 'from decimal import Decimal',
|
|
64
|
+
UUID: 'from uuid import UUID',
|
|
65
|
+
Any: 'from typing import Any',
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
/** Attribute names reserved by SQLAlchemy's declarative base */
|
|
69
|
+
const SA_RESERVED = new Set(['metadata', 'registry', 'query', 'query_class'])
|
|
70
|
+
|
|
71
|
+
export default defineTemplate({
|
|
72
|
+
name: 'SQLAlchemy Models',
|
|
73
|
+
description: 'Generate SQLAlchemy ORM model classes with Column definitions from SQL schema',
|
|
74
|
+
language: 'python',
|
|
75
|
+
|
|
76
|
+
generate(ctx) {
|
|
77
|
+
const schema = enrichRealm(ctx)
|
|
78
|
+
const saTypes = new Set<string>()
|
|
79
|
+
const needsForeignKey = { value: false }
|
|
80
|
+
const needsARRAY = { value: false }
|
|
81
|
+
const needsEnum = { value: false }
|
|
82
|
+
const needsDataclass = { value: false }
|
|
83
|
+
const extraImports = new Set<string>()
|
|
84
|
+
const modelBlocks: string[] = []
|
|
85
|
+
const enumBlocks: string[] = []
|
|
86
|
+
const compositeBlocks: string[] = []
|
|
87
|
+
|
|
88
|
+
// Enums
|
|
89
|
+
for (const e of schema.enums) {
|
|
90
|
+
needsEnum.value = true
|
|
91
|
+
const className = toPascalCase(e.name)
|
|
92
|
+
const members = e.values.map((v) => ` ${toScreamingSnake(v)} = "${v}"`).join('\n')
|
|
93
|
+
enumBlocks.push(`class ${className}(str, enum.Enum):\n${members}`)
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Composite types (collected from columns, rendered as dataclasses)
|
|
97
|
+
const composites = new Map<string, Array<{ name: string; type: string }>>()
|
|
98
|
+
for (const table of schema.tables) {
|
|
99
|
+
for (const col of table.columns) {
|
|
100
|
+
if (col.category === 'composite' && col.compositeFields?.length && !composites.has(col.pgType)) {
|
|
101
|
+
composites.set(col.pgType, col.compositeFields)
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
for (const [name, fields] of composites) {
|
|
106
|
+
needsDataclass.value = true
|
|
107
|
+
const typeName = toPascalCase(name)
|
|
108
|
+
const fieldLines: string[] = []
|
|
109
|
+
for (const f of fields) {
|
|
110
|
+
const pyType = pgToPython(f.type, false)
|
|
111
|
+
collectTypeImports(pyType, extraImports)
|
|
112
|
+
fieldLines.push(` ${f.name}: ${pyType}`)
|
|
113
|
+
}
|
|
114
|
+
compositeBlocks.push(`@dataclass\nclass ${typeName}:\n${fieldLines.join('\n')}`)
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
for (const table of activeTables(schema)) {
|
|
118
|
+
const lines: string[] = []
|
|
119
|
+
lines.push(`class ${table.pascalName}(Base):`)
|
|
120
|
+
lines.push(` __tablename__ = '${table.name}'`)
|
|
121
|
+
lines.push('')
|
|
122
|
+
|
|
123
|
+
for (const col of table.columns) {
|
|
124
|
+
let saType: string
|
|
125
|
+
|
|
126
|
+
if (col.typeOverride) {
|
|
127
|
+
saType = col.typeOverride
|
|
128
|
+
} else if (col.category === 'enum' && col.enumValues?.length) {
|
|
129
|
+
saType = `Enum(${toPascalCase(col.pgType)})`
|
|
130
|
+
saTypes.add('Enum')
|
|
131
|
+
} else {
|
|
132
|
+
saType = mapToSAType(col.pgType, saTypes, needsARRAY)
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
const columnArgs: string[] = [saType]
|
|
136
|
+
|
|
137
|
+
// Foreign key
|
|
138
|
+
if (col.foreignKey) {
|
|
139
|
+
columnArgs.push(`ForeignKey('${col.foreignKey.table}.${col.foreignKey.column}')`)
|
|
140
|
+
needsForeignKey.value = true
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
// Column kwargs
|
|
144
|
+
const kwargs: string[] = []
|
|
145
|
+
if (col.isPrimaryKey) kwargs.push('primary_key=True')
|
|
146
|
+
if (!col.nullable && !col.isPrimaryKey) kwargs.push('nullable=False')
|
|
147
|
+
|
|
148
|
+
const allArgs = [...columnArgs, ...kwargs].join(', ')
|
|
149
|
+
const attr = SA_RESERVED.has(col.name) ? `${col.name}_` : col.name
|
|
150
|
+
const colName = attr !== col.name ? `'${col.name}', ` : ''
|
|
151
|
+
lines.push(` ${attr} = Column(${colName}${allArgs})`)
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
modelBlocks.push(lines.join('\n'))
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
// Views (read-only, plain Table objects — no ORM PK requirement)
|
|
158
|
+
const needsTable = { value: false }
|
|
159
|
+
const viewBlocks: string[] = []
|
|
160
|
+
for (const view of schema.views.filter((v) => !v.skipped)) {
|
|
161
|
+
needsTable.value = true
|
|
162
|
+
const colDefs: string[] = []
|
|
163
|
+
for (const col of view.columns) {
|
|
164
|
+
let saType: string
|
|
165
|
+
if (col.typeOverride) {
|
|
166
|
+
saType = col.typeOverride
|
|
167
|
+
} else if (col.category === 'enum' && col.enumValues?.length) {
|
|
168
|
+
saType = `Enum(${toPascalCase(col.pgType)})`
|
|
169
|
+
saTypes.add('Enum')
|
|
170
|
+
} else {
|
|
171
|
+
saType = mapToSAType(col.pgType, saTypes, needsARRAY)
|
|
172
|
+
}
|
|
173
|
+
colDefs.push(` Column('${col.name}', ${saType}),`)
|
|
174
|
+
}
|
|
175
|
+
viewBlocks.push(`${view.name} = Table(\n '${view.name}',\n Base.metadata,\n${colDefs.join('\n')}\n)`)
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// Collect needed SA type imports
|
|
179
|
+
const saImportTypes = ['Column', ...saTypes].sort()
|
|
180
|
+
if (needsForeignKey.value) saImportTypes.push('ForeignKey')
|
|
181
|
+
if (needsARRAY.value) saImportTypes.push('ARRAY')
|
|
182
|
+
if (needsTable.value) saImportTypes.push('Table')
|
|
183
|
+
|
|
184
|
+
const importLines: string[] = [
|
|
185
|
+
`from sqlalchemy import ${saImportTypes.join(', ')}`,
|
|
186
|
+
'from sqlalchemy.orm import declarative_base',
|
|
187
|
+
]
|
|
188
|
+
if (needsDataclass.value) {
|
|
189
|
+
importLines.push('from dataclasses import dataclass')
|
|
190
|
+
}
|
|
191
|
+
if (needsEnum.value) {
|
|
192
|
+
importLines.push('import enum')
|
|
193
|
+
}
|
|
194
|
+
for (const imp of [...extraImports].sort()) {
|
|
195
|
+
importLines.push(imp)
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
const preModelBlocks = [...enumBlocks, ...compositeBlocks]
|
|
199
|
+
|
|
200
|
+
const content = `# Generated by @sqldoc/templates/sqlalchemy -- DO NOT EDIT
|
|
201
|
+
|
|
202
|
+
${importLines.join('\n')}
|
|
203
|
+
|
|
204
|
+
Base = declarative_base()
|
|
205
|
+
|
|
206
|
+
${preModelBlocks.length > 0 ? `\n${preModelBlocks.join('\n\n\n')}\n\n` : ''}
|
|
207
|
+
${modelBlocks.join('\n\n\n')}
|
|
208
|
+
${viewBlocks.length > 0 ? `\n\n# Views\n${viewBlocks.join('\n\n\n')}\n` : ''}
|
|
209
|
+
`
|
|
210
|
+
|
|
211
|
+
return {
|
|
212
|
+
files: [{ path: 'models.py', content }],
|
|
213
|
+
}
|
|
214
|
+
},
|
|
215
|
+
})
|
|
216
|
+
|
|
217
|
+
function collectTypeImports(pyType: string, imports: Set<string>): void {
|
|
218
|
+
const match = pyType.match(/^Optional\[(.+)\]$/)
|
|
219
|
+
const baseType = match ? match[1] : pyType
|
|
220
|
+
const listMatch = baseType.match(/^list\[(.+)\]$/)
|
|
221
|
+
const innerType = listMatch ? listMatch[1] : baseType
|
|
222
|
+
|
|
223
|
+
if (TYPE_IMPORTS[innerType]) {
|
|
224
|
+
imports.add(TYPE_IMPORTS[innerType])
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
function mapToSAType(pgType: string, saTypes: Set<string>, needsARRAY: { value: boolean }): string {
|
|
229
|
+
const normalized = pgType.toLowerCase().trim()
|
|
230
|
+
|
|
231
|
+
// Handle arrays
|
|
232
|
+
if (normalized.endsWith('[]') || normalized.startsWith('_')) {
|
|
233
|
+
const base = normalized.endsWith('[]') ? normalized.slice(0, -2) : normalized.slice(1)
|
|
234
|
+
const innerType = mapToSAType(base, saTypes, needsARRAY)
|
|
235
|
+
needsARRAY.value = true
|
|
236
|
+
return `ARRAY(${innerType})`
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
// Handle varchar(n)
|
|
240
|
+
const lengthMatch = normalized.match(/^(?:varchar|character varying|char|character)\((\d+)\)$/)
|
|
241
|
+
if (lengthMatch) {
|
|
242
|
+
saTypes.add('String')
|
|
243
|
+
return `String(${lengthMatch[1]})`
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// Handle numeric(p,s)
|
|
247
|
+
const numericMatch = normalized.match(/^(?:numeric|decimal)\((\d+),\s*(\d+)\)$/)
|
|
248
|
+
if (numericMatch) {
|
|
249
|
+
saTypes.add('Numeric')
|
|
250
|
+
return `Numeric(${numericMatch[1]}, ${numericMatch[2]})`
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
// Strip remaining length specifiers for lookup
|
|
254
|
+
const baseType = normalized.replace(/\(\d+(?:,\s*\d+)?\)/, '').trim()
|
|
255
|
+
const saType = PG_TO_SA[baseType] ?? 'String'
|
|
256
|
+
saTypes.add(saType)
|
|
257
|
+
return saType
|
|
258
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Integration test for @sqldoc/templates/sqlalchemy
|
|
3
|
+
Connects to real Postgres, verifies generated SQLAlchemy models work with actual data.
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import create_engine, select
|
|
9
|
+
from sqlalchemy.orm import Session
|
|
10
|
+
|
|
11
|
+
import models
|
|
12
|
+
|
|
13
|
+
DATABASE_URL = os.environ.get("DATABASE_URL")
|
|
14
|
+
if not DATABASE_URL:
|
|
15
|
+
print("DATABASE_URL not set", file=sys.stderr)
|
|
16
|
+
sys.exit(1)
|
|
17
|
+
|
|
18
|
+
failed = 0
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def assert_eq(actual, expected, msg):
|
|
22
|
+
global failed
|
|
23
|
+
if actual != expected:
|
|
24
|
+
print(f"FAIL: {msg} (got {actual!r}, expected {expected!r})", file=sys.stderr)
|
|
25
|
+
failed += 1
|
|
26
|
+
else:
|
|
27
|
+
print(f" ok: {msg}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def main():
|
|
31
|
+
global failed
|
|
32
|
+
engine = create_engine(DATABASE_URL)
|
|
33
|
+
|
|
34
|
+
print("--- sqlalchemy integration test ---")
|
|
35
|
+
|
|
36
|
+
with Session(engine) as session:
|
|
37
|
+
# 1. Query known seeded user via ORM model
|
|
38
|
+
user = session.query(models.Users).filter_by(id=1).one()
|
|
39
|
+
assert_eq(user.email, "test@example.com", "user email matches")
|
|
40
|
+
assert_eq(user.name, "Test User", "user name matches")
|
|
41
|
+
assert_eq(user.age, 30, "user age matches")
|
|
42
|
+
assert_eq(user.is_active, True, "user is_active matches")
|
|
43
|
+
|
|
44
|
+
# 2. Query known seeded post via ORM model
|
|
45
|
+
post = session.query(models.Posts).filter_by(id=1).one()
|
|
46
|
+
assert_eq(post.title, "Hello World", "post title matches")
|
|
47
|
+
|
|
48
|
+
# 3. Query view via Table object
|
|
49
|
+
rows = session.execute(select(models.active_users)).fetchall()
|
|
50
|
+
assert_eq(len(rows) >= 1, True, "active_users view returns rows")
|
|
51
|
+
row = rows[0]
|
|
52
|
+
assert_eq(row.email, "test@example.com", "view email matches")
|
|
53
|
+
|
|
54
|
+
if failed > 0:
|
|
55
|
+
print(f"\n{failed} assertion(s) failed", file=sys.stderr)
|
|
56
|
+
sys.exit(1)
|
|
57
|
+
print("\nAll assertions passed!")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
if __name__ == "__main__":
|
|
61
|
+
main()
|