prisma-sql 1.53.0 → 1.55.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,7 +56,7 @@ var require_package = __commonJS({
56
56
  "package.json"(exports$1, module) {
57
57
  module.exports = {
58
58
  name: "prisma-sql",
59
- version: "1.53.0",
59
+ version: "1.55.0",
60
60
  description: "Convert Prisma queries to optimized SQL with type safety. 2-7x faster than Prisma Client.",
61
61
  main: "dist/index.cjs",
62
62
  module: "dist/index.js",
@@ -815,6 +815,9 @@ function assertNoControlChars(label, s) {
815
815
  );
816
816
  }
817
817
  }
818
+ function quoteRawIdent(id) {
819
+ return `"${id.replace(/"/g, '""')}"`;
820
+ }
818
821
  function isIdentCharCode(c) {
819
822
  return c >= 48 && c <= 57 || c >= 65 && c <= 90 || c >= 97 && c <= 122 || c === 95;
820
823
  }
@@ -836,33 +839,33 @@ function parseQuotedPart(input, start) {
836
839
  }
837
840
  if (!sawAny) {
838
841
  throw new Error(
839
- `tableName/tableRef has empty quoted identifier part: ${JSON.stringify(input)}`
842
+ `qualified name has empty quoted identifier part: ${JSON.stringify(input)}`
840
843
  );
841
844
  }
842
845
  return i + 1;
843
846
  }
844
847
  if (c === 10 || c === 13 || c === 0) {
845
848
  throw new Error(
846
- `tableName/tableRef contains invalid characters: ${JSON.stringify(input)}`
849
+ `qualified name contains invalid characters: ${JSON.stringify(input)}`
847
850
  );
848
851
  }
849
852
  sawAny = true;
850
853
  i++;
851
854
  }
852
855
  throw new Error(
853
- `tableName/tableRef has unterminated quoted identifier: ${JSON.stringify(input)}`
856
+ `qualified name has unterminated quoted identifier: ${JSON.stringify(input)}`
854
857
  );
855
858
  }
856
859
  function parseUnquotedPart(input, start) {
857
860
  const n = input.length;
858
861
  let i = start;
859
862
  if (i >= n) {
860
- throw new Error(`tableName/tableRef is invalid: ${JSON.stringify(input)}`);
863
+ throw new Error(`qualified name is invalid: ${JSON.stringify(input)}`);
861
864
  }
862
865
  const c0 = input.charCodeAt(i);
863
866
  if (!isIdentStartCharCode(c0)) {
864
867
  throw new Error(
865
- `tableName/tableRef must use identifiers (or quoted identifiers). Got: ${JSON.stringify(input)}`
868
+ `qualified name must use identifiers (or quoted identifiers). Got: ${JSON.stringify(input)}`
866
869
  );
867
870
  }
868
871
  i++;
@@ -871,15 +874,15 @@ function parseUnquotedPart(input, start) {
871
874
  if (c === 46) break;
872
875
  if (!isIdentCharCode(c)) {
873
876
  throw new Error(
874
- `tableName/tableRef contains invalid identifier characters: ${JSON.stringify(input)}`
877
+ `qualified name contains invalid identifier characters: ${JSON.stringify(input)}`
875
878
  );
876
879
  }
877
880
  i++;
878
881
  }
879
882
  return i;
880
883
  }
881
- function assertSafeQualifiedName(tableRef) {
882
- const raw = String(tableRef);
884
+ function assertSafeQualifiedName(input) {
885
+ const raw = String(input);
883
886
  const trimmed = raw.trim();
884
887
  if (trimmed.length === 0) {
885
888
  throw new Error("tableName/tableRef is required and cannot be empty");
@@ -953,7 +956,7 @@ function quote2(id) {
953
956
  );
954
957
  }
955
958
  if (needsQuoting(id)) {
956
- return `"${id.replace(/"/g, '""')}"`;
959
+ return quoteRawIdent(id);
957
960
  }
958
961
  return id;
959
962
  }
@@ -2592,28 +2595,24 @@ function buildOperator(expr, op, val, ctx, mode, fieldType) {
2592
2595
  function toSafeSqlIdentifier(input) {
2593
2596
  const raw = String(input);
2594
2597
  const n = raw.length;
2598
+ if (n === 0) return "_t";
2595
2599
  let out = "";
2596
2600
  for (let i = 0; i < n; i++) {
2597
2601
  const c = raw.charCodeAt(i);
2598
2602
  const isAZ = c >= 65 && c <= 90 || c >= 97 && c <= 122;
2599
2603
  const is09 = c >= 48 && c <= 57;
2600
2604
  const isUnderscore = c === 95;
2601
- if (isAZ || is09 || isUnderscore) {
2602
- out += raw[i];
2603
- } else {
2604
- out += "_";
2605
- }
2605
+ out += isAZ || is09 || isUnderscore ? raw[i] : "_";
2606
2606
  }
2607
- if (out.length === 0) out = "_t";
2608
2607
  const c0 = out.charCodeAt(0);
2609
2608
  const startsOk = c0 >= 65 && c0 <= 90 || c0 >= 97 && c0 <= 122 || c0 === 95;
2610
- if (!startsOk) out = `_${out}`;
2611
- const lowered = out.toLowerCase();
2609
+ const lowered = (startsOk ? out : `_${out}`).toLowerCase();
2612
2610
  return ALIAS_FORBIDDEN_KEYWORDS.has(lowered) ? `_${lowered}` : lowered;
2613
2611
  }
2614
2612
  function createAliasGenerator(maxAliases = 1e4) {
2615
2613
  let counter = 0;
2616
2614
  const usedAliases = /* @__PURE__ */ new Set();
2615
+ const maxLen = 63;
2617
2616
  return {
2618
2617
  next(baseName) {
2619
2618
  if (usedAliases.size >= maxAliases) {
@@ -2623,14 +2622,13 @@ function createAliasGenerator(maxAliases = 1e4) {
2623
2622
  }
2624
2623
  const base = toSafeSqlIdentifier(baseName);
2625
2624
  const suffix = `_${counter}`;
2626
- const maxLen = 63;
2627
2625
  const baseMax = Math.max(1, maxLen - suffix.length);
2628
2626
  const trimmedBase = base.length > baseMax ? base.slice(0, baseMax) : base;
2629
2627
  const alias = `${trimmedBase}${suffix}`;
2630
2628
  counter += 1;
2631
2629
  if (usedAliases.has(alias)) {
2632
2630
  throw new Error(
2633
- `CRITICAL: Duplicate alias '${alias}' at counter=${counter}. This indicates a bug in alias generation logic.`
2631
+ `CRITICAL: Duplicate alias '${alias}' at counter=${counter}.`
2634
2632
  );
2635
2633
  }
2636
2634
  usedAliases.add(alias);
@@ -2682,24 +2680,19 @@ function normalizeDynamicNameOrThrow(dynamicName, index) {
2682
2680
  }
2683
2681
  return dn;
2684
2682
  }
2685
- function assertUniqueDynamicName(dn, seen) {
2686
- if (seen.has(dn)) {
2687
- throw new Error(`CRITICAL: Duplicate dynamic param name in mappings: ${dn}`);
2688
- }
2689
- seen.add(dn);
2690
- }
2691
- function validateMappingEntry(m, expectedIndex, seenDynamic) {
2692
- assertSequentialIndex(m.index, expectedIndex);
2693
- assertExactlyOneOfDynamicOrValue(m);
2694
- if (typeof m.dynamicName === "string") {
2695
- const dn = normalizeDynamicNameOrThrow(m.dynamicName, m.index);
2696
- assertUniqueDynamicName(dn, seenDynamic);
2697
- }
2698
- }
2699
2683
  function validateMappings(mappings) {
2700
2684
  const seenDynamic = /* @__PURE__ */ new Set();
2701
2685
  for (let i = 0; i < mappings.length; i++) {
2702
- validateMappingEntry(mappings[i], i + 1, seenDynamic);
2686
+ const m = mappings[i];
2687
+ assertSequentialIndex(m.index, i + 1);
2688
+ assertExactlyOneOfDynamicOrValue(m);
2689
+ if (typeof m.dynamicName === "string") {
2690
+ const dn = normalizeDynamicNameOrThrow(m.dynamicName, m.index);
2691
+ if (seenDynamic.has(dn)) {
2692
+ throw new Error(`CRITICAL: Duplicate dynamic param name: ${dn}`);
2693
+ }
2694
+ seenDynamic.add(dn);
2695
+ }
2703
2696
  }
2704
2697
  }
2705
2698
  function validateState(params, mappings, index) {
@@ -2711,16 +2704,19 @@ function validateState(params, mappings, index) {
2711
2704
  }
2712
2705
  function createStoreInternal(startIndex, initialParams = [], initialMappings = []) {
2713
2706
  let index = startIndex;
2714
- const params = initialParams.length > 0 ? [...initialParams] : [];
2715
- const mappings = initialMappings.length > 0 ? [...initialMappings] : [];
2707
+ const params = initialParams.length > 0 ? initialParams.slice() : [];
2708
+ const mappings = initialMappings.length > 0 ? initialMappings.slice() : [];
2716
2709
  const dynamicNameToIndex = /* @__PURE__ */ new Map();
2717
- for (const m of mappings) {
2710
+ for (let i = 0; i < mappings.length; i++) {
2711
+ const m = mappings[i];
2718
2712
  if (typeof m.dynamicName === "string") {
2719
2713
  dynamicNameToIndex.set(m.dynamicName.trim(), m.index);
2720
2714
  }
2721
2715
  }
2722
2716
  let dirty = true;
2723
2717
  let cachedSnapshot = null;
2718
+ let frozenParams = null;
2719
+ let frozenMappings = null;
2724
2720
  function assertCanAdd() {
2725
2721
  if (index > MAX_PARAM_INDEX) {
2726
2722
  throw new Error(
@@ -2772,13 +2768,17 @@ function createStoreInternal(startIndex, initialParams = [], initialMappings = [
2772
2768
  }
2773
2769
  function snapshot() {
2774
2770
  if (!dirty && cachedSnapshot) return cachedSnapshot;
2771
+ if (!frozenParams) frozenParams = Object.freeze(params.slice());
2772
+ if (!frozenMappings) frozenMappings = Object.freeze(mappings.slice());
2775
2773
  const snap = {
2776
2774
  index,
2777
- params,
2778
- mappings
2775
+ params: frozenParams,
2776
+ mappings: frozenMappings
2779
2777
  };
2780
2778
  cachedSnapshot = snap;
2781
2779
  dirty = false;
2780
+ frozenParams = null;
2781
+ frozenMappings = null;
2782
2782
  return snap;
2783
2783
  }
2784
2784
  return {
@@ -2802,11 +2802,11 @@ function createParamStore(startIndex = 1) {
2802
2802
  return createStoreInternal(startIndex);
2803
2803
  }
2804
2804
  function createParamStoreFrom(existingParams, existingMappings, nextIndex) {
2805
- validateState([...existingParams], [...existingMappings], nextIndex);
2805
+ validateState(existingParams, existingMappings, nextIndex);
2806
2806
  return createStoreInternal(
2807
2807
  nextIndex,
2808
- [...existingParams],
2809
- [...existingMappings]
2808
+ existingParams.slice(),
2809
+ existingMappings.slice()
2810
2810
  );
2811
2811
  }
2812
2812
 
@@ -2988,7 +2988,7 @@ function getRelationTableReference(relModel, dialect) {
2988
2988
  dialect
2989
2989
  );
2990
2990
  }
2991
- function resolveRelationOrThrow(model, schemas, schemaByName, relName) {
2991
+ function resolveRelationOrThrow(model, schemaByName, relName) {
2992
2992
  const field = model.fields.find((f) => f.name === relName);
2993
2993
  if (!isNotNullish(field)) {
2994
2994
  throw new Error(
@@ -3042,8 +3042,9 @@ function validateOrderByForModel(model, orderBy) {
3042
3042
  throw new Error("orderBy array entries must have exactly one field");
3043
3043
  }
3044
3044
  const fieldName = String(entries[0][0]).trim();
3045
- if (fieldName.length === 0)
3045
+ if (fieldName.length === 0) {
3046
3046
  throw new Error("orderBy field name cannot be empty");
3047
+ }
3047
3048
  if (!scalarSet.has(fieldName)) {
3048
3049
  throw new Error(
3049
3050
  `orderBy references unknown or non-scalar field '${fieldName}' on model ${model.name}`
@@ -3102,8 +3103,9 @@ function extractRelationPaginationConfig(relArgs) {
3102
3103
  function maybeReverseNegativeTake(takeVal, hasOrderBy, orderByInput) {
3103
3104
  if (typeof takeVal !== "number") return { takeVal, orderByInput };
3104
3105
  if (takeVal >= 0) return { takeVal, orderByInput };
3105
- if (!hasOrderBy)
3106
+ if (!hasOrderBy) {
3106
3107
  throw new Error("Negative take requires orderBy for deterministic results");
3108
+ }
3107
3109
  return {
3108
3110
  takeVal: Math.abs(takeVal),
3109
3111
  orderByInput: reverseOrderByInput(orderByInput)
@@ -3113,9 +3115,7 @@ function finalizeOrderByForInclude(args) {
3113
3115
  if (args.hasOrderBy && isNotNullish(args.orderByInput)) {
3114
3116
  validateOrderByForModel(args.relModel, args.orderByInput);
3115
3117
  }
3116
- if (!args.hasPagination) {
3117
- return args.orderByInput;
3118
- }
3118
+ if (!args.hasPagination) return args.orderByInput;
3119
3119
  return ensureDeterministicOrderByInput({
3120
3120
  orderBy: args.hasOrderBy ? args.orderByInput : void 0,
3121
3121
  model: args.relModel,
@@ -3176,7 +3176,9 @@ function buildOrderBySql(finalOrderByInput, relAlias, dialect, relModel) {
3176
3176
  return isNotNullish(finalOrderByInput) ? buildOrderBy(finalOrderByInput, relAlias, dialect, relModel) : "";
3177
3177
  }
3178
3178
  function buildBaseSql(args) {
3179
- return `${SQL_TEMPLATES.SELECT} ${args.selectExpr} ${SQL_TEMPLATES.FROM} ${args.relTable} ${args.relAlias} ${args.joins} ${SQL_TEMPLATES.WHERE} ${args.joinPredicate}${args.whereClause}`;
3179
+ const joins = args.joins ? ` ${args.joins}` : "";
3180
+ const where = `${SQL_TEMPLATES.WHERE} ${args.joinPredicate}${args.whereClause}`;
3181
+ return `${SQL_TEMPLATES.SELECT} ${args.selectExpr} ${SQL_TEMPLATES.FROM} ${args.relTable} ${args.relAlias}${joins} ` + where;
3180
3182
  }
3181
3183
  function buildOneToOneIncludeSql(args) {
3182
3184
  const objExpr = jsonBuildObject(args.relSelect, args.ctx.dialect);
@@ -3188,9 +3190,7 @@ function buildOneToOneIncludeSql(args) {
3188
3190
  joinPredicate: args.joinPredicate,
3189
3191
  whereClause: args.whereClause
3190
3192
  });
3191
- if (args.orderBySql) {
3192
- sql += ` ${SQL_TEMPLATES.ORDER_BY} ${args.orderBySql}`;
3193
- }
3193
+ if (args.orderBySql) sql += ` ${SQL_TEMPLATES.ORDER_BY} ${args.orderBySql}`;
3194
3194
  if (isNotNullish(args.takeVal)) {
3195
3195
  return appendLimitOffset(
3196
3196
  sql,
@@ -3243,7 +3243,7 @@ function buildListIncludeSpec(args) {
3243
3243
  `include.${args.relName}`
3244
3244
  );
3245
3245
  const selectExpr = jsonAgg("row", args.ctx.dialect);
3246
- const sql = `${SQL_TEMPLATES.SELECT} ${selectExpr} ${SQL_TEMPLATES.FROM} (${base}) ${rowAlias}`;
3246
+ const sql = `${SQL_TEMPLATES.SELECT} ${selectExpr} ${SQL_TEMPLATES.FROM} (${base}) ${SQL_TEMPLATES.AS} ${rowAlias}`;
3247
3247
  return Object.freeze({ name: args.relName, sql, isOneToOne: false });
3248
3248
  }
3249
3249
  function buildSingleInclude(relName, relArgs, field, relModel, ctx) {
@@ -3341,12 +3341,7 @@ function buildIncludeSqlInternal(args, model, schemas, schemaByName, parentAlias
3341
3341
  `Query complexity limit exceeded: ${stats.totalSubqueries} subqueries generated. Maximum allowed: ${MAX_TOTAL_SUBQUERIES}. This indicates exponential include nesting. Stats: depth=${stats.maxDepth}, includes=${stats.totalIncludes}. Path: ${visitPath.join(" -> ")}. Simplify your include structure or split into multiple queries.`
3342
3342
  );
3343
3343
  }
3344
- const resolved = resolveRelationOrThrow(
3345
- model,
3346
- schemas,
3347
- schemaByName,
3348
- relName
3349
- );
3344
+ const resolved = resolveRelationOrThrow(model, schemaByName, relName);
3350
3345
  const relationPath = `${model.name}.${relName}`;
3351
3346
  const currentPath = [...visitPath, relationPath];
3352
3347
  if (visitPath.includes(relationPath)) {
@@ -3402,7 +3397,7 @@ function buildIncludeSql(args, model, schemas, parentAlias, params, dialect) {
3402
3397
  stats
3403
3398
  );
3404
3399
  }
3405
- function resolveCountRelationOrThrow(relName, model, schemas, schemaByName) {
3400
+ function resolveCountRelationOrThrow(relName, model, schemaByName) {
3406
3401
  const relationSet = getRelationFieldSet(model);
3407
3402
  if (!relationSet.has(relName)) {
3408
3403
  throw new Error(
@@ -3410,10 +3405,11 @@ function resolveCountRelationOrThrow(relName, model, schemas, schemaByName) {
3410
3405
  );
3411
3406
  }
3412
3407
  const field = model.fields.find((f) => f.name === relName);
3413
- if (!field)
3408
+ if (!field) {
3414
3409
  throw new Error(
3415
3410
  `_count.${relName} references unknown relation on model ${model.name}`
3416
3411
  );
3412
+ }
3417
3413
  if (!isValidRelationField(field)) {
3418
3414
  throw new Error(
3419
3415
  `_count.${relName} has invalid relation metadata on model ${model.name}`
@@ -3441,8 +3437,9 @@ function defaultReferencesForCount(fkCount) {
3441
3437
  }
3442
3438
  function resolveCountKeyPairs(field) {
3443
3439
  const fkFields = normalizeKeyList(field.foreignKey);
3444
- if (fkFields.length === 0)
3440
+ if (fkFields.length === 0) {
3445
3441
  throw new Error("Relation count requires foreignKey");
3442
+ }
3446
3443
  const refsRaw = field.references;
3447
3444
  const refs = normalizeKeyList(refsRaw);
3448
3445
  const refFields = refs.length > 0 ? refs : defaultReferencesForCount(fkFields.length);
@@ -3518,12 +3515,7 @@ function buildRelationCountSql(countSelect, model, schemas, parentAlias, _params
3518
3515
  for (const m of schemas) schemaByName.set(m.name, m);
3519
3516
  for (const [relName, shouldCount] of Object.entries(countSelect)) {
3520
3517
  if (!shouldCount) continue;
3521
- const resolved = resolveCountRelationOrThrow(
3522
- relName,
3523
- model,
3524
- schemas,
3525
- schemaByName
3526
- );
3518
+ const resolved = resolveCountRelationOrThrow(relName, model, schemaByName);
3527
3519
  const built = buildCountJoinAndPair({
3528
3520
  relName,
3529
3521
  field: resolved.field,
@@ -4953,12 +4945,81 @@ function generateCode(models, queries, dialect, datamodel) {
4953
4945
  }));
4954
4946
  const { mappings, fieldTypes } = extractEnumMappings(datamodel);
4955
4947
  return `// Generated by @prisma-sql/generator - DO NOT EDIT
4956
- import { buildSQL, transformQueryResults, type PrismaMethod, type Model } from 'prisma-sql'
4948
+ import { buildSQL, buildBatchSql, parseBatchResults, buildBatchCountSql, parseBatchCountResults, createTransactionExecutor, transformQueryResults, type PrismaMethod, type Model, type BatchQuery, type BatchCountQuery, type TransactionQuery, type TransactionOptions } from 'prisma-sql'
4949
+
4950
+ class DeferredQuery {
4951
+ constructor(
4952
+ public readonly model: string,
4953
+ public readonly method: PrismaMethod,
4954
+ public readonly args: any,
4955
+ ) {}
4956
+
4957
+ then(onfulfilled?: any, onrejected?: any): any {
4958
+ throw new Error(
4959
+ 'Cannot await a batch query. Batch queries must not be awaited inside the $batch callback.',
4960
+ )
4961
+ }
4962
+ }
4963
+
4964
+ interface BatchProxy {
4965
+ [modelName: string]: {
4966
+ findMany: (args?: any) => DeferredQuery
4967
+ findFirst: (args?: any) => DeferredQuery
4968
+ findUnique: (args?: any) => DeferredQuery
4969
+ count: (args?: any) => DeferredQuery
4970
+ aggregate: (args?: any) => DeferredQuery
4971
+ groupBy: (args?: any) => DeferredQuery
4972
+ }
4973
+ }
4974
+
4975
+ const ACCELERATED_METHODS = new Set<PrismaMethod>([
4976
+ 'findMany',
4977
+ 'findFirst',
4978
+ 'findUnique',
4979
+ 'count',
4980
+ 'aggregate',
4981
+ 'groupBy',
4982
+ ])
4983
+
4984
+ function createBatchProxy(): BatchProxy {
4985
+ return new Proxy(
4986
+ {},
4987
+ {
4988
+ get(_target, modelName: string): any {
4989
+ if (typeof modelName === 'symbol') return undefined
4990
+
4991
+ const model = MODEL_MAP.get(modelName)
4992
+ if (!model) {
4993
+ throw new Error(
4994
+ \`Model '\${modelName}' not found. Available: \${[...MODEL_MAP.keys()].join(', ')}\`,
4995
+ )
4996
+ }
4997
+
4998
+ return new Proxy(
4999
+ {},
5000
+ {
5001
+ get(_target, method: string): (args?: any) => DeferredQuery {
5002
+ if (!ACCELERATED_METHODS.has(method as PrismaMethod)) {
5003
+ throw new Error(
5004
+ \`Method '\${method}' not supported in batch. Supported: \${[...ACCELERATED_METHODS].join(', ')}\`,
5005
+ )
5006
+ }
5007
+
5008
+ return (args?: any): DeferredQuery => {
5009
+ return new DeferredQuery(
5010
+ modelName,
5011
+ method as PrismaMethod,
5012
+ args,
5013
+ )
5014
+ }
5015
+ },
5016
+ },
5017
+ )
5018
+ },
5019
+ },
5020
+ ) as BatchProxy
5021
+ }
4957
5022
 
4958
- /**
4959
- * Normalize values for SQL params.
4960
- * Synced from src/utils/normalize-value.ts
4961
- */
4962
5023
  function normalizeValue(value: unknown, seen = new WeakSet<object>(), depth = 0): unknown {
4963
5024
  const MAX_DEPTH = 20
4964
5025
  if (depth > MAX_DEPTH) {
@@ -5005,9 +5066,6 @@ function normalizeValue(value: unknown, seen = new WeakSet<object>(), depth = 0)
5005
5066
  return value
5006
5067
  }
5007
5068
 
5008
- /**
5009
- * Get nested value from object using dot notation path
5010
- */
5011
5069
  function getByPath(obj: any, path: string): unknown {
5012
5070
  if (!obj || !path) return undefined
5013
5071
  const keys = path.split('.')
@@ -5019,9 +5077,6 @@ function getByPath(obj: any, path: string): unknown {
5019
5077
  return result
5020
5078
  }
5021
5079
 
5022
- /**
5023
- * Normalize all params in array
5024
- */
5025
5080
  function normalizeParams(params: unknown[]): unknown[] {
5026
5081
  return params.map(p => normalizeValue(p))
5027
5082
  }
@@ -5042,6 +5097,8 @@ const QUERIES: Record<string, Record<string, Record<string, {
5042
5097
 
5043
5098
  const DIALECT = ${JSON.stringify(dialect)}
5044
5099
 
5100
+ const MODEL_MAP = new Map(MODELS.map(m => [m.name, m]))
5101
+
5045
5102
  function isDynamicParam(key: string): boolean {
5046
5103
  return key === 'skip' || key === 'take' || key === 'cursor'
5047
5104
  }
@@ -5208,6 +5265,13 @@ async function executeQuery(client: any, sql: string, params: unknown[]): Promis
5208
5265
  return stmt.all(...normalizedParams)
5209
5266
  }
5210
5267
 
5268
+ async function executeRaw(client: any, sql: string, params?: unknown[]): Promise<unknown[]> {
5269
+ if (DIALECT === 'postgres') {
5270
+ return await client.unsafe(sql, (params || []) as any[])
5271
+ }
5272
+ throw new Error('Raw execution for sqlite not supported in transactions')
5273
+ }
5274
+
5211
5275
  export function speedExtension(config: {
5212
5276
  postgres?: any
5213
5277
  sqlite?: any
@@ -5235,6 +5299,14 @@ export function speedExtension(config: {
5235
5299
  }
5236
5300
 
5237
5301
  return (prisma: any) => {
5302
+ const txExecutor = createTransactionExecutor({
5303
+ modelMap: MODEL_MAP,
5304
+ allModels: MODELS,
5305
+ dialect: DIALECT,
5306
+ executeRaw: (sql: string, params?: unknown[]) => executeRaw(client, sql, params),
5307
+ postgresClient: postgres,
5308
+ })
5309
+
5238
5310
  const handleMethod = async function(this: any, method: PrismaMethod, args: any) {
5239
5311
  const modelName = this?.name || this?.$name
5240
5312
  const startTime = Date.now()
@@ -5288,11 +5360,130 @@ export function speedExtension(config: {
5288
5360
  return transformQueryResults(method, results)
5289
5361
  }
5290
5362
 
5363
+ async function batch<T extends Record<string, DeferredQuery>>(
5364
+ callback: (batch: BatchProxy) => T | Promise<T>,
5365
+ ): Promise<{ [K in keyof T]: any }> {
5366
+ const batchProxy = createBatchProxy()
5367
+ const queries = await callback(batchProxy)
5368
+
5369
+ const batchQueries: Record<string, BatchQuery> = {}
5370
+
5371
+ for (const [key, deferred] of Object.entries(queries)) {
5372
+ if (!(deferred instanceof DeferredQuery)) {
5373
+ throw new Error(
5374
+ \`Batch query '\${key}' must be a deferred query. Did you await it?\`,
5375
+ )
5376
+ }
5377
+
5378
+ batchQueries[key] = {
5379
+ model: deferred.model,
5380
+ method: deferred.method,
5381
+ args: deferred.args || {},
5382
+ }
5383
+ }
5384
+
5385
+ const startTime = Date.now()
5386
+ const { sql, params, keys } = buildBatchSql(
5387
+ batchQueries,
5388
+ MODEL_MAP,
5389
+ MODELS,
5390
+ DIALECT,
5391
+ )
5392
+
5393
+ if (debug) {
5394
+ console.log(\`[\${DIALECT}] $batch (\${keys.length} queries)\`)
5395
+ console.log('SQL:', sql)
5396
+ console.log('Params:', params)
5397
+ }
5398
+
5399
+ const normalizedParams = normalizeParams(params)
5400
+ const rows = await client.unsafe(sql, normalizedParams as any[])
5401
+ const row = rows[0] as Record<string, unknown>
5402
+ const results = parseBatchResults(row, keys, batchQueries)
5403
+ const duration = Date.now() - startTime
5404
+
5405
+ onQuery?.({
5406
+ model: '_batch',
5407
+ method: 'batch',
5408
+ sql,
5409
+ params: normalizedParams,
5410
+ duration,
5411
+ prebaked: false,
5412
+ })
5413
+
5414
+ return results as { [K in keyof T]: any }
5415
+ }
5416
+
5417
+ async function batchCount(queries: BatchCountQuery[]): Promise<number[]> {
5418
+ if (queries.length === 0) return []
5419
+
5420
+ const startTime = Date.now()
5421
+ const { sql, params } = buildBatchCountSql(
5422
+ queries,
5423
+ MODEL_MAP,
5424
+ MODELS,
5425
+ DIALECT,
5426
+ )
5427
+
5428
+ if (debug) {
5429
+ console.log(\`[\${DIALECT}] $batchCount (\${queries.length} queries)\`)
5430
+ console.log('SQL:', sql)
5431
+ console.log('Params:', params)
5432
+ }
5433
+
5434
+ const normalizedParams = normalizeParams(params)
5435
+ const rows = await client.unsafe(sql, normalizedParams as any[])
5436
+ const row = rows[0] as Record<string, unknown>
5437
+ const results = parseBatchCountResults(row, queries.length)
5438
+ const duration = Date.now() - startTime
5439
+
5440
+ onQuery?.({
5441
+ model: '_batch',
5442
+ method: 'count',
5443
+ sql,
5444
+ params: normalizedParams,
5445
+ duration,
5446
+ prebaked: false,
5447
+ })
5448
+
5449
+ return results
5450
+ }
5451
+
5452
+ async function transaction(
5453
+ queries: TransactionQuery[],
5454
+ options?: TransactionOptions,
5455
+ ): Promise<unknown[]> {
5456
+ const startTime = Date.now()
5457
+
5458
+ if (debug) {
5459
+ console.log(\`[\${DIALECT}] $transaction (\${queries.length} queries)\`)
5460
+ }
5461
+
5462
+ const results = await txExecutor.execute(queries, options)
5463
+ const duration = Date.now() - startTime
5464
+
5465
+ onQuery?.({
5466
+ model: '_transaction',
5467
+ method: 'count',
5468
+ sql: \`TRANSACTION(\${queries.length})\`,
5469
+ params: [],
5470
+ duration,
5471
+ prebaked: false,
5472
+ })
5473
+
5474
+ return results
5475
+ }
5476
+
5291
5477
  return prisma.$extends({
5292
5478
  name: 'prisma-sql-generated',
5293
5479
 
5294
5480
  client: {
5295
5481
  $original: prisma,
5482
+ $batch: batch as <T extends Record<string, DeferredQuery>>(
5483
+ callback: (batch: BatchProxy) => T | Promise<T>,
5484
+ ) => Promise<{ [K in keyof T]: any }>,
5485
+ $batchCount: batchCount as (...args: any[]) => Promise<number[]>,
5486
+ $transaction: transaction as (...args: any[]) => Promise<unknown[]>,
5296
5487
  },
5297
5488
 
5298
5489
  model: {
@@ -5320,6 +5511,18 @@ export function speedExtension(config: {
5320
5511
  })
5321
5512
  }
5322
5513
  }
5514
+
5515
+ type SpeedExtensionReturn = ReturnType<ReturnType<typeof speedExtension>>
5516
+
5517
+ export type SpeedClient<T> = T & {
5518
+ $batch<T extends Record<string, DeferredQuery>>(
5519
+ callback: (batch: BatchProxy) => T | Promise<T>,
5520
+ ): Promise<{ [K in keyof T]: any }>
5521
+ $batchCount(queries: BatchCountQuery[]): Promise<number[]>
5522
+ $transaction(queries: TransactionQuery[], options?: TransactionOptions): Promise<unknown[]>
5523
+ }
5524
+
5525
+ export type { BatchCountQuery, TransactionQuery, TransactionOptions }
5323
5526
  `;
5324
5527
  }
5325
5528
  function formatQueries(queries) {