neopg 2.0.1 → 2.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/lib/ModelChain.js CHANGED
@@ -2,9 +2,8 @@
2
2
 
3
3
  const makeId = require('./makeId.js')
4
4
  const makeTimestamp = require('./makeTimestamp.js')
5
- const TransactionScope = require('./TransactionScope.js')
6
5
 
7
- // [优化 1] 提取常量定义到类外部,提升性能
6
+ // 提取常量定义
8
7
  const INT_TYPES = new Set([
9
8
  'int', 'integer', 'smallint', 'bigint',
10
9
  'serial', 'bigserial', 'smallserial',
@@ -16,16 +15,7 @@ const FLOAT_TYPES = new Set([
16
15
  'money', 'double precision', 'float4', 'float8'
17
16
  ])
18
17
 
19
- /**
20
- * ModelChain - 链式查询构建器
21
- * 负责运行时的查询构建、SQL 拼装和结果处理。
22
- */
23
18
  class ModelChain {
24
- /**
25
- * @param {Object} ctx - 上下文 (NeoPG 实例 或 TransactionScope 实例)
26
- * @param {ModelDef} def - 模型元数据
27
- * @param {string} schema - 数据库 schema
28
- */
29
19
  constructor(ctx, def, schema = 'public') {
30
20
  this.ctx = ctx
31
21
  this.def = def
@@ -34,102 +24,103 @@ class ModelChain {
34
24
  this.tableName = def.tableName
35
25
  this.schema = schema
36
26
 
37
- // --- 查询状态 (AST Lite) ---
27
+ // --- 查询状态 ---
38
28
  this._conditions = []
39
29
  this._order = []
40
30
  this._limit = null
41
31
  this._offset = null
42
32
  this._columns = null
43
- this._group = null
44
- this._lock = null
45
33
  this._returning = null
46
34
  this._joins = []
47
35
  this._group = []
36
+ this._lock = null
48
37
 
49
- // 内部状态标记
50
38
  this._isRaw = !!def.isRaw
39
+ this._executed = false
51
40
  }
52
41
 
53
- // --- 静态构造器 ---
54
42
  static from(schemaObject) {
55
43
  return class AnonymousModel extends ModelChain {
56
44
  static schema = schemaObject
57
45
  }
58
46
  }
59
47
 
60
- // --- 核心:链式调用 API ---
48
+ // --- 内部状态管理 ---
61
49
 
62
- /**
63
- * 添加 WHERE 条件
64
- * 1. .where(sql`age > ${age}`) -> 原生片段 (Query)
65
- * 2. .where({ a: 1 }) -> a=1
66
- * 3. .where('age', '>', 18) -> age > 18 (兼容)
67
- */
68
- where(arg1, arg2, arg3) {
69
- if (!arg1) return this
70
- // 1. Fragment 检测
71
- if (arg1.constructor && arg1.constructor.name === 'Query') {
72
- this._conditions.push(arg1)
73
- return this
74
- }
75
-
76
- // 2. Object 写法
77
- if (typeof arg1 === 'object' && !Array.isArray(arg1)) {
78
- for (const k of Object.keys(arg1)) {
79
- const v = arg1[k]
80
-
81
- if (v === undefined) continue
50
+ _ensureActive() {
51
+ if (this._executed) {
52
+ throw new Error(
53
+ `[NeoPG] ModelChain for '${this.tableName}' has already been executed. ` +
54
+ `Do NOT reuse the chain variable. Use .clone() if you need to fork queries.`
55
+ )
56
+ }
57
+ }
82
58
 
83
- if (v === null) this._conditions.push(this.sql`${this.sql(k)} IS NULL`)
84
- else if (Array.isArray(v)) this._conditions.push(this.sql`${this.sql(k)} IN ${this.sql(v)}`)
85
- else this._conditions.push(this.sql`${this.sql(k)} = ${v}`)
86
- }
87
-
88
- return this
89
- }
59
+ _destroy() {
60
+ this._executed = true
61
+ this.def = null
62
+ this.ctx = null
63
+ // this.sql = null // 可选:保留引用以便 debug,或者释放
64
+ }
90
65
 
91
- // 3. String 写法
92
- if (typeof arg1 === 'string') {
93
- // .where('age', '>', 18)
94
- if (arg3 !== undefined) {
95
- this._conditions.push(this.sql`${this.sql(arg1)} ${this.sql.unsafe(arg2)} ${arg3}`)
96
- return this
97
- }
98
- // .where('age', 18) -> age = 18
99
- if (arg2 !== undefined) {
100
- this._conditions.push(this.sql`${this.sql(arg1)} = ${arg2}`)
101
- return this
102
- }
103
- // .where('id = ?', 123)
104
- if (arg1.includes('?') && arg2 !== undefined) {
105
- const p = arg1.split('?');
106
- if(p.length===2) {
107
- this._conditions.push(this.sql`${this.sql.unsafe(p[0])}${arg2}${this.sql.unsafe(p[1])}`)
108
- return this
109
- }
110
- }
111
- // .where('1=1') -> Raw SQL
112
- // 注意:这里必须用 unsafe,否则 '1=1' 会被当成字符串值处理
113
- this._conditions.push(this.sql.unsafe(arg1))
114
- }
66
+ clone() {
67
+ this._ensureActive()
68
+ const copy = new ModelChain(this.ctx, this.def, this.schema)
69
+
70
+ // 拷贝状态
71
+ copy._conditions = [...this._conditions]
72
+ copy._joins = [...this._joins]
73
+ copy._group = [...this._group]
74
+ copy._order = [...this._order]
75
+
76
+ copy._limit = this._limit
77
+ copy._offset = this._offset
78
+ copy._lock = this._lock
79
+ if (this._columns) copy._columns = [...this._columns]
80
+ if (this._returning) copy._returning = [...this._returning]
115
81
 
116
- return this
82
+ return copy
117
83
  }
118
84
 
119
- whereIf(condition, arg1, arg2, arg3) {
120
- if (condition) return this.where(arg1, arg2, arg3)
121
- return this
122
- }
85
+ // --- 构建方法 (无检测,高性能) ---
123
86
 
124
87
  select(columns) {
125
88
  if (!columns) return this
126
-
127
89
  if (typeof columns === 'string') {
128
90
  this._columns = columns.split(',').map(s => s.trim())
129
91
  } else if (Array.isArray(columns)) {
130
92
  this._columns = columns
131
93
  }
94
+ return this
95
+ }
132
96
 
97
+ limit(count, offset = 0) {
98
+ this._limit = count
99
+ this._offset = offset
100
+ return this
101
+ }
102
+
103
+ page(pageIndex, pageSize) {
104
+ return this.limit(pageSize, (pageIndex - 1) * pageSize)
105
+ }
106
+
107
+ forUpdate() {
108
+ this._lock = this.sql`FOR UPDATE`
109
+ return this
110
+ }
111
+
112
+ forShare() {
113
+ this._lock = this.sql`FOR SHARE`
114
+ return this
115
+ }
116
+
117
+ returning(cols) {
118
+ if (!cols) return this
119
+ if (typeof cols === 'string') {
120
+ this._returning = cols.split(',').map(s => s.trim()).filter(s => s)
121
+ } else if (Array.isArray(cols)) {
122
+ this._returning = cols
123
+ }
133
124
  return this
134
125
  }
135
126
 
@@ -145,7 +136,6 @@ class ModelChain {
145
136
  for(const k in a) {
146
137
  this._order.push(this.sql`${this.sql(k)} ${this.sql.unsafe(a[k].toUpperCase())}`)
147
138
  }
148
-
149
139
  return this
150
140
  }
151
141
 
@@ -154,104 +144,131 @@ class ModelChain {
154
144
  if(a.includes(' ')) this._order.push(this.sql.unsafe(a));
155
145
  else this._order.push(this.sql`${this.sql(a)} ${this.sql.unsafe(d)}`);
156
146
  }
157
-
158
147
  return this
159
148
  }
160
149
 
161
- limit(count, offset = 0) {
162
- this._limit = count
163
- this._offset = offset
150
+ group(arg) {
151
+ if (!arg) return this
152
+ if (arg.constructor && arg.constructor.name === 'Query') {
153
+ this._group.push(arg)
154
+ return this
155
+ }
156
+ if (Array.isArray(arg)) {
157
+ arg.forEach(f => this.group(f))
158
+ return this
159
+ }
160
+ if (typeof arg === 'string') {
161
+ if (arg.includes(',')) {
162
+ arg.split(',').map(s => s.trim()).filter(s=>s).forEach(s => {
163
+ this._group.push(this.sql(s))
164
+ })
165
+ } else {
166
+ this._group.push(this.sql(arg))
167
+ }
168
+ }
164
169
  return this
165
170
  }
166
171
 
167
- page(pageIndex, pageSize) {
168
- return this.limit(pageSize, (pageIndex - 1) * pageSize)
169
- }
172
+ // --- 特殊构建:Where (需要检测) ---
170
173
 
171
- forUpdate() {
172
- this._lock = this.sql`FOR UPDATE`
173
- return this
174
- }
174
+ where(arg1, arg2, arg3) {
175
+ // ⚠️ 只有 where 需要提前检测,防止条件污染
176
+ this._ensureActive()
175
177
 
176
- forShare() {
177
- this._lock = this.sql`FOR SHARE`
178
- return this
179
- }
178
+ if (!arg1) return this
180
179
 
181
- returning(cols) {
182
- if (!cols) return this
180
+ if (arg1.constructor && arg1.constructor.name === 'Query') {
181
+ this._conditions.push(arg1)
182
+ return this
183
+ }
184
+
185
+ if (typeof arg1 === 'object' && !Array.isArray(arg1)) {
186
+ for (const k of Object.keys(arg1)) {
187
+ const v = arg1[k]
188
+ if (v === undefined) continue
189
+ if (v === null) this._conditions.push(this.sql`${this.sql(k)} IS NULL`)
190
+ else if (Array.isArray(v)) this._conditions.push(this.sql`${this.sql(k)} IN ${this.sql(v)}`)
191
+ else this._conditions.push(this.sql`${this.sql(k)} = ${v}`)
192
+ }
193
+ return this
194
+ }
183
195
 
184
- if (typeof cols === 'string') {
185
- // 支持 'id, name' 写法
186
- this._returning = cols.split(',').map(s => s.trim()).filter(s => s)
187
- } else if (Array.isArray(cols)) {
188
- this._returning = cols
189
- }
196
+ if (typeof arg1 === 'string') {
197
+ if (arg3 !== undefined) {
198
+ this._conditions.push(this.sql`${this.sql(arg1)} ${this.sql.unsafe(arg2)} ${arg3}`)
199
+ return this
200
+ }
201
+ if (arg2 !== undefined) {
202
+ this._conditions.push(this.sql`${this.sql(arg1)} = ${arg2}`)
203
+ return this
204
+ }
205
+ if (arg1.includes('?') && arg2 !== undefined) {
206
+ const p = arg1.split('?');
207
+ if(p.length===2) {
208
+ this._conditions.push(this.sql`${this.sql.unsafe(p[0])}${arg2}${this.sql.unsafe(p[1])}`)
209
+ return this
210
+ }
211
+ }
212
+ this._conditions.push(this.sql.unsafe(arg1))
213
+ }
214
+ return this
215
+ }
190
216
 
217
+ whereIf(condition, arg1, arg2, arg3) {
218
+ if (condition) return this.where(arg1, arg2, arg3)
191
219
  return this
192
220
  }
193
221
 
194
- // --- 构建 RETURNING 片段 ---
195
- _buildReturning() {
196
- // 如果没有设置 returning,默认不返回数据 (节省性能)
197
- // 注意:这意味着默认 insert/update 返回的是 Result 对象(包含 count),而不是行数据
198
- if (!this._returning || this._returning.length === 0) {
199
- return this.sql``
200
- }
201
-
202
- // 特殊处理 '*': 用户显式要求 returning('*')
203
- // 如果直接用 this.sql(['*']) 会被转义为 "*",导致错误
204
- if (this._returning.length === 1 && this._returning[0] === '*') {
205
- return this.sql`RETURNING *`
206
- }
207
-
208
- // 普通字段:利用 postgres.js 自动转义标识符
209
- return this.sql`RETURNING ${this.sql(this._returning)}`
210
- }
222
+ // --- SQL 构建辅助 (内部方法) ---
211
223
 
212
- // --- 辅助:构建 Where 片段 (修复 Bug 的核心) ---
213
224
  _buildWhere() {
214
225
  const len = this._conditions.length
215
226
  if (len === 0) return this.sql``
216
-
217
- // 只有一个条件,直接返回,零开销
218
- if (len === 1) {
219
- return this.sql`WHERE ${this._conditions[0]}`
220
- }
227
+ if (len === 1) return this.sql`WHERE ${this._conditions[0]}`
221
228
 
222
- // 预分配数组:N个条件需要 N-1 个 'AND',总长 2N-1
223
- // 使用 new Array 预分配内存,比 push 更快
224
229
  const parts = new Array(len * 2 - 1)
225
230
  const AND = this.sql.unsafe(' AND ')
226
-
227
231
  for (let i = 0; i < len; i++) {
228
- // 偶数位放条件
229
232
  parts[i * 2] = this._conditions[i]
230
- // 奇数位放 AND (除了最后一位)
231
- if (i < len - 1) {
232
- parts[i * 2 + 1] = AND
233
- }
233
+ if (i < len - 1) parts[i * 2 + 1] = AND
234
234
  }
235
-
236
- // postgres.js 会自动展开这个扁平数组,性能极高
237
235
  return this.sql`WHERE ${parts}`
238
236
  }
239
237
 
240
- // --- 辅助:构建 Order 片段 (修复 Bug) ---
238
+ _buildReturning() {
239
+ if (!this._returning || this._returning.length === 0) return this.sql``
240
+ if (this._returning.length === 1 && this._returning[0] === '*') return this.sql`RETURNING *`
241
+ return this.sql`RETURNING ${this.sql(this._returning)}`
242
+ }
243
+
241
244
  _buildOrder() {
242
245
  if (this._order.length === 0) return this.sql``
243
- // 数组直接传入模板,postgres.js 默认用逗号连接,这正是 ORDER BY 需要的
244
- // 不能用 this.sql(this._order),那样会试图转义为标识符
245
246
  return this.sql`ORDER BY ${this._order}`
246
247
  }
247
248
 
248
- // --- 核心:执行方法 ---
249
+ _buildJoins() {
250
+ const len = this._joins.length
251
+ if (len === 0) return this.sql``
252
+ if (len === 1) return this._joins[0]
249
253
 
250
- async find() {
254
+ const parts = new Array(len * 2 - 1)
255
+ const SPACE = this.sql.unsafe(' ')
256
+ for (let i = 0; i < len; i++) {
257
+ parts[i * 2] = this._joins[i]
258
+ if (i < len - 1) parts[i * 2 + 1] = SPACE
259
+ }
260
+ return this.sql`${parts}`
261
+ }
262
+
263
+ _buildGroup() {
264
+ if (this._group.length === 0) return this.sql``
265
+ return this.sql`GROUP BY ${this._group}`
266
+ }
267
+
268
+ _buildSelectQuery() {
251
269
  const t = this.sql(this.tableName)
252
270
  const c = this._columns ? this.sql(this._columns) : this.sql`*`
253
271
 
254
- // 修复:使用新方法构建
255
272
  const w = this._buildWhere()
256
273
  const o = this._buildOrder()
257
274
  const j = this._buildJoins()
@@ -262,343 +279,244 @@ class ModelChain {
262
279
  const lck = this._lock || this.sql``
263
280
  const ft = this.sql`${this.sql(this.schema)}.${t}`
264
281
 
265
- return await this.sql`SELECT ${c} FROM ${ft} ${j} ${w} ${g} ${o} ${l} ${off} ${lck}`
282
+ return this.sql`SELECT ${c} FROM ${ft} ${j} ${w} ${g} ${o} ${l} ${off} ${lck}`
283
+ }
284
+
285
+ // --- 执行动作 (Executors) ---
286
+ // 必须检测 _ensureActive 并在 finally 中销毁
287
+
288
+ async find() {
289
+ this._ensureActive()
290
+ try {
291
+ return await this._buildSelectQuery()
292
+ } finally {
293
+ this._destroy()
294
+ }
266
295
  }
267
296
 
268
297
  async get() {
298
+ // get 依赖 find,但因为 get 修改了 limit 状态,虽然 limit 方法没检测,
299
+ // 但最终调用的 find 会检测。为了保险起见,这里也可以不加 try/finally,
300
+ // 让 find 去处理销毁。
269
301
  this.limit(1)
270
- const rows = await this.find()
302
+ const rows = await this.find() // find 会负责 destroy
271
303
  return rows.length > 0 ? rows[0] : null
272
304
  }
273
305
 
274
- async count() {
275
- const t = this.sql(this.tableName)
276
-
277
- const w = this._buildWhere()
278
- const ft = this.sql`${this.sql(this.schema)}.${t}`
279
- const j = this._buildJoins()
280
-
281
- const r = await this.sql`SELECT count(*) as total FROM ${ft} ${j} ${w}`
282
-
283
- if (r.length === 0) return 0
284
-
285
- return parseInt(r[0].total)
286
- }
287
-
288
- async insert(data) {
289
- const isArray = Array.isArray(data)
290
- const inputs = isArray ? data : [data]
291
- if (inputs.length === 0) throw new Error('[NeoPG] Insert data cannot be empty')
292
-
293
- if (this.def) {
294
- this._prepareDataForInsert(inputs)
295
- }
296
-
297
- const fullTable = this.sql`${this.sql(this.schema)}.${this.sql(this.tableName)}`
306
+ async findAndCount() {
307
+ this._ensureActive()
308
+ try {
309
+ // 1. 数据查询
310
+ const dataQuery = this._buildSelectQuery()
298
311
 
299
- // [修改] 动态构建 returning
300
- const retFragment = this._buildReturning()
312
+ // 2. 总数查询
313
+ const t = this.sql(this.tableName)
314
+ const w = this._buildWhere()
315
+ const j = this._buildJoins()
316
+ const g = this._buildGroup()
317
+ const ft = this.sql`${this.sql(this.schema)}.${t}`
318
+
319
+ let countPromise
301
320
 
302
- const result = await this.sql`INSERT INTO ${fullTable} ${this.sql(inputs)} ${retFragment}`
321
+ if (this._group.length > 0) {
322
+ countPromise = this.sql`SELECT count(*) as total FROM (SELECT 1 FROM ${ft} ${j} ${w} ${g}) as temp`
323
+ } else {
324
+ countPromise = this.sql`SELECT count(*) as total FROM ${ft} ${j} ${w}`
325
+ }
303
326
 
304
- // 如果有 returning 数据,result 是数组(包含行);否则 result 是 Result 对象(包含 count)
305
- // 逻辑保持兼容:如果用户请求了数据,且是单条插入,返回对象;否则返回数组
306
- if (this._returning && this._returning.length > 0) {
307
- if (!isArray && result.length === 1) {
308
- return result[0]
309
- }
327
+ const [data, countResult] = await Promise.all([dataQuery, countPromise])
310
328
 
311
- return result
329
+ return {
330
+ data,
331
+ total: parseInt(countResult[0]?.total || 0, 10)
332
+ }
333
+ } finally {
334
+ this._destroy()
312
335
  }
313
-
314
- // 如果没有 returning,返回 postgres 原生结果 (包含 count 等信息)
315
- // 测试发现如果没有returning则返回的是空数组
316
- return result
317
336
  }
318
337
 
319
- async update(data) {
320
- if (!data || Object.keys(data).length === 0) throw new Error('[NeoPG] Update data cannot be empty')
321
- if (this.def) { this._prepareDataForUpdate(data) }
322
-
323
- if (this._conditions.length === 0) throw new Error('[NeoPG] UPDATE requires a WHERE condition')
324
-
325
- const fullTable = this.sql`${this.sql(this.schema)}.${this.sql(this.tableName)}`
326
- // 修复:使用新方法构建
327
- const whereFragment = this._buildWhere()
328
-
329
- // [修改] 动态构建 returning
330
- const retFragment = this._buildReturning()
331
-
332
- const result = await this.sql`UPDATE ${fullTable} SET ${this.sql(data)} ${whereFragment} ${retFragment}`
333
-
334
- if (this._returning && this._returning.length > 0) {
335
- if (result.length === 1) return result[0]
336
-
337
- return result
338
- }
339
-
340
- return result
341
- }
342
-
343
- async delete() {
344
- if (this._conditions.length === 0) throw new Error('[NeoPG] DELETE requires a WHERE condition')
345
- const fullTable = this.sql`${this.sql(this.schema)}.${this.sql(this.tableName)}`
346
- // 修复:使用新方法构建
347
- const whereFragment = this._buildWhere()
348
-
349
- const retFragment = this._buildReturning()
338
+ async count() {
339
+ this._ensureActive()
340
+ try {
341
+ const t = this.sql(this.tableName)
342
+ const w = this._buildWhere()
343
+ const j = this._buildJoins()
344
+ const g = this._buildGroup()
345
+ const ft = this.sql`${this.sql(this.schema)}.${t}`
350
346
 
351
- return await this.sql`DELETE FROM ${fullTable} ${whereFragment} ${retFragment}`
352
- }
353
-
354
- async transaction(callback) {
355
- return this.ctx.transaction(callback)
356
- /* return await this.sql.begin(async (trxSql) => {
357
- const scope = new TransactionScope(this.ctx, trxSql)
358
- return await callback(scope)
359
- }) */
360
- }
361
-
362
- begin(callback) {
363
- return this.ctx.transaction(callback)
364
- }
365
-
366
- /**
367
- * 内部通用 Join 添加器
368
- * @param {string} type - 'INNER JOIN', 'LEFT JOIN', 'RIGHT JOIN'
369
- * @param {string|Object} table - 表名 或 sql`fragment`
370
- * @param {string|Object} on - 条件字符串 或 sql`fragment`
371
- */
372
- _addJoin(type, table, on) {
373
- let tableFragment
374
- let onFragment
375
-
376
- // 1. 处理 Table
377
- if (table.constructor && table.constructor.name === 'Query') {
378
- tableFragment = table
379
- } else {
380
- // 默认作为当前 Schema 下的表名处理 "public"."table"
381
- // 如果需要跨 Schema (e.g. "other.table"),请用户传入 sql`other.table`
382
- tableFragment = this.sql(table)
383
- }
347
+ let query;
348
+ if (this._group.length > 0) {
349
+ query = this.sql`SELECT count(*) as total FROM (SELECT 1 FROM ${ft} ${j} ${w} ${g}) as temp`
350
+ } else {
351
+ query = this.sql`SELECT count(*) as total FROM ${ft} ${j} ${w}`
352
+ }
384
353
 
385
- // 2. 处理 ON 条件
386
- if (on.constructor && on.constructor.name === 'Query') {
387
- onFragment = on;
388
- } else {
389
- // 字符串情况,视为 Raw SQL (e.g. "u.id = p.uid")
390
- // 因为 ON 条件通常包含操作符,无法简单参数化,必须 unsafe
391
- onFragment = this.sql.unsafe(on);
354
+ const r = await query
355
+ if (r.length === 0) return 0
356
+ return parseInt(r[0].total)
357
+ } finally {
358
+ this._destroy()
392
359
  }
393
-
394
- // 3. 构建单个 Join 片段
395
- // 格式: TYPE + table + ON + condition
396
- const joinFragment = this.sql`${this.sql.unsafe(type)} ${tableFragment} ON ${onFragment}`
397
-
398
- this._joins.push(joinFragment)
399
- return this
400
- }
401
-
402
- join(table, on) {
403
- return this._addJoin('INNER JOIN', table, on)
404
- }
405
-
406
- innerJoin(table, on) {
407
- return this._addJoin('INNER JOIN', table, on)
408
- }
409
-
410
- leftJoin(table, on) {
411
- return this._addJoin('LEFT JOIN', table, on)
412
360
  }
413
361
 
414
- rightJoin(table, on) {
415
- return this._addJoin('RIGHT JOIN', table, on)
416
- }
417
-
418
- fullJoin(table, on) {
419
- return this._addJoin('FULL OUTER JOIN', table, on)
420
- }
421
-
422
- _buildJoins() {
423
- const len = this._joins.length
424
- if (len === 0) return this.sql``
362
+ async insert(data) {
363
+ this._ensureActive()
364
+ try {
365
+ const isArray = Array.isArray(data)
366
+ const inputs = isArray ? data : [data]
367
+ if (inputs.length === 0) throw new Error('[NeoPG] Insert data cannot be empty')
368
+
369
+ if (this.def) {
370
+ this._prepareDataForInsert(inputs)
371
+ }
425
372
 
426
- // 只有一个 Join,直接返回
427
- if (len === 1) {
428
- return this._joins[0]
429
- }
373
+ const fullTable = this.sql`${this.sql(this.schema)}.${this.sql(this.tableName)}`
374
+ const retFragment = this._buildReturning()
430
375
 
431
- // 多个 Join,必须用空格连接,不能用逗号
432
- // 采用“平铺数组”高性能方案
433
- const parts = new Array(len * 2 - 1)
434
- const SPACE = this.sql.unsafe(' ')
376
+ const result = await this.sql`INSERT INTO ${fullTable} ${this.sql(inputs)} ${retFragment}`
435
377
 
436
- for (let i = 0; i < len; i++) {
437
- parts[i * 2] = this._joins[i]
438
- if (i < len - 1) {
439
- parts[i * 2 + 1] = SPACE
378
+ if (this._returning && this._returning.length > 0) {
379
+ if (!isArray && result.length === 1) return result[0]
380
+ return result
440
381
  }
382
+ return result
383
+ } finally {
384
+ this._destroy()
441
385
  }
442
-
443
- return this.sql`${parts}`
444
386
  }
445
387
 
446
- /**
447
- * 添加 Group By 条件
448
- * .group('category_id')
449
- * .group('category_id, type')
450
- * .group(['id', 'name'])
451
- */
452
- group(arg) {
453
- if (!arg) return this
454
-
455
- // 1. Fragment
456
- if (arg.constructor && arg.constructor.name === 'Query') {
457
- this._group.push(arg)
458
- return this
459
- }
460
-
461
- // 2. Array
462
- if (Array.isArray(arg)) {
463
- arg.forEach(f => this.group(f))
464
- return this
465
- }
388
+ async update(data) {
389
+ this._ensureActive()
390
+ try {
391
+ if (!data || Object.keys(data).length === 0) throw new Error('[NeoPG] Update data cannot be empty')
392
+ if (this.def) { this._prepareDataForUpdate(data) }
393
+ if (this._conditions.length === 0) throw new Error('[NeoPG] UPDATE requires a WHERE condition')
394
+
395
+ const fullTable = this.sql`${this.sql(this.schema)}.${this.sql(this.tableName)}`
396
+ const whereFragment = this._buildWhere()
397
+ const retFragment = this._buildReturning()
398
+
399
+ const result = await this.sql`UPDATE ${fullTable} SET ${this.sql(data)} ${whereFragment} ${retFragment}`
466
400
 
467
- // 3. String
468
- if (typeof arg === 'string') {
469
- if (arg.includes(',')) {
470
- // 'id, name' -> 拆分
471
- arg.split(',').map(s => s.trim()).filter(s=>s).forEach(s => {
472
- this._group.push(this.sql(s))
473
- })
474
- } else {
475
- // 单个字段
476
- this._group.push(this.sql(arg))
401
+ if (this._returning && this._returning.length > 0) {
402
+ if (result.length === 1) return result[0]
403
+ return result
477
404
  }
405
+ return result
406
+ } finally {
407
+ this._destroy()
478
408
  }
479
-
480
- return this
481
409
  }
482
410
 
483
- // 构建 Group 片段
484
- _buildGroup() {
485
- if (this._group.length === 0) return this.sql``
486
-
487
- // postgres.js 模板数组默认用逗号连接,正好符合 GROUP BY 语法
488
- return this.sql`GROUP BY ${this._group}`
489
- }
490
-
491
- // --- 聚合函数 ---
492
-
493
- async min(field) {
494
- return this._aggregate('MIN', field)
495
- }
496
-
497
- async max(field) {
498
- return this._aggregate('MAX', field)
499
- }
500
-
501
- async sum(field) {
502
- return this._aggregate('SUM', field)
411
+ async delete() {
412
+ this._ensureActive()
413
+ try {
414
+ if (this._conditions.length === 0) throw new Error('[NeoPG] DELETE requires a WHERE condition')
415
+ const fullTable = this.sql`${this.sql(this.schema)}.${this.sql(this.tableName)}`
416
+ const whereFragment = this._buildWhere()
417
+ const retFragment = this._buildReturning()
418
+
419
+ return await this.sql`DELETE FROM ${fullTable} ${whereFragment} ${retFragment}`
420
+ } finally {
421
+ this._destroy()
422
+ }
503
423
  }
504
424
 
505
- async avg(field) {
506
- return this._aggregate('AVG', field)
425
+ async transaction(callback) {
426
+ // 事务通常开启新的 Scope,不需要销毁当前 Chain
427
+ return this.ctx.transaction(callback)
507
428
  }
508
429
 
509
- /**
510
- * 通用聚合执行器
511
- * @param {string} func - MIN, MAX, SUM, AVG
512
- * @param {string} field - 列名
513
- */
514
430
  async _aggregate(func, field) {
515
- if (!field) throw new Error(`[NeoPG] ${func} requires a field name.`)
516
-
517
- const t = this.sql(this.tableName)
518
- const w = this._buildWhere()
519
- const j = this._buildJoins()
520
- const ft = this.sql`${this.sql(this.schema)}.${t}`
521
-
522
- // 处理字段名 (可能是 'age' 也可能是 'users.age')
523
- let colFragment;
524
- if (field.constructor && field.constructor.name === 'Query') {
525
- colFragment = field
526
- } else {
527
- colFragment = this.sql(field)
431
+ this._ensureActive()
432
+ try {
433
+ if (!field) throw new Error(`[NeoPG] ${func} requires a field name.`)
434
+
435
+ const t = this.sql(this.tableName)
436
+ const w = this._buildWhere()
437
+ const j = this._buildJoins()
438
+ const ft = this.sql`${this.sql(this.schema)}.${t}`
439
+
440
+ let colFragment;
441
+ if (field.constructor && field.constructor.name === 'Query') {
442
+ colFragment = field
443
+ } else {
444
+ colFragment = this.sql(field)
445
+ }
446
+
447
+ const query = this.sql`
448
+ SELECT ${this.sql.unsafe(func)}(${colFragment}) as val
449
+ FROM ${ft} ${j} ${w}
450
+ `
451
+
452
+ const result = await query
453
+ const val = result.length > 0 ? result[0].val : null
454
+ if (val === null) return null;
455
+ return this._convertAggregateValue(val, field, func)
456
+ } finally {
457
+ this._destroy()
528
458
  }
529
-
530
- // SELECT MIN(age) as val ...
531
- const query = this.sql`
532
- SELECT ${this.sql.unsafe(func)}(${colFragment}) as val
533
- FROM ${ft} ${j} ${w}
534
- `
535
-
536
- const result = await query
537
- const val = result.length > 0 ? result[0].val : null
538
-
539
- if (val === null) return null;
540
-
541
- // 智能类型转换
542
- return this._convertAggregateValue(val, field, func)
543
459
  }
544
460
 
545
- /**
546
- * 智能转换聚合结果类型
547
- * Postgres 对于 SUM/AVG/COUNT 经常返回字符串 (BigInt/Numeric),我们需要转回 Number
548
- */
549
- _convertAggregateValue(val, field, func) {
550
- // 1. AVG 始终是浮点数
551
- if (func === 'AVG') {
552
- return parseFloat(val)
553
- }
461
+ async min(field) { return this._aggregate('MIN', field) }
462
+ async max(field) { return this._aggregate('MAX', field) }
463
+ async sum(field) { return this._aggregate('SUM', field) }
464
+ async avg(field) { return this._aggregate('AVG', field) }
554
465
 
555
- // 如果是 Raw Fragment,无法推断类型,直接返回原值(通常是 String)
466
+ _convertAggregateValue(val, field, func) {
467
+ if (func === 'AVG') return parseFloat(val)
556
468
  if (typeof field !== 'string') return val
557
469
 
558
- // 2. 尝试从 ModelDef 获取列定义
559
- // field 可能是 'age' 也可能是 'u.age' (别名暂不支持自动推断,这里只处理简单列名)
560
470
  const colDef = this.def && this.def.columns ? this.def.columns[field] : null
561
-
562
- // 如果不知道列定义,尝试尽力猜测
563
471
  if (!colDef) {
564
- // 如果 val 是字符串且长得像数字
565
472
  if (typeof val === 'string' && !isNaN(val)) {
566
- // SUM 默认为数字
567
473
  if (func === 'SUM') return parseFloat(val)
568
474
  }
569
-
570
475
  return val
571
476
  }
572
477
 
573
- // 5. [优化] 精确类型匹配
574
- // 处理 'numeric(10,2)' -> 'numeric'
575
- // 处理 'integer' -> 'integer'
576
478
  const rawType = colDef.type.toLowerCase()
577
479
  const parenIndex = rawType.indexOf('(')
578
480
  const baseType = parenIndex > 0 ? rawType.substring(0, parenIndex).trim() : rawType
579
481
 
580
- // 整数匹配
581
- if (INT_TYPES.has(baseType)) {
582
- return parseInt(val, 10)
482
+ if (INT_TYPES.has(baseType)) return parseInt(val, 10)
483
+ if (FLOAT_TYPES.has(baseType)) return parseFloat(val)
484
+ return val
485
+ }
486
+
487
+ join(table, on) { return this._addJoin('INNER JOIN', table, on) }
488
+ innerJoin(table, on) { return this._addJoin('INNER JOIN', table, on) }
489
+ leftJoin(table, on) { return this._addJoin('LEFT JOIN', table, on) }
490
+ rightJoin(table, on) { return this._addJoin('RIGHT JOIN', table, on) }
491
+ fullJoin(table, on) { return this._addJoin('FULL OUTER JOIN', table, on) }
492
+
493
+ _addJoin(type, table, on) {
494
+ // 内部方法,不需要检测,提升性能
495
+ let tableFragment
496
+ let onFragment
497
+
498
+ if (table.constructor && table.constructor.name === 'Query') {
499
+ tableFragment = table
500
+ } else {
501
+ tableFragment = this.sql(table)
583
502
  }
584
503
 
585
- // 浮点数匹配
586
- if (FLOAT_TYPES.has(baseType)) {
587
- return parseFloat(val)
504
+ if (on.constructor && on.constructor.name === 'Query') {
505
+ onFragment = on;
506
+ } else {
507
+ onFragment = this.sql.unsafe(on);
588
508
  }
589
509
 
590
- // 其他 (Date, String, Boolean) 原样返回
591
- return val
510
+ this._joins.push(this.sql`${this.sql.unsafe(type)} ${tableFragment} ON ${onFragment}`)
511
+ return this
592
512
  }
593
513
 
594
- // --- 内部辅助方法 ---
595
-
514
+ // --- 数据预处理 ---
596
515
  _prepareDataForInsert(rows) {
597
516
  const pk = this.def.primaryKey
598
517
  const autoId = this.def.autoId
599
518
  const pkLen = this.def.pkLen
600
519
  const ts = this.def.timestamps
601
- const defaults = this.def.defaults
602
520
 
603
521
  let make_timestamp = ts.insert && ts.insert.length > 0
604
522
 
@@ -606,13 +524,9 @@ class ModelChain {
606
524
  if (autoId && row[pk] === undefined) {
607
525
  row[pk] = this.def.makeId(pkLen)
608
526
  }
609
-
610
527
  if (make_timestamp) {
611
- for (const t of ts.insert) {
612
- makeTimestamp(row, t)
613
- }
528
+ for (const t of ts.insert) makeTimestamp(row, t)
614
529
  }
615
-
616
530
  for (const key in row) {
617
531
  this.def.validateField(key, row[key])
618
532
  }
@@ -621,17 +535,13 @@ class ModelChain {
621
535
 
622
536
  _prepareDataForUpdate(row) {
623
537
  const ts = this.def.timestamps
624
-
625
538
  if (ts.update && ts.update.length > 0) {
626
- for (const t of ts.update) {
627
- makeTimestamp(row, t)
628
- }
539
+ for (const t of ts.update) makeTimestamp(row, t)
629
540
  }
630
-
631
541
  for (const key in row) {
632
542
  this.def.validateField(key, row[key])
633
543
  }
634
544
  }
635
545
  }
636
546
 
637
- module.exports = ModelChain
547
+ module.exports = ModelChain
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "neopg",
3
- "version": "2.0.1",
3
+ "version": "2.0.2",
4
4
  "description": "orm for postgres",
5
5
  "keywords": [
6
6
  "neopg",
package/test/test-db.js CHANGED
@@ -213,7 +213,10 @@ db.add(User);
213
213
 
214
214
  if (n > 7) {
215
215
  console.error('\x1b[7;5m随机测试:将会让事物执行失败\x1b[0m')
216
- await tx.model('User').insert({username: 'Neo'})
216
+ //await tx.model('User').insert({username: 'Neo'})
217
+ let subt = tx.model('User')
218
+ console.log('count', await subt.count())
219
+ console.log('count', await subt.count())
217
220
  }
218
221
 
219
222
  console.log('test count',