prisma-sql 1.53.0 → 1.55.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/generator.js CHANGED
@@ -54,7 +54,7 @@ var require_package = __commonJS({
54
54
  "package.json"(exports$1, module) {
55
55
  module.exports = {
56
56
  name: "prisma-sql",
57
- version: "1.53.0",
57
+ version: "1.55.0",
58
58
  description: "Convert Prisma queries to optimized SQL with type safety. 2-7x faster than Prisma Client.",
59
59
  main: "dist/index.cjs",
60
60
  module: "dist/index.js",
@@ -813,6 +813,9 @@ function assertNoControlChars(label, s) {
813
813
  );
814
814
  }
815
815
  }
816
+ function quoteRawIdent(id) {
817
+ return `"${id.replace(/"/g, '""')}"`;
818
+ }
816
819
  function isIdentCharCode(c) {
817
820
  return c >= 48 && c <= 57 || c >= 65 && c <= 90 || c >= 97 && c <= 122 || c === 95;
818
821
  }
@@ -834,33 +837,33 @@ function parseQuotedPart(input, start) {
834
837
  }
835
838
  if (!sawAny) {
836
839
  throw new Error(
837
- `tableName/tableRef has empty quoted identifier part: ${JSON.stringify(input)}`
840
+ `qualified name has empty quoted identifier part: ${JSON.stringify(input)}`
838
841
  );
839
842
  }
840
843
  return i + 1;
841
844
  }
842
845
  if (c === 10 || c === 13 || c === 0) {
843
846
  throw new Error(
844
- `tableName/tableRef contains invalid characters: ${JSON.stringify(input)}`
847
+ `qualified name contains invalid characters: ${JSON.stringify(input)}`
845
848
  );
846
849
  }
847
850
  sawAny = true;
848
851
  i++;
849
852
  }
850
853
  throw new Error(
851
- `tableName/tableRef has unterminated quoted identifier: ${JSON.stringify(input)}`
854
+ `qualified name has unterminated quoted identifier: ${JSON.stringify(input)}`
852
855
  );
853
856
  }
854
857
  function parseUnquotedPart(input, start) {
855
858
  const n = input.length;
856
859
  let i = start;
857
860
  if (i >= n) {
858
- throw new Error(`tableName/tableRef is invalid: ${JSON.stringify(input)}`);
861
+ throw new Error(`qualified name is invalid: ${JSON.stringify(input)}`);
859
862
  }
860
863
  const c0 = input.charCodeAt(i);
861
864
  if (!isIdentStartCharCode(c0)) {
862
865
  throw new Error(
863
- `tableName/tableRef must use identifiers (or quoted identifiers). Got: ${JSON.stringify(input)}`
866
+ `qualified name must use identifiers (or quoted identifiers). Got: ${JSON.stringify(input)}`
864
867
  );
865
868
  }
866
869
  i++;
@@ -869,15 +872,15 @@ function parseUnquotedPart(input, start) {
869
872
  if (c === 46) break;
870
873
  if (!isIdentCharCode(c)) {
871
874
  throw new Error(
872
- `tableName/tableRef contains invalid identifier characters: ${JSON.stringify(input)}`
875
+ `qualified name contains invalid identifier characters: ${JSON.stringify(input)}`
873
876
  );
874
877
  }
875
878
  i++;
876
879
  }
877
880
  return i;
878
881
  }
879
- function assertSafeQualifiedName(tableRef) {
880
- const raw = String(tableRef);
882
+ function assertSafeQualifiedName(input) {
883
+ const raw = String(input);
881
884
  const trimmed = raw.trim();
882
885
  if (trimmed.length === 0) {
883
886
  throw new Error("tableName/tableRef is required and cannot be empty");
@@ -951,7 +954,7 @@ function quote2(id) {
951
954
  );
952
955
  }
953
956
  if (needsQuoting(id)) {
954
- return `"${id.replace(/"/g, '""')}"`;
957
+ return quoteRawIdent(id);
955
958
  }
956
959
  return id;
957
960
  }
@@ -2590,28 +2593,24 @@ function buildOperator(expr, op, val, ctx, mode, fieldType) {
2590
2593
  function toSafeSqlIdentifier(input) {
2591
2594
  const raw = String(input);
2592
2595
  const n = raw.length;
2596
+ if (n === 0) return "_t";
2593
2597
  let out = "";
2594
2598
  for (let i = 0; i < n; i++) {
2595
2599
  const c = raw.charCodeAt(i);
2596
2600
  const isAZ = c >= 65 && c <= 90 || c >= 97 && c <= 122;
2597
2601
  const is09 = c >= 48 && c <= 57;
2598
2602
  const isUnderscore = c === 95;
2599
- if (isAZ || is09 || isUnderscore) {
2600
- out += raw[i];
2601
- } else {
2602
- out += "_";
2603
- }
2603
+ out += isAZ || is09 || isUnderscore ? raw[i] : "_";
2604
2604
  }
2605
- if (out.length === 0) out = "_t";
2606
2605
  const c0 = out.charCodeAt(0);
2607
2606
  const startsOk = c0 >= 65 && c0 <= 90 || c0 >= 97 && c0 <= 122 || c0 === 95;
2608
- if (!startsOk) out = `_${out}`;
2609
- const lowered = out.toLowerCase();
2607
+ const lowered = (startsOk ? out : `_${out}`).toLowerCase();
2610
2608
  return ALIAS_FORBIDDEN_KEYWORDS.has(lowered) ? `_${lowered}` : lowered;
2611
2609
  }
2612
2610
  function createAliasGenerator(maxAliases = 1e4) {
2613
2611
  let counter = 0;
2614
2612
  const usedAliases = /* @__PURE__ */ new Set();
2613
+ const maxLen = 63;
2615
2614
  return {
2616
2615
  next(baseName) {
2617
2616
  if (usedAliases.size >= maxAliases) {
@@ -2621,14 +2620,13 @@ function createAliasGenerator(maxAliases = 1e4) {
2621
2620
  }
2622
2621
  const base = toSafeSqlIdentifier(baseName);
2623
2622
  const suffix = `_${counter}`;
2624
- const maxLen = 63;
2625
2623
  const baseMax = Math.max(1, maxLen - suffix.length);
2626
2624
  const trimmedBase = base.length > baseMax ? base.slice(0, baseMax) : base;
2627
2625
  const alias = `${trimmedBase}${suffix}`;
2628
2626
  counter += 1;
2629
2627
  if (usedAliases.has(alias)) {
2630
2628
  throw new Error(
2631
- `CRITICAL: Duplicate alias '${alias}' at counter=${counter}. This indicates a bug in alias generation logic.`
2629
+ `CRITICAL: Duplicate alias '${alias}' at counter=${counter}.`
2632
2630
  );
2633
2631
  }
2634
2632
  usedAliases.add(alias);
@@ -2680,24 +2678,19 @@ function normalizeDynamicNameOrThrow(dynamicName, index) {
2680
2678
  }
2681
2679
  return dn;
2682
2680
  }
2683
- function assertUniqueDynamicName(dn, seen) {
2684
- if (seen.has(dn)) {
2685
- throw new Error(`CRITICAL: Duplicate dynamic param name in mappings: ${dn}`);
2686
- }
2687
- seen.add(dn);
2688
- }
2689
- function validateMappingEntry(m, expectedIndex, seenDynamic) {
2690
- assertSequentialIndex(m.index, expectedIndex);
2691
- assertExactlyOneOfDynamicOrValue(m);
2692
- if (typeof m.dynamicName === "string") {
2693
- const dn = normalizeDynamicNameOrThrow(m.dynamicName, m.index);
2694
- assertUniqueDynamicName(dn, seenDynamic);
2695
- }
2696
- }
2697
2681
  function validateMappings(mappings) {
2698
2682
  const seenDynamic = /* @__PURE__ */ new Set();
2699
2683
  for (let i = 0; i < mappings.length; i++) {
2700
- validateMappingEntry(mappings[i], i + 1, seenDynamic);
2684
+ const m = mappings[i];
2685
+ assertSequentialIndex(m.index, i + 1);
2686
+ assertExactlyOneOfDynamicOrValue(m);
2687
+ if (typeof m.dynamicName === "string") {
2688
+ const dn = normalizeDynamicNameOrThrow(m.dynamicName, m.index);
2689
+ if (seenDynamic.has(dn)) {
2690
+ throw new Error(`CRITICAL: Duplicate dynamic param name: ${dn}`);
2691
+ }
2692
+ seenDynamic.add(dn);
2693
+ }
2701
2694
  }
2702
2695
  }
2703
2696
  function validateState(params, mappings, index) {
@@ -2709,16 +2702,19 @@ function validateState(params, mappings, index) {
2709
2702
  }
2710
2703
  function createStoreInternal(startIndex, initialParams = [], initialMappings = []) {
2711
2704
  let index = startIndex;
2712
- const params = initialParams.length > 0 ? [...initialParams] : [];
2713
- const mappings = initialMappings.length > 0 ? [...initialMappings] : [];
2705
+ const params = initialParams.length > 0 ? initialParams.slice() : [];
2706
+ const mappings = initialMappings.length > 0 ? initialMappings.slice() : [];
2714
2707
  const dynamicNameToIndex = /* @__PURE__ */ new Map();
2715
- for (const m of mappings) {
2708
+ for (let i = 0; i < mappings.length; i++) {
2709
+ const m = mappings[i];
2716
2710
  if (typeof m.dynamicName === "string") {
2717
2711
  dynamicNameToIndex.set(m.dynamicName.trim(), m.index);
2718
2712
  }
2719
2713
  }
2720
2714
  let dirty = true;
2721
2715
  let cachedSnapshot = null;
2716
+ let frozenParams = null;
2717
+ let frozenMappings = null;
2722
2718
  function assertCanAdd() {
2723
2719
  if (index > MAX_PARAM_INDEX) {
2724
2720
  throw new Error(
@@ -2770,13 +2766,17 @@ function createStoreInternal(startIndex, initialParams = [], initialMappings = [
2770
2766
  }
2771
2767
  function snapshot() {
2772
2768
  if (!dirty && cachedSnapshot) return cachedSnapshot;
2769
+ if (!frozenParams) frozenParams = Object.freeze(params.slice());
2770
+ if (!frozenMappings) frozenMappings = Object.freeze(mappings.slice());
2773
2771
  const snap = {
2774
2772
  index,
2775
- params,
2776
- mappings
2773
+ params: frozenParams,
2774
+ mappings: frozenMappings
2777
2775
  };
2778
2776
  cachedSnapshot = snap;
2779
2777
  dirty = false;
2778
+ frozenParams = null;
2779
+ frozenMappings = null;
2780
2780
  return snap;
2781
2781
  }
2782
2782
  return {
@@ -2800,11 +2800,11 @@ function createParamStore(startIndex = 1) {
2800
2800
  return createStoreInternal(startIndex);
2801
2801
  }
2802
2802
  function createParamStoreFrom(existingParams, existingMappings, nextIndex) {
2803
- validateState([...existingParams], [...existingMappings], nextIndex);
2803
+ validateState(existingParams, existingMappings, nextIndex);
2804
2804
  return createStoreInternal(
2805
2805
  nextIndex,
2806
- [...existingParams],
2807
- [...existingMappings]
2806
+ existingParams.slice(),
2807
+ existingMappings.slice()
2808
2808
  );
2809
2809
  }
2810
2810
 
@@ -2986,7 +2986,7 @@ function getRelationTableReference(relModel, dialect) {
2986
2986
  dialect
2987
2987
  );
2988
2988
  }
2989
- function resolveRelationOrThrow(model, schemas, schemaByName, relName) {
2989
+ function resolveRelationOrThrow(model, schemaByName, relName) {
2990
2990
  const field = model.fields.find((f) => f.name === relName);
2991
2991
  if (!isNotNullish(field)) {
2992
2992
  throw new Error(
@@ -3040,8 +3040,9 @@ function validateOrderByForModel(model, orderBy) {
3040
3040
  throw new Error("orderBy array entries must have exactly one field");
3041
3041
  }
3042
3042
  const fieldName = String(entries[0][0]).trim();
3043
- if (fieldName.length === 0)
3043
+ if (fieldName.length === 0) {
3044
3044
  throw new Error("orderBy field name cannot be empty");
3045
+ }
3045
3046
  if (!scalarSet.has(fieldName)) {
3046
3047
  throw new Error(
3047
3048
  `orderBy references unknown or non-scalar field '${fieldName}' on model ${model.name}`
@@ -3100,8 +3101,9 @@ function extractRelationPaginationConfig(relArgs) {
3100
3101
  function maybeReverseNegativeTake(takeVal, hasOrderBy, orderByInput) {
3101
3102
  if (typeof takeVal !== "number") return { takeVal, orderByInput };
3102
3103
  if (takeVal >= 0) return { takeVal, orderByInput };
3103
- if (!hasOrderBy)
3104
+ if (!hasOrderBy) {
3104
3105
  throw new Error("Negative take requires orderBy for deterministic results");
3106
+ }
3105
3107
  return {
3106
3108
  takeVal: Math.abs(takeVal),
3107
3109
  orderByInput: reverseOrderByInput(orderByInput)
@@ -3111,9 +3113,7 @@ function finalizeOrderByForInclude(args) {
3111
3113
  if (args.hasOrderBy && isNotNullish(args.orderByInput)) {
3112
3114
  validateOrderByForModel(args.relModel, args.orderByInput);
3113
3115
  }
3114
- if (!args.hasPagination) {
3115
- return args.orderByInput;
3116
- }
3116
+ if (!args.hasPagination) return args.orderByInput;
3117
3117
  return ensureDeterministicOrderByInput({
3118
3118
  orderBy: args.hasOrderBy ? args.orderByInput : void 0,
3119
3119
  model: args.relModel,
@@ -3174,7 +3174,9 @@ function buildOrderBySql(finalOrderByInput, relAlias, dialect, relModel) {
3174
3174
  return isNotNullish(finalOrderByInput) ? buildOrderBy(finalOrderByInput, relAlias, dialect, relModel) : "";
3175
3175
  }
3176
3176
  function buildBaseSql(args) {
3177
- return `${SQL_TEMPLATES.SELECT} ${args.selectExpr} ${SQL_TEMPLATES.FROM} ${args.relTable} ${args.relAlias} ${args.joins} ${SQL_TEMPLATES.WHERE} ${args.joinPredicate}${args.whereClause}`;
3177
+ const joins = args.joins ? ` ${args.joins}` : "";
3178
+ const where = `${SQL_TEMPLATES.WHERE} ${args.joinPredicate}${args.whereClause}`;
3179
+ return `${SQL_TEMPLATES.SELECT} ${args.selectExpr} ${SQL_TEMPLATES.FROM} ${args.relTable} ${args.relAlias}${joins} ` + where;
3178
3180
  }
3179
3181
  function buildOneToOneIncludeSql(args) {
3180
3182
  const objExpr = jsonBuildObject(args.relSelect, args.ctx.dialect);
@@ -3186,9 +3188,7 @@ function buildOneToOneIncludeSql(args) {
3186
3188
  joinPredicate: args.joinPredicate,
3187
3189
  whereClause: args.whereClause
3188
3190
  });
3189
- if (args.orderBySql) {
3190
- sql += ` ${SQL_TEMPLATES.ORDER_BY} ${args.orderBySql}`;
3191
- }
3191
+ if (args.orderBySql) sql += ` ${SQL_TEMPLATES.ORDER_BY} ${args.orderBySql}`;
3192
3192
  if (isNotNullish(args.takeVal)) {
3193
3193
  return appendLimitOffset(
3194
3194
  sql,
@@ -3241,7 +3241,7 @@ function buildListIncludeSpec(args) {
3241
3241
  `include.${args.relName}`
3242
3242
  );
3243
3243
  const selectExpr = jsonAgg("row", args.ctx.dialect);
3244
- const sql = `${SQL_TEMPLATES.SELECT} ${selectExpr} ${SQL_TEMPLATES.FROM} (${base}) ${rowAlias}`;
3244
+ const sql = `${SQL_TEMPLATES.SELECT} ${selectExpr} ${SQL_TEMPLATES.FROM} (${base}) ${SQL_TEMPLATES.AS} ${rowAlias}`;
3245
3245
  return Object.freeze({ name: args.relName, sql, isOneToOne: false });
3246
3246
  }
3247
3247
  function buildSingleInclude(relName, relArgs, field, relModel, ctx) {
@@ -3339,12 +3339,7 @@ function buildIncludeSqlInternal(args, model, schemas, schemaByName, parentAlias
3339
3339
  `Query complexity limit exceeded: ${stats.totalSubqueries} subqueries generated. Maximum allowed: ${MAX_TOTAL_SUBQUERIES}. This indicates exponential include nesting. Stats: depth=${stats.maxDepth}, includes=${stats.totalIncludes}. Path: ${visitPath.join(" -> ")}. Simplify your include structure or split into multiple queries.`
3340
3340
  );
3341
3341
  }
3342
- const resolved = resolveRelationOrThrow(
3343
- model,
3344
- schemas,
3345
- schemaByName,
3346
- relName
3347
- );
3342
+ const resolved = resolveRelationOrThrow(model, schemaByName, relName);
3348
3343
  const relationPath = `${model.name}.${relName}`;
3349
3344
  const currentPath = [...visitPath, relationPath];
3350
3345
  if (visitPath.includes(relationPath)) {
@@ -3400,7 +3395,7 @@ function buildIncludeSql(args, model, schemas, parentAlias, params, dialect) {
3400
3395
  stats
3401
3396
  );
3402
3397
  }
3403
- function resolveCountRelationOrThrow(relName, model, schemas, schemaByName) {
3398
+ function resolveCountRelationOrThrow(relName, model, schemaByName) {
3404
3399
  const relationSet = getRelationFieldSet(model);
3405
3400
  if (!relationSet.has(relName)) {
3406
3401
  throw new Error(
@@ -3408,10 +3403,11 @@ function resolveCountRelationOrThrow(relName, model, schemas, schemaByName) {
3408
3403
  );
3409
3404
  }
3410
3405
  const field = model.fields.find((f) => f.name === relName);
3411
- if (!field)
3406
+ if (!field) {
3412
3407
  throw new Error(
3413
3408
  `_count.${relName} references unknown relation on model ${model.name}`
3414
3409
  );
3410
+ }
3415
3411
  if (!isValidRelationField(field)) {
3416
3412
  throw new Error(
3417
3413
  `_count.${relName} has invalid relation metadata on model ${model.name}`
@@ -3439,8 +3435,9 @@ function defaultReferencesForCount(fkCount) {
3439
3435
  }
3440
3436
  function resolveCountKeyPairs(field) {
3441
3437
  const fkFields = normalizeKeyList(field.foreignKey);
3442
- if (fkFields.length === 0)
3438
+ if (fkFields.length === 0) {
3443
3439
  throw new Error("Relation count requires foreignKey");
3440
+ }
3444
3441
  const refsRaw = field.references;
3445
3442
  const refs = normalizeKeyList(refsRaw);
3446
3443
  const refFields = refs.length > 0 ? refs : defaultReferencesForCount(fkFields.length);
@@ -3516,12 +3513,7 @@ function buildRelationCountSql(countSelect, model, schemas, parentAlias, _params
3516
3513
  for (const m of schemas) schemaByName.set(m.name, m);
3517
3514
  for (const [relName, shouldCount] of Object.entries(countSelect)) {
3518
3515
  if (!shouldCount) continue;
3519
- const resolved = resolveCountRelationOrThrow(
3520
- relName,
3521
- model,
3522
- schemas,
3523
- schemaByName
3524
- );
3516
+ const resolved = resolveCountRelationOrThrow(relName, model, schemaByName);
3525
3517
  const built = buildCountJoinAndPair({
3526
3518
  relName,
3527
3519
  field: resolved.field,
@@ -4951,12 +4943,81 @@ function generateCode(models, queries, dialect, datamodel) {
4951
4943
  }));
4952
4944
  const { mappings, fieldTypes } = extractEnumMappings(datamodel);
4953
4945
  return `// Generated by @prisma-sql/generator - DO NOT EDIT
4954
- import { buildSQL, transformQueryResults, type PrismaMethod, type Model } from 'prisma-sql'
4946
+ import { buildSQL, buildBatchSql, parseBatchResults, buildBatchCountSql, parseBatchCountResults, createTransactionExecutor, transformQueryResults, type PrismaMethod, type Model, type BatchQuery, type BatchCountQuery, type TransactionQuery, type TransactionOptions } from 'prisma-sql'
4947
+
4948
+ class DeferredQuery {
4949
+ constructor(
4950
+ public readonly model: string,
4951
+ public readonly method: PrismaMethod,
4952
+ public readonly args: any,
4953
+ ) {}
4954
+
4955
+ then(onfulfilled?: any, onrejected?: any): any {
4956
+ throw new Error(
4957
+ 'Cannot await a batch query. Batch queries must not be awaited inside the $batch callback.',
4958
+ )
4959
+ }
4960
+ }
4961
+
4962
+ interface BatchProxy {
4963
+ [modelName: string]: {
4964
+ findMany: (args?: any) => DeferredQuery
4965
+ findFirst: (args?: any) => DeferredQuery
4966
+ findUnique: (args?: any) => DeferredQuery
4967
+ count: (args?: any) => DeferredQuery
4968
+ aggregate: (args?: any) => DeferredQuery
4969
+ groupBy: (args?: any) => DeferredQuery
4970
+ }
4971
+ }
4972
+
4973
+ const ACCELERATED_METHODS = new Set<PrismaMethod>([
4974
+ 'findMany',
4975
+ 'findFirst',
4976
+ 'findUnique',
4977
+ 'count',
4978
+ 'aggregate',
4979
+ 'groupBy',
4980
+ ])
4981
+
4982
+ function createBatchProxy(): BatchProxy {
4983
+ return new Proxy(
4984
+ {},
4985
+ {
4986
+ get(_target, modelName: string): any {
4987
+ if (typeof modelName === 'symbol') return undefined
4988
+
4989
+ const model = MODEL_MAP.get(modelName)
4990
+ if (!model) {
4991
+ throw new Error(
4992
+ \`Model '\${modelName}' not found. Available: \${[...MODEL_MAP.keys()].join(', ')}\`,
4993
+ )
4994
+ }
4995
+
4996
+ return new Proxy(
4997
+ {},
4998
+ {
4999
+ get(_target, method: string): (args?: any) => DeferredQuery {
5000
+ if (!ACCELERATED_METHODS.has(method as PrismaMethod)) {
5001
+ throw new Error(
5002
+ \`Method '\${method}' not supported in batch. Supported: \${[...ACCELERATED_METHODS].join(', ')}\`,
5003
+ )
5004
+ }
5005
+
5006
+ return (args?: any): DeferredQuery => {
5007
+ return new DeferredQuery(
5008
+ modelName,
5009
+ method as PrismaMethod,
5010
+ args,
5011
+ )
5012
+ }
5013
+ },
5014
+ },
5015
+ )
5016
+ },
5017
+ },
5018
+ ) as BatchProxy
5019
+ }
4955
5020
 
4956
- /**
4957
- * Normalize values for SQL params.
4958
- * Synced from src/utils/normalize-value.ts
4959
- */
4960
5021
  function normalizeValue(value: unknown, seen = new WeakSet<object>(), depth = 0): unknown {
4961
5022
  const MAX_DEPTH = 20
4962
5023
  if (depth > MAX_DEPTH) {
@@ -5003,9 +5064,6 @@ function normalizeValue(value: unknown, seen = new WeakSet<object>(), depth = 0)
5003
5064
  return value
5004
5065
  }
5005
5066
 
5006
- /**
5007
- * Get nested value from object using dot notation path
5008
- */
5009
5067
  function getByPath(obj: any, path: string): unknown {
5010
5068
  if (!obj || !path) return undefined
5011
5069
  const keys = path.split('.')
@@ -5017,9 +5075,6 @@ function getByPath(obj: any, path: string): unknown {
5017
5075
  return result
5018
5076
  }
5019
5077
 
5020
- /**
5021
- * Normalize all params in array
5022
- */
5023
5078
  function normalizeParams(params: unknown[]): unknown[] {
5024
5079
  return params.map(p => normalizeValue(p))
5025
5080
  }
@@ -5040,6 +5095,8 @@ const QUERIES: Record<string, Record<string, Record<string, {
5040
5095
 
5041
5096
  const DIALECT = ${JSON.stringify(dialect)}
5042
5097
 
5098
+ const MODEL_MAP = new Map(MODELS.map(m => [m.name, m]))
5099
+
5043
5100
  function isDynamicParam(key: string): boolean {
5044
5101
  return key === 'skip' || key === 'take' || key === 'cursor'
5045
5102
  }
@@ -5206,6 +5263,13 @@ async function executeQuery(client: any, sql: string, params: unknown[]): Promis
5206
5263
  return stmt.all(...normalizedParams)
5207
5264
  }
5208
5265
 
5266
+ async function executeRaw(client: any, sql: string, params?: unknown[]): Promise<unknown[]> {
5267
+ if (DIALECT === 'postgres') {
5268
+ return await client.unsafe(sql, (params || []) as any[])
5269
+ }
5270
+ throw new Error('Raw execution for sqlite not supported in transactions')
5271
+ }
5272
+
5209
5273
  export function speedExtension(config: {
5210
5274
  postgres?: any
5211
5275
  sqlite?: any
@@ -5233,6 +5297,14 @@ export function speedExtension(config: {
5233
5297
  }
5234
5298
 
5235
5299
  return (prisma: any) => {
5300
+ const txExecutor = createTransactionExecutor({
5301
+ modelMap: MODEL_MAP,
5302
+ allModels: MODELS,
5303
+ dialect: DIALECT,
5304
+ executeRaw: (sql: string, params?: unknown[]) => executeRaw(client, sql, params),
5305
+ postgresClient: postgres,
5306
+ })
5307
+
5236
5308
  const handleMethod = async function(this: any, method: PrismaMethod, args: any) {
5237
5309
  const modelName = this?.name || this?.$name
5238
5310
  const startTime = Date.now()
@@ -5286,11 +5358,130 @@ export function speedExtension(config: {
5286
5358
  return transformQueryResults(method, results)
5287
5359
  }
5288
5360
 
5361
+ async function batch<T extends Record<string, DeferredQuery>>(
5362
+ callback: (batch: BatchProxy) => T | Promise<T>,
5363
+ ): Promise<{ [K in keyof T]: any }> {
5364
+ const batchProxy = createBatchProxy()
5365
+ const queries = await callback(batchProxy)
5366
+
5367
+ const batchQueries: Record<string, BatchQuery> = {}
5368
+
5369
+ for (const [key, deferred] of Object.entries(queries)) {
5370
+ if (!(deferred instanceof DeferredQuery)) {
5371
+ throw new Error(
5372
+ \`Batch query '\${key}' must be a deferred query. Did you await it?\`,
5373
+ )
5374
+ }
5375
+
5376
+ batchQueries[key] = {
5377
+ model: deferred.model,
5378
+ method: deferred.method,
5379
+ args: deferred.args || {},
5380
+ }
5381
+ }
5382
+
5383
+ const startTime = Date.now()
5384
+ const { sql, params, keys } = buildBatchSql(
5385
+ batchQueries,
5386
+ MODEL_MAP,
5387
+ MODELS,
5388
+ DIALECT,
5389
+ )
5390
+
5391
+ if (debug) {
5392
+ console.log(\`[\${DIALECT}] $batch (\${keys.length} queries)\`)
5393
+ console.log('SQL:', sql)
5394
+ console.log('Params:', params)
5395
+ }
5396
+
5397
+ const normalizedParams = normalizeParams(params)
5398
+ const rows = await client.unsafe(sql, normalizedParams as any[])
5399
+ const row = rows[0] as Record<string, unknown>
5400
+ const results = parseBatchResults(row, keys, batchQueries)
5401
+ const duration = Date.now() - startTime
5402
+
5403
+ onQuery?.({
5404
+ model: '_batch',
5405
+ method: 'batch',
5406
+ sql,
5407
+ params: normalizedParams,
5408
+ duration,
5409
+ prebaked: false,
5410
+ })
5411
+
5412
+ return results as { [K in keyof T]: any }
5413
+ }
5414
+
5415
+ async function batchCount(queries: BatchCountQuery[]): Promise<number[]> {
5416
+ if (queries.length === 0) return []
5417
+
5418
+ const startTime = Date.now()
5419
+ const { sql, params } = buildBatchCountSql(
5420
+ queries,
5421
+ MODEL_MAP,
5422
+ MODELS,
5423
+ DIALECT,
5424
+ )
5425
+
5426
+ if (debug) {
5427
+ console.log(\`[\${DIALECT}] $batchCount (\${queries.length} queries)\`)
5428
+ console.log('SQL:', sql)
5429
+ console.log('Params:', params)
5430
+ }
5431
+
5432
+ const normalizedParams = normalizeParams(params)
5433
+ const rows = await client.unsafe(sql, normalizedParams as any[])
5434
+ const row = rows[0] as Record<string, unknown>
5435
+ const results = parseBatchCountResults(row, queries.length)
5436
+ const duration = Date.now() - startTime
5437
+
5438
+ onQuery?.({
5439
+ model: '_batch',
5440
+ method: 'count',
5441
+ sql,
5442
+ params: normalizedParams,
5443
+ duration,
5444
+ prebaked: false,
5445
+ })
5446
+
5447
+ return results
5448
+ }
5449
+
5450
+ async function transaction(
5451
+ queries: TransactionQuery[],
5452
+ options?: TransactionOptions,
5453
+ ): Promise<unknown[]> {
5454
+ const startTime = Date.now()
5455
+
5456
+ if (debug) {
5457
+ console.log(\`[\${DIALECT}] $transaction (\${queries.length} queries)\`)
5458
+ }
5459
+
5460
+ const results = await txExecutor.execute(queries, options)
5461
+ const duration = Date.now() - startTime
5462
+
5463
+ onQuery?.({
5464
+ model: '_transaction',
5465
+ method: 'count',
5466
+ sql: \`TRANSACTION(\${queries.length})\`,
5467
+ params: [],
5468
+ duration,
5469
+ prebaked: false,
5470
+ })
5471
+
5472
+ return results
5473
+ }
5474
+
5289
5475
  return prisma.$extends({
5290
5476
  name: 'prisma-sql-generated',
5291
5477
 
5292
5478
  client: {
5293
5479
  $original: prisma,
5480
+ $batch: batch as <T extends Record<string, DeferredQuery>>(
5481
+ callback: (batch: BatchProxy) => T | Promise<T>,
5482
+ ) => Promise<{ [K in keyof T]: any }>,
5483
+ $batchCount: batchCount as (...args: any[]) => Promise<number[]>,
5484
+ $transaction: transaction as (...args: any[]) => Promise<unknown[]>,
5294
5485
  },
5295
5486
 
5296
5487
  model: {
@@ -5318,6 +5509,18 @@ export function speedExtension(config: {
5318
5509
  })
5319
5510
  }
5320
5511
  }
5512
+
5513
+ type SpeedExtensionReturn = ReturnType<ReturnType<typeof speedExtension>>
5514
+
5515
+ export type SpeedClient<T> = T & {
5516
+ $batch<T extends Record<string, DeferredQuery>>(
5517
+ callback: (batch: BatchProxy) => T | Promise<T>,
5518
+ ): Promise<{ [K in keyof T]: any }>
5519
+ $batchCount(queries: BatchCountQuery[]): Promise<number[]>
5520
+ $transaction(queries: TransactionQuery[], options?: TransactionOptions): Promise<unknown[]>
5521
+ }
5522
+
5523
+ export type { BatchCountQuery, TransactionQuery, TransactionOptions }
5321
5524
  `;
5322
5525
  }
5323
5526
  function formatQueries(queries) {