@prisma-next/sql-runtime 0.4.1 → 0.5.0-dev.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ import type { DraftPlan, SqlMiddleware, SqlMiddlewareContext } from './sql-middleware';
2
+
3
+ export async function runBeforeCompileChain(
4
+ middleware: readonly SqlMiddleware[],
5
+ initial: DraftPlan,
6
+ ctx: SqlMiddlewareContext,
7
+ ): Promise<DraftPlan> {
8
+ let current = initial;
9
+ for (const mw of middleware) {
10
+ if (!mw.beforeCompile) {
11
+ continue;
12
+ }
13
+ const result = await mw.beforeCompile(current, ctx);
14
+ if (result === undefined) {
15
+ continue;
16
+ }
17
+ if (result.ast === current.ast) {
18
+ continue;
19
+ }
20
+ ctx.log.debug?.({
21
+ event: 'middleware.rewrite',
22
+ middleware: mw.name,
23
+ lane: current.meta.lane,
24
+ });
25
+ current = result;
26
+ }
27
+ return current;
28
+ }
@@ -1,11 +1,8 @@
1
1
  import type { ExecutionPlan } from '@prisma-next/contract/types';
2
2
  import { type RuntimeErrorEnvelope, runtimeError } from '@prisma-next/framework-components/runtime';
3
- import type {
4
- AfterExecuteResult,
5
- Middleware,
6
- MiddlewareContext,
7
- } from '@prisma-next/runtime-executor';
3
+ import type { AfterExecuteResult } from '@prisma-next/runtime-executor';
8
4
  import { isQueryAst, type SelectAst } from '@prisma-next/sql-relational-core/ast';
5
+ import type { SqlMiddleware, SqlMiddlewareContext } from './sql-middleware';
9
6
 
10
7
  export interface BudgetsOptions {
11
8
  readonly maxRows?: number;
@@ -77,7 +74,7 @@ function hasDetectableLimitFromHeuristics(plan: ExecutionPlan): boolean {
77
74
  function emitBudgetViolation(
78
75
  error: RuntimeErrorEnvelope,
79
76
  shouldBlock: boolean,
80
- ctx: MiddlewareContext<unknown>,
77
+ ctx: SqlMiddlewareContext,
81
78
  ): void {
82
79
  if (shouldBlock) {
83
80
  throw error;
@@ -89,7 +86,7 @@ function emitBudgetViolation(
89
86
  });
90
87
  }
91
88
 
92
- export function budgets<TContract = unknown>(options?: BudgetsOptions): Middleware<TContract> {
89
+ export function budgets(options?: BudgetsOptions): SqlMiddleware {
93
90
  const maxRows = options?.maxRows ?? 10_000;
94
91
  const defaultTableRows = options?.defaultTableRows ?? 10_000;
95
92
  const tableRows = options?.tableRows ?? {};
@@ -102,7 +99,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
102
99
  name: 'budgets',
103
100
  familyId: 'sql' as const,
104
101
 
105
- async beforeExecute(plan: ExecutionPlan, ctx: MiddlewareContext<TContract>) {
102
+ async beforeExecute(plan: ExecutionPlan, ctx: SqlMiddlewareContext) {
106
103
  observedRowsByPlan.set(plan, { count: 0 });
107
104
 
108
105
  if (isQueryAst(plan.ast)) {
@@ -115,11 +112,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
115
112
  return evaluateWithHeuristics(plan, ctx);
116
113
  },
117
114
 
118
- async onRow(
119
- _row: Record<string, unknown>,
120
- plan: ExecutionPlan,
121
- _ctx: MiddlewareContext<TContract>,
122
- ) {
115
+ async onRow(_row: Record<string, unknown>, plan: ExecutionPlan, _ctx: SqlMiddlewareContext) {
123
116
  const state = observedRowsByPlan.get(plan);
124
117
  if (!state) return;
125
118
  state.count += 1;
@@ -135,7 +128,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
135
128
  async afterExecute(
136
129
  _plan: ExecutionPlan,
137
130
  result: AfterExecuteResult,
138
- ctx: MiddlewareContext<TContract>,
131
+ ctx: SqlMiddlewareContext,
139
132
  ) {
140
133
  const latencyMs = result.latencyMs;
141
134
  if (latencyMs > maxLatencyMs) {
@@ -146,17 +139,13 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
146
139
  maxLatencyMs,
147
140
  }),
148
141
  shouldBlock,
149
- ctx as MiddlewareContext<unknown>,
142
+ ctx,
150
143
  );
151
144
  }
152
145
  },
153
146
  });
154
147
 
155
- function evaluateSelectAst(
156
- plan: ExecutionPlan,
157
- ast: SelectAst,
158
- ctx: MiddlewareContext<TContract>,
159
- ) {
148
+ function evaluateSelectAst(plan: ExecutionPlan, ast: SelectAst, ctx: SqlMiddlewareContext) {
160
149
  const hasAggNoGroup = hasAggregateWithoutGroupBy(ast);
161
150
  const estimated = estimateRowsFromAst(
162
151
  ast,
@@ -177,7 +166,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
177
166
  maxRows,
178
167
  }),
179
168
  shouldBlock,
180
- ctx as MiddlewareContext<unknown>,
169
+ ctx,
181
170
  );
182
171
  return;
183
172
  }
@@ -188,7 +177,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
188
177
  maxRows,
189
178
  }),
190
179
  shouldBlock,
191
- ctx as MiddlewareContext<unknown>,
180
+ ctx,
192
181
  );
193
182
  return;
194
183
  }
@@ -201,12 +190,12 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
201
190
  maxRows,
202
191
  }),
203
192
  shouldBlock,
204
- ctx as MiddlewareContext<unknown>,
193
+ ctx,
205
194
  );
206
195
  }
207
196
  }
208
197
 
209
- async function evaluateWithHeuristics(plan: ExecutionPlan, ctx: MiddlewareContext<TContract>) {
198
+ async function evaluateWithHeuristics(plan: ExecutionPlan, ctx: SqlMiddlewareContext) {
210
199
  const estimated = estimateRowsFromHeuristics(plan, tableRows, defaultTableRows);
211
200
  const isUnbounded = !hasDetectableLimitFromHeuristics(plan);
212
201
  const sqlUpper = plan.sql.trimStart().toUpperCase();
@@ -222,7 +211,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
222
211
  maxRows,
223
212
  }),
224
213
  shouldBlock,
225
- ctx as MiddlewareContext<unknown>,
214
+ ctx,
226
215
  );
227
216
  return;
228
217
  }
@@ -233,7 +222,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
233
222
  maxRows,
234
223
  }),
235
224
  shouldBlock,
236
- ctx as MiddlewareContext<unknown>,
225
+ ctx,
237
226
  );
238
227
  return;
239
228
  }
@@ -247,7 +236,7 @@ export function budgets<TContract = unknown>(options?: BudgetsOptions): Middlewa
247
236
  maxRows,
248
237
  }),
249
238
  shouldBlock,
250
- ctx as MiddlewareContext<unknown>,
239
+ ctx,
251
240
  );
252
241
  }
253
242
  return;
@@ -1,6 +1,5 @@
1
1
  import type { ExecutionPlan } from '@prisma-next/contract/types';
2
2
  import { runtimeError } from '@prisma-next/framework-components/runtime';
3
- import type { Middleware, MiddlewareContext } from '@prisma-next/runtime-executor';
4
3
  import { evaluateRawGuardrails } from '@prisma-next/runtime-executor';
5
4
  import {
6
5
  type AnyFromSource,
@@ -8,6 +7,7 @@ import {
8
7
  isQueryAst,
9
8
  } from '@prisma-next/sql-relational-core/ast';
10
9
  import { ifDefined } from '@prisma-next/utils/defined';
10
+ import type { SqlMiddleware, SqlMiddlewareContext } from './sql-middleware';
11
11
 
12
12
  export interface LintsOptions {
13
13
  readonly severities?: {
@@ -138,14 +138,14 @@ function getConfiguredSeverity(code: string, options?: LintsOptions): 'warn' | '
138
138
  * Fallback: When ast is missing, `fallbackWhenAstMissing: 'raw'` uses heuristic
139
139
  * SQL parsing; `'skip'` skips all lints. Default is `'raw'`.
140
140
  */
141
- export function lints<TContract = unknown>(options?: LintsOptions): Middleware<TContract> {
141
+ export function lints(options?: LintsOptions): SqlMiddleware {
142
142
  const fallback = options?.fallbackWhenAstMissing ?? 'raw';
143
143
 
144
144
  return Object.freeze({
145
145
  name: 'lints',
146
146
  familyId: 'sql' as const,
147
147
 
148
- async beforeExecute(plan: ExecutionPlan, ctx: MiddlewareContext<TContract>) {
148
+ async beforeExecute(plan: ExecutionPlan, ctx: SqlMiddlewareContext) {
149
149
  if (isQueryAst(plan.ast)) {
150
150
  const findings = evaluateAstLints(plan.ast);
151
151
 
@@ -1,17 +1,46 @@
1
- import type { Contract, ExecutionPlan } from '@prisma-next/contract/types';
1
+ import type { Contract, ExecutionPlan, PlanMeta } from '@prisma-next/contract/types';
2
2
  import type {
3
3
  AfterExecuteResult,
4
4
  RuntimeMiddleware,
5
5
  RuntimeMiddlewareContext,
6
6
  } from '@prisma-next/framework-components/runtime';
7
7
  import type { SqlStorage } from '@prisma-next/sql-contract/types';
8
+ import type { AnyQueryAst } from '@prisma-next/sql-relational-core/ast';
8
9
 
9
10
  export interface SqlMiddlewareContext extends RuntimeMiddlewareContext {
10
11
  readonly contract: Contract<SqlStorage>;
11
12
  }
12
13
 
14
+ /**
15
+ * Pre-lowering query view passed to `beforeCompile`. Carries the typed SQL
16
+ * AST and plan metadata; `sql`/`params` are produced later by the adapter.
17
+ */
18
+ export interface DraftPlan {
19
+ readonly ast: AnyQueryAst;
20
+ readonly meta: PlanMeta;
21
+ }
22
+
13
23
  export interface SqlMiddleware extends RuntimeMiddleware {
14
- readonly familyId: 'sql';
24
+ readonly familyId?: 'sql';
25
+ /**
26
+ * Rewrite the query AST before it is lowered to SQL. Middlewares run in
27
+ * registration order; each sees the predecessor's output, so rewrites
28
+ * compose (e.g. soft-delete + tenant isolation).
29
+ *
30
+ * Return `undefined` (or a draft whose `ast` reference equals the input's)
31
+ * to pass through. Return a draft with a new `ast` reference to replace it;
32
+ * the runtime emits a `middleware.rewrite` debug log event and continues
33
+ * with the new draft. `adapter.lower()` runs once after the chain.
34
+ *
35
+ * Use `AstRewriter` / `SelectAst.withWhere` / `AndExpr.of` etc. to build
36
+ * the rewritten AST. Predicates and literals go through parameterized
37
+ * constructors by default — no SQL-injection surface is added. **Warning:**
38
+ * constructing `LiteralExpr.of(userInput)` from untrusted input bypasses
39
+ * that guarantee; use `ParamRef.of(userInput, ...)` instead.
40
+ *
41
+ * See `docs/architecture docs/subsystems/4. Runtime & Middleware Framework.md`.
42
+ */
43
+ beforeCompile?(draft: DraftPlan, ctx: SqlMiddlewareContext): Promise<DraftPlan | undefined>;
15
44
  beforeExecute?(plan: ExecutionPlan, ctx: SqlMiddlewareContext): Promise<void>;
16
45
  onRow?(
17
46
  row: Record<string, unknown>,
@@ -6,7 +6,6 @@ import type {
6
6
  import { checkMiddlewareCompatibility } from '@prisma-next/framework-components/runtime';
7
7
  import type {
8
8
  Log,
9
- Middleware,
10
9
  RuntimeCore,
11
10
  RuntimeCoreOptions,
12
11
  RuntimeTelemetryEvent,
@@ -33,6 +32,8 @@ import { decodeRow } from './codecs/decoding';
33
32
  import { encodeParams } from './codecs/encoding';
34
33
  import { validateCodecRegistryCompleteness } from './codecs/validation';
35
34
  import { lowerSqlPlan } from './lower-sql-plan';
35
+ import { runBeforeCompileChain } from './middleware/before-compile-chain';
36
+ import type { SqlMiddleware } from './middleware/sql-middleware';
36
37
  import type {
37
38
  ExecutionContext,
38
39
  SqlRuntimeAdapterInstance,
@@ -45,7 +46,7 @@ export interface RuntimeOptions<TContract extends Contract<SqlStorage> = Contrac
45
46
  readonly adapter: Adapter<AnyQueryAst, Contract<SqlStorage>, LoweredStatement>;
46
47
  readonly driver: SqlDriver<unknown>;
47
48
  readonly verify: RuntimeVerifyOptions;
48
- readonly middleware?: readonly Middleware<TContract>[];
49
+ readonly middleware?: readonly SqlMiddleware[];
49
50
  readonly mode?: 'strict' | 'permissive';
50
51
  readonly log?: Log;
51
52
  }
@@ -64,7 +65,7 @@ export interface CreateRuntimeOptions<
64
65
  readonly context: ExecutionContext<TContract>;
65
66
  readonly driver: SqlDriver<unknown>;
66
67
  readonly verify: RuntimeVerifyOptions;
67
- readonly middleware?: readonly Middleware<TContract>[];
68
+ readonly middleware?: readonly SqlMiddleware[];
68
69
  readonly mode?: 'strict' | 'permissive';
69
70
  readonly log?: Log;
70
71
  }
@@ -125,7 +126,7 @@ export type { RuntimeTelemetryEvent, RuntimeVerifyOptions, TelemetryOutcome };
125
126
  class SqlRuntimeImpl<TContract extends Contract<SqlStorage> = Contract<SqlStorage>>
126
127
  implements Runtime
127
128
  {
128
- private readonly core: RuntimeCore<TContract, SqlDriver<unknown>>;
129
+ private readonly core: RuntimeCore<TContract, SqlDriver<unknown>, SqlMiddleware>;
129
130
  private readonly contract: TContract;
130
131
  private readonly adapter: Adapter<AnyQueryAst, Contract<SqlStorage>, LoweredStatement>;
131
132
  private readonly codecRegistry: CodecRegistry;
@@ -148,7 +149,7 @@ class SqlRuntimeImpl<TContract extends Contract<SqlStorage> = Contract<SqlStorag
148
149
 
149
150
  const familyAdapter = new SqlFamilyAdapter(context.contract, adapter.profile);
150
151
 
151
- const coreOptions: RuntimeCoreOptions<TContract, SqlDriver<unknown>> = {
152
+ const coreOptions: RuntimeCoreOptions<TContract, SqlDriver<unknown>, SqlMiddleware> = {
152
153
  familyAdapter,
153
154
  driver,
154
155
  verify,
@@ -172,12 +173,27 @@ class SqlRuntimeImpl<TContract extends Contract<SqlStorage> = Contract<SqlStorag
172
173
  }
173
174
  }
174
175
 
175
- private toExecutionPlan<Row>(plan: ExecutionPlan<Row> | SqlQueryPlan<Row>): ExecutionPlan<Row> {
176
+ private async toExecutionPlan<Row>(
177
+ plan: ExecutionPlan<Row> | SqlQueryPlan<Row>,
178
+ ): Promise<ExecutionPlan<Row>> {
176
179
  const isSqlQueryPlan = (p: ExecutionPlan<Row> | SqlQueryPlan<Row>): p is SqlQueryPlan<Row> => {
177
180
  return 'ast' in p && !('sql' in p);
178
181
  };
179
182
 
180
- return isSqlQueryPlan(plan) ? lowerSqlPlan(this.adapter, this.contract, plan) : plan;
183
+ if (!isSqlQueryPlan(plan)) {
184
+ return plan;
185
+ }
186
+
187
+ const rewrittenDraft = await runBeforeCompileChain(
188
+ this.core.middleware,
189
+ { ast: plan.ast, meta: plan.meta },
190
+ this.core.middlewareContext,
191
+ );
192
+
193
+ const planToLower: SqlQueryPlan<Row> =
194
+ rewrittenDraft.ast === plan.ast ? plan : { ...plan, ast: rewrittenDraft.ast };
195
+
196
+ return lowerSqlPlan(this.adapter, this.contract, planToLower);
181
197
  }
182
198
 
183
199
  private executeAgainstQueryable<Row = Record<string, unknown>>(
@@ -185,11 +201,11 @@ class SqlRuntimeImpl<TContract extends Contract<SqlStorage> = Contract<SqlStorag
185
201
  queryable: CoreQueryable,
186
202
  ): AsyncIterableResult<Row> {
187
203
  this.ensureCodecRegistryValidated(this.contract);
188
- const executablePlan = this.toExecutionPlan(plan);
189
204
 
190
205
  const iterator = async function* (
191
206
  self: SqlRuntimeImpl<TContract>,
192
207
  ): AsyncGenerator<Row, void, unknown> {
208
+ const executablePlan = await self.toExecutionPlan(plan);
193
209
  const encodedParams = encodeParams(executablePlan, self.codecRegistry);
194
210
  const planWithEncodedParams: ExecutionPlan<Row> = {
195
211
  ...executablePlan,
@@ -0,0 +1,223 @@
1
+ import type { Contract, PlanMeta } from '@prisma-next/contract/types';
2
+ import type { SqlStorage } from '@prisma-next/sql-contract/types';
3
+ import {
4
+ AndExpr,
5
+ BinaryExpr,
6
+ ColumnRef,
7
+ LiteralExpr,
8
+ SelectAst,
9
+ TableSource,
10
+ } from '@prisma-next/sql-relational-core/ast';
11
+ import { timeouts } from '@prisma-next/test-utils';
12
+ import { describe, expect, it, vi } from 'vitest';
13
+ import { runBeforeCompileChain } from '../src/middleware/before-compile-chain';
14
+ import type {
15
+ DraftPlan,
16
+ SqlMiddleware,
17
+ SqlMiddlewareContext,
18
+ } from '../src/middleware/sql-middleware';
19
+
20
+ function createContext(): SqlMiddlewareContext & {
21
+ log: { debug: ReturnType<typeof vi.fn> };
22
+ } {
23
+ const debug = vi.fn();
24
+ return {
25
+ contract: {} as Contract<SqlStorage>,
26
+ mode: 'strict' as const,
27
+ now: () => 0,
28
+ log: {
29
+ info: vi.fn(),
30
+ warn: vi.fn(),
31
+ error: vi.fn(),
32
+ debug,
33
+ },
34
+ };
35
+ }
36
+
37
+ const meta: PlanMeta = {
38
+ target: 'postgres',
39
+ storageHash: 'sha256:test',
40
+ lane: 'dsl',
41
+ paramDescriptors: [],
42
+ };
43
+
44
+ function createDraft(): DraftPlan {
45
+ const users = TableSource.named('users');
46
+ return {
47
+ ast: SelectAst.from(users).withProjection([]),
48
+ meta,
49
+ };
50
+ }
51
+
52
+ describe('runBeforeCompileChain', () => {
53
+ it(
54
+ 'returns the initial draft unchanged when no middleware rewrites',
55
+ async () => {
56
+ const draft = createDraft();
57
+ const ctx = createContext();
58
+ const mw: SqlMiddleware = {
59
+ name: 'noop',
60
+ familyId: 'sql',
61
+ async beforeCompile() {
62
+ return undefined;
63
+ },
64
+ };
65
+
66
+ const result = await runBeforeCompileChain([mw], draft, ctx);
67
+
68
+ expect(result).toBe(draft);
69
+ expect(ctx.log.debug).not.toHaveBeenCalled();
70
+ },
71
+ timeouts.default,
72
+ );
73
+
74
+ it(
75
+ 'treats a returned draft with same ast reference as passthrough',
76
+ async () => {
77
+ const draft = createDraft();
78
+ const ctx = createContext();
79
+ const mw: SqlMiddleware = {
80
+ name: 'sameRef',
81
+ familyId: 'sql',
82
+ async beforeCompile(d) {
83
+ return { ...d };
84
+ },
85
+ };
86
+
87
+ const result = await runBeforeCompileChain([mw], draft, ctx);
88
+
89
+ expect(result.ast).toBe(draft.ast);
90
+ expect(ctx.log.debug).not.toHaveBeenCalled();
91
+ },
92
+ timeouts.default,
93
+ );
94
+
95
+ it(
96
+ 'replaces the current draft when a middleware returns a new ast ref',
97
+ async () => {
98
+ const draft = createDraft();
99
+ const ctx = createContext();
100
+ const addWhere = BinaryExpr.eq(ColumnRef.of('users', 'deleted_at'), LiteralExpr.of(null));
101
+ const mw: SqlMiddleware = {
102
+ name: 'softDelete',
103
+ familyId: 'sql',
104
+ async beforeCompile(d) {
105
+ if (d.ast.kind !== 'select') return;
106
+ return { ...d, ast: d.ast.withWhere(addWhere) };
107
+ },
108
+ };
109
+
110
+ const result = await runBeforeCompileChain([mw], draft, ctx);
111
+
112
+ expect(result.ast).not.toBe(draft.ast);
113
+ expect(result.ast.kind).toBe('select');
114
+ expect((result.ast as SelectAst).where).toBe(addWhere);
115
+ },
116
+ timeouts.default,
117
+ );
118
+
119
+ it(
120
+ 'chains rewrites in registration order',
121
+ async () => {
122
+ const draft = createDraft();
123
+ const ctx = createContext();
124
+ const order: string[] = [];
125
+
126
+ const predA = BinaryExpr.eq(ColumnRef.of('users', 'a'), LiteralExpr.of(1));
127
+ const predB = BinaryExpr.eq(ColumnRef.of('users', 'b'), LiteralExpr.of(2));
128
+
129
+ const mwA: SqlMiddleware = {
130
+ name: 'addA',
131
+ familyId: 'sql',
132
+ async beforeCompile(d) {
133
+ order.push('A');
134
+ if (d.ast.kind !== 'select') return;
135
+ return { ...d, ast: d.ast.withWhere(predA) };
136
+ },
137
+ };
138
+ const mwB: SqlMiddleware = {
139
+ name: 'addB',
140
+ familyId: 'sql',
141
+ async beforeCompile(d) {
142
+ order.push('B');
143
+ if (d.ast.kind !== 'select') return;
144
+ const current = d.ast.where;
145
+ const combined = current ? AndExpr.of([current, predB]) : predB;
146
+ return { ...d, ast: d.ast.withWhere(combined) };
147
+ },
148
+ };
149
+
150
+ const result = await runBeforeCompileChain([mwA, mwB], draft, ctx);
151
+
152
+ expect(order).toEqual(['A', 'B']);
153
+ expect(result.ast.kind).toBe('select');
154
+ const where = (result.ast as SelectAst).where;
155
+ expect(where?.kind).toBe('and');
156
+ },
157
+ timeouts.default,
158
+ );
159
+
160
+ it(
161
+ 'emits a debug log event per rewrite with middleware name and lane',
162
+ async () => {
163
+ const draft = createDraft();
164
+ const ctx = createContext();
165
+ const pred = BinaryExpr.eq(ColumnRef.of('users', 'a'), LiteralExpr.of(1));
166
+ const mw: SqlMiddleware = {
167
+ name: 'rewriteOne',
168
+ familyId: 'sql',
169
+ async beforeCompile(d) {
170
+ if (d.ast.kind !== 'select') return;
171
+ return { ...d, ast: d.ast.withWhere(pred) };
172
+ },
173
+ };
174
+
175
+ await runBeforeCompileChain([mw, mw], draft, ctx);
176
+
177
+ expect(ctx.log.debug).toHaveBeenCalledTimes(2);
178
+ expect(ctx.log.debug).toHaveBeenCalledWith({
179
+ event: 'middleware.rewrite',
180
+ middleware: 'rewriteOne',
181
+ lane: 'dsl',
182
+ });
183
+ },
184
+ timeouts.default,
185
+ );
186
+
187
+ it(
188
+ 'skips middleware without beforeCompile',
189
+ async () => {
190
+ const draft = createDraft();
191
+ const ctx = createContext();
192
+ const observerOnly: SqlMiddleware = {
193
+ name: 'observer',
194
+ familyId: 'sql',
195
+ async beforeExecute() {},
196
+ };
197
+
198
+ const result = await runBeforeCompileChain([observerOnly], draft, ctx);
199
+
200
+ expect(result).toBe(draft);
201
+ expect(ctx.log.debug).not.toHaveBeenCalled();
202
+ },
203
+ timeouts.default,
204
+ );
205
+
206
+ it(
207
+ 'propagates errors thrown inside beforeCompile',
208
+ async () => {
209
+ const draft = createDraft();
210
+ const ctx = createContext();
211
+ const mw: SqlMiddleware = {
212
+ name: 'thrower',
213
+ familyId: 'sql',
214
+ async beforeCompile() {
215
+ throw new Error('boom');
216
+ },
217
+ };
218
+
219
+ await expect(runBeforeCompileChain([mw], draft, ctx)).rejects.toThrow('boom');
220
+ },
221
+ timeouts.default,
222
+ );
223
+ });
@@ -1,5 +1,6 @@
1
- import type { ExecutionPlan, PlanMeta } from '@prisma-next/contract/types';
2
- import type { AfterExecuteResult, MiddlewareContext } from '@prisma-next/runtime-executor';
1
+ import type { Contract, ExecutionPlan, PlanMeta } from '@prisma-next/contract/types';
2
+ import type { AfterExecuteResult } from '@prisma-next/runtime-executor';
3
+ import type { SqlStorage } from '@prisma-next/sql-contract/types';
3
4
  import {
4
5
  AggregateExpr,
5
6
  ColumnRef,
@@ -11,15 +12,14 @@ import {
11
12
  import { timeouts } from '@prisma-next/test-utils';
12
13
  import { describe, expect, it, vi } from 'vitest';
13
14
  import { budgets } from '../src/middleware/budgets';
15
+ import type { SqlMiddlewareContext } from '../src/middleware/sql-middleware';
14
16
 
15
17
  const userTable = TableSource.named('user');
16
18
  const idCol = ColumnRef.of('user', 'id');
17
19
 
18
- function createMiddlewareContext(
19
- overrides?: Partial<MiddlewareContext<unknown>>,
20
- ): MiddlewareContext<unknown> {
20
+ function createMiddlewareContext(overrides?: Partial<SqlMiddlewareContext>): SqlMiddlewareContext {
21
21
  return {
22
- contract: {},
22
+ contract: {} as Contract<SqlStorage>,
23
23
  mode: 'strict' as const,
24
24
  now: () => Date.now(),
25
25
  log: {
@@ -41,10 +41,7 @@ function createStubAdapterDescriptor(): SqlRuntimeAdapterDescriptor<'postgres'>
41
41
  readMarkerStatement: () => ({ sql: '', params: [] }),
42
42
  },
43
43
  lower() {
44
- return {
45
- profileId: 'test-profile',
46
- body: Object.freeze({ sql: '', params: [] }),
47
- };
44
+ return Object.freeze({ sql: '', params: [] });
48
45
  },
49
46
  },
50
47
  );
@@ -1,5 +1,5 @@
1
- import type { ExecutionPlan, PlanMeta } from '@prisma-next/contract/types';
2
- import type { MiddlewareContext } from '@prisma-next/runtime-executor';
1
+ import type { Contract, ExecutionPlan, PlanMeta } from '@prisma-next/contract/types';
2
+ import type { SqlStorage } from '@prisma-next/sql-contract/types';
3
3
  import {
4
4
  BinaryExpr,
5
5
  ColumnRef,
@@ -14,10 +14,11 @@ import {
14
14
  import { timeouts } from '@prisma-next/test-utils';
15
15
  import { describe, expect, it, vi } from 'vitest';
16
16
  import { lints } from '../src/middleware/lints';
17
+ import type { SqlMiddlewareContext } from '../src/middleware/sql-middleware';
17
18
 
18
- function createMiddlewareContext(): MiddlewareContext<unknown> {
19
+ function createMiddlewareContext(): SqlMiddlewareContext {
19
20
  return {
20
- contract: {},
21
+ contract: {} as Contract<SqlStorage>,
21
22
  mode: 'strict' as const,
22
23
  now: () => Date.now(),
23
24
  log: {