agent-sql 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -12,6 +12,8 @@ It ensures that that the needed tenant table is somewhere in the query,
12
12
  and adds a `WHERE` clause ensuring that only values from the supplied ID are returned.
13
13
  Then it checks that the tables and `JOIN`s follow the schema, preventing sneaky joins.
14
14
 
15
+ Function calls also go through a whitelist (configurable).
16
+
15
17
  Finally, we throw in a `LIMIT` clause (configurable) to prevent accidental LLM denial-of-service.
16
18
 
17
19
  Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
package/dist/index.d.mts CHANGED
@@ -1,5 +1,8 @@
1
1
  import { i as SelectStatement, n as defineSchema, r as Result, t as Schema } from "./joins-Cu_0yAgN.mjs";
2
2
 
3
+ //#region src/functions.d.ts
4
+ type DbType = "postgres" | "pglite" | "sqlite";
5
+ //#endregion
3
6
  //#region src/guard.d.ts
4
7
  type GuardVal = string | number;
5
8
  interface GuardCol {
@@ -12,10 +15,10 @@ type WhereGuard = GuardCol & {
12
15
  };
13
16
  type OneOrTwoDots<S extends string> = S extends `${infer A}.${infer B}.${infer C}` ? A extends `${string}.${string}` ? never : B extends `${string}.${string}` ? never : C extends `${string}.${string}` ? never : S : S extends `${infer A}.${infer B}` ? A extends `${string}.${string}` ? never : B extends `${string}.${string}` ? never : S : never;
14
17
  type SchemaGuardKeys<T> = { [Table in keyof T & string]: `${Table}.${keyof T[Table] & string}` }[keyof T & string];
15
- declare function applyGuards(ast: SelectStatement, guards: WhereGuard[], limit?: number): Result<SelectStatement>;
18
+ declare function applyGuards(ast: SelectStatement, guards: WhereGuard[], limit: number): Result<SelectStatement>;
16
19
  //#endregion
17
20
  //#region src/output.d.ts
18
- declare function outputSql(ast: SelectStatement): string;
21
+ declare function outputSql(ast: SelectStatement, pretty?: boolean): string;
19
22
  //#endregion
20
23
  //#region src/parse.d.ts
21
24
  declare function parseSql(expr: string): Result<SelectStatement>;
@@ -23,18 +26,30 @@ declare function parseSql(expr: string): Result<SelectStatement>;
23
26
  //#region src/index.d.ts
24
27
  declare function agentSql<S extends string>(sql: string, column: S & OneOrTwoDots<S>, value: GuardVal, {
25
28
  schema,
26
- limit
29
+ limit,
30
+ pretty,
31
+ db,
32
+ allowExtraFunctions
27
33
  }?: {
28
34
  schema?: Schema;
29
35
  limit?: number;
36
+ pretty?: boolean;
37
+ db?: DbType;
38
+ allowExtraFunctions?: string[];
30
39
  }): string;
31
40
  declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts: {
32
41
  limit?: number;
42
+ pretty?: boolean;
33
43
  throws: false;
44
+ db?: DbType;
45
+ allowExtraFunctions?: string[];
34
46
  }): (expr: string) => Result<string>;
35
47
  declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts?: {
36
48
  limit?: number;
49
+ pretty?: boolean;
37
50
  throws?: true;
51
+ db?: DbType;
52
+ allowExtraFunctions?: string[];
38
53
  }): (expr: string) => string;
39
54
  //#endregion
40
- export { agentSql, createAgentSql, defineSchema, outputSql, parseSql, applyGuards as sanitiseSql };
55
+ export { type DbType, agentSql, createAgentSql, defineSchema, outputSql, parseSql, applyGuards as sanitiseSql };
package/dist/index.mjs CHANGED
@@ -10,28 +10,794 @@ var AgentSqlError = class extends Error {
10
10
  type = "agent_sql_error";
11
11
  };
12
12
  //#endregion
13
+ //#region src/result.ts
14
+ function Err(error) {
15
+ return {
16
+ ok: false,
17
+ error,
18
+ unwrap() {
19
+ throw new Error(String(error));
20
+ }
21
+ };
22
+ }
23
+ function Ok(data) {
24
+ return {
25
+ ok: true,
26
+ data,
27
+ unwrap() {
28
+ return data;
29
+ }
30
+ };
31
+ }
32
+ function returnOrThrow(result, throws) {
33
+ if (!throws) return result;
34
+ if (result.ok) return result.data;
35
+ throw result.error;
36
+ }
37
+ //#endregion
38
+ //#region src/functions.ts
39
+ const DEFAULT_DB = "postgres";
40
+ const COMMON_FUNCTIONS = [
41
+ "count",
42
+ "sum",
43
+ "avg",
44
+ "min",
45
+ "max",
46
+ "coalesce",
47
+ "nullif",
48
+ "greatest",
49
+ "least",
50
+ "lower",
51
+ "upper",
52
+ "trim",
53
+ "ltrim",
54
+ "rtrim",
55
+ "length",
56
+ "replace",
57
+ "substr",
58
+ "substring",
59
+ "concat",
60
+ "reverse",
61
+ "repeat",
62
+ "position",
63
+ "left",
64
+ "right",
65
+ "lpad",
66
+ "rpad",
67
+ "translate",
68
+ "char_length",
69
+ "character_length",
70
+ "octet_length",
71
+ "overlay",
72
+ "ascii",
73
+ "chr",
74
+ "starts_with",
75
+ "abs",
76
+ "ceil",
77
+ "ceiling",
78
+ "floor",
79
+ "round",
80
+ "trunc",
81
+ "truncate",
82
+ "mod",
83
+ "power",
84
+ "sqrt",
85
+ "cbrt",
86
+ "log",
87
+ "ln",
88
+ "exp",
89
+ "sign",
90
+ "random",
91
+ "pi",
92
+ "degrees",
93
+ "radians",
94
+ "div",
95
+ "gcd",
96
+ "lcm",
97
+ "cast"
98
+ ];
99
+ const PGVECTOR_FUNCTIONS = [
100
+ "l2_distance",
101
+ "inner_product",
102
+ "cosine_distance",
103
+ "l1_distance",
104
+ "vector_dims",
105
+ "vector_norm",
106
+ "l2_normalize",
107
+ "binary_quantize",
108
+ "subvector",
109
+ "vector_avg"
110
+ ];
111
+ const POSTGIS_FUNCTIONS = [
112
+ "st_geomfromtext",
113
+ "st_geomfromwkb",
114
+ "st_geomfromewkt",
115
+ "st_geomfromewkb",
116
+ "st_geomfromgeojson",
117
+ "st_geogfromtext",
118
+ "st_geogfromwkb",
119
+ "st_makepoint",
120
+ "st_makepoint",
121
+ "st_makepointm",
122
+ "st_makeenvelope",
123
+ "st_makeline",
124
+ "st_makepolygon",
125
+ "st_makebox2d",
126
+ "st_point",
127
+ "st_pointz",
128
+ "st_pointm",
129
+ "st_pointzm",
130
+ "st_polygon",
131
+ "st_linefrommultipoint",
132
+ "st_tileenvelope",
133
+ "st_hexagongrid",
134
+ "st_squaregrid",
135
+ "st_letters",
136
+ "st_collect",
137
+ "st_linemerge",
138
+ "st_buildarea",
139
+ "st_polygonize",
140
+ "st_unaryunion",
141
+ "st_astext",
142
+ "st_asewkt",
143
+ "st_asbinary",
144
+ "st_asewkb",
145
+ "st_asgeojson",
146
+ "st_asgml",
147
+ "st_askml",
148
+ "st_assvg",
149
+ "st_astwkb",
150
+ "st_asmvtgeom",
151
+ "st_asmvt",
152
+ "st_asencodedpolyline",
153
+ "st_ashexewkb",
154
+ "st_aslatlontext",
155
+ "st_x",
156
+ "st_y",
157
+ "st_z",
158
+ "st_m",
159
+ "st_geometrytype",
160
+ "st_srid",
161
+ "st_dimension",
162
+ "st_coorddim",
163
+ "st_numgeometries",
164
+ "st_numpoints",
165
+ "st_npoints",
166
+ "st_nrings",
167
+ "st_numinteriorrings",
168
+ "st_numinteriorring",
169
+ "st_exteriorring",
170
+ "st_interiorringn",
171
+ "st_geometryn",
172
+ "st_pointn",
173
+ "st_startpoint",
174
+ "st_endpoint",
175
+ "st_envelope",
176
+ "st_boundingdiagonal",
177
+ "st_xmin",
178
+ "st_xmax",
179
+ "st_ymin",
180
+ "st_ymax",
181
+ "st_zmin",
182
+ "st_zmax",
183
+ "st_isempty",
184
+ "st_isclosed",
185
+ "st_isring",
186
+ "st_issimple",
187
+ "st_isvalid",
188
+ "st_isvalidreason",
189
+ "st_isvaliddetail",
190
+ "st_hasm",
191
+ "st_hasz",
192
+ "st_ismeasured",
193
+ "st_intersects",
194
+ "st_disjoint",
195
+ "st_contains",
196
+ "st_within",
197
+ "st_covers",
198
+ "st_coveredby",
199
+ "st_crosses",
200
+ "st_overlaps",
201
+ "st_touches",
202
+ "st_equals",
203
+ "st_relate",
204
+ "st_containsproperly",
205
+ "st_dwithin",
206
+ "st_3dintersects",
207
+ "st_3ddwithin",
208
+ "st_orderingequals",
209
+ "st_distance",
210
+ "st_3ddistance",
211
+ "st_maxdistance",
212
+ "st_area",
213
+ "st_length",
214
+ "st_length2d",
215
+ "st_3dlength",
216
+ "st_perimeter",
217
+ "st_azimuth",
218
+ "st_angle",
219
+ "st_hausdorffdistance",
220
+ "st_frechetdistance",
221
+ "st_longestline",
222
+ "st_shortestline",
223
+ "st_transform",
224
+ "st_setsrid",
225
+ "st_force2d",
226
+ "st_force3d",
227
+ "st_force3dz",
228
+ "st_force3dm",
229
+ "st_force4d",
230
+ "st_forcecollection",
231
+ "st_forcepolygoncw",
232
+ "st_forcepolygonccw",
233
+ "st_forcecurve",
234
+ "st_forcesfs",
235
+ "st_multi",
236
+ "st_normalize",
237
+ "st_flipcoordinates",
238
+ "st_translate",
239
+ "st_scale",
240
+ "st_rotate",
241
+ "st_rotatex",
242
+ "st_rotatey",
243
+ "st_rotatez",
244
+ "st_affine",
245
+ "st_transscale",
246
+ "st_snap",
247
+ "st_snaptogrid",
248
+ "st_segmentize",
249
+ "st_simplify",
250
+ "st_simplifypreservetopology",
251
+ "st_simplifyvw",
252
+ "st_chaikinsmoothing",
253
+ "st_seteffectivearea",
254
+ "st_filterbym",
255
+ "st_locatebetween",
256
+ "st_locatebetweenelevations",
257
+ "st_offsetcurve",
258
+ "st_intersection",
259
+ "st_union",
260
+ "st_difference",
261
+ "st_symdifference",
262
+ "st_buffer",
263
+ "st_convexhull",
264
+ "st_concavehull",
265
+ "st_minimumboundingcircle",
266
+ "st_minimumboundingradius",
267
+ "st_orientedenvelope",
268
+ "st_centroid",
269
+ "st_pointonsurface",
270
+ "st_geometricmedian",
271
+ "st_voronoipolygons",
272
+ "st_voronoilines",
273
+ "st_delaunaytriangles",
274
+ "st_subdivide",
275
+ "st_split",
276
+ "st_sharedpaths",
277
+ "st_node",
278
+ "st_clusterdbscan",
279
+ "st_clusterkmeans",
280
+ "st_clusterintersecting",
281
+ "st_clusterwithin",
282
+ "st_makevalid",
283
+ "st_lineinterpolatepoint",
284
+ "st_lineinterpolatepoints",
285
+ "st_linelocatepoint",
286
+ "st_linesubstring",
287
+ "st_addmeasure",
288
+ "st_closestpoint",
289
+ "st_linefromencodedpolyline",
290
+ "box2d",
291
+ "box3d",
292
+ "st_expand",
293
+ "st_estimatedextent",
294
+ "st_extent",
295
+ "st_3dextent",
296
+ "st_memsize",
297
+ "st_distancesphere",
298
+ "st_distancespheroid",
299
+ "st_project",
300
+ "st_memunion",
301
+ "st_polygonize",
302
+ "st_nband",
303
+ "st_numbands",
304
+ "st_summary",
305
+ "st_dump",
306
+ "st_dumppoints",
307
+ "st_dumprings",
308
+ "postgis_version",
309
+ "postgis_full_version",
310
+ "postgis_geos_version",
311
+ "postgis_proj_version",
312
+ "postgis_lib_version",
313
+ "postgis_scripts_installed",
314
+ "postgis_scripts_released",
315
+ "postgis_type_name",
316
+ "populate_geometry_columns",
317
+ "find_srid",
318
+ "updategeometrysrid",
319
+ "addgeometrycolumn",
320
+ "dropgeometrycolumn",
321
+ "geography",
322
+ "geometry"
323
+ ];
324
+ const POSTGRES_FUNCTIONS = [
325
+ ...COMMON_FUNCTIONS,
326
+ "array_agg",
327
+ "string_agg",
328
+ "json_agg",
329
+ "jsonb_agg",
330
+ "json_object_agg",
331
+ "jsonb_object_agg",
332
+ "bool_and",
333
+ "bool_or",
334
+ "every",
335
+ "bit_and",
336
+ "bit_or",
337
+ "bit_xor",
338
+ "corr",
339
+ "covar_pop",
340
+ "covar_samp",
341
+ "regr_avgx",
342
+ "regr_avgy",
343
+ "regr_count",
344
+ "regr_intercept",
345
+ "regr_r2",
346
+ "regr_slope",
347
+ "regr_sxx",
348
+ "regr_sxy",
349
+ "regr_syy",
350
+ "stddev",
351
+ "stddev_pop",
352
+ "stddev_samp",
353
+ "variance",
354
+ "var_pop",
355
+ "var_samp",
356
+ "percentile_cont",
357
+ "percentile_disc",
358
+ "mode",
359
+ "rank",
360
+ "dense_rank",
361
+ "percent_rank",
362
+ "cume_dist",
363
+ "ntile",
364
+ "lag",
365
+ "lead",
366
+ "first_value",
367
+ "last_value",
368
+ "nth_value",
369
+ "row_number",
370
+ "initcap",
371
+ "strpos",
372
+ "encode",
373
+ "decode",
374
+ "md5",
375
+ "sha256",
376
+ "sha224",
377
+ "sha384",
378
+ "sha512",
379
+ "format",
380
+ "concat_ws",
381
+ "regexp_replace",
382
+ "regexp_match",
383
+ "regexp_matches",
384
+ "regexp_split_to_array",
385
+ "regexp_split_to_table",
386
+ "split_part",
387
+ "btrim",
388
+ "bit_length",
389
+ "quote_ident",
390
+ "quote_literal",
391
+ "quote_nullable",
392
+ "to_hex",
393
+ "convert",
394
+ "convert_from",
395
+ "convert_to",
396
+ "string_to_array",
397
+ "array_to_string",
398
+ "now",
399
+ "current_timestamp",
400
+ "current_date",
401
+ "current_time",
402
+ "localtime",
403
+ "localtimestamp",
404
+ "clock_timestamp",
405
+ "statement_timestamp",
406
+ "transaction_timestamp",
407
+ "timeofday",
408
+ "date_trunc",
409
+ "date_part",
410
+ "extract",
411
+ "age",
412
+ "to_char",
413
+ "to_date",
414
+ "to_timestamp",
415
+ "to_number",
416
+ "make_date",
417
+ "make_time",
418
+ "make_timestamp",
419
+ "make_timestamptz",
420
+ "make_interval",
421
+ "justify_days",
422
+ "justify_hours",
423
+ "justify_interval",
424
+ "isfinite",
425
+ "json_extract_path",
426
+ "json_extract_path_text",
427
+ "jsonb_extract_path",
428
+ "jsonb_extract_path_text",
429
+ "json_array_length",
430
+ "jsonb_array_length",
431
+ "json_typeof",
432
+ "jsonb_typeof",
433
+ "json_build_object",
434
+ "jsonb_build_object",
435
+ "json_build_array",
436
+ "jsonb_build_array",
437
+ "to_json",
438
+ "to_jsonb",
439
+ "row_to_json",
440
+ "json_each",
441
+ "json_each_text",
442
+ "jsonb_each",
443
+ "jsonb_each_text",
444
+ "json_object_keys",
445
+ "jsonb_object_keys",
446
+ "json_populate_record",
447
+ "jsonb_populate_record",
448
+ "json_populate_recordset",
449
+ "jsonb_populate_recordset",
450
+ "json_to_record",
451
+ "jsonb_to_record",
452
+ "json_to_recordset",
453
+ "jsonb_to_recordset",
454
+ "json_array_elements",
455
+ "jsonb_array_elements",
456
+ "json_array_elements_text",
457
+ "jsonb_array_elements_text",
458
+ "jsonb_set",
459
+ "jsonb_set_lax",
460
+ "jsonb_insert",
461
+ "jsonb_path_query",
462
+ "jsonb_path_query_array",
463
+ "jsonb_path_query_first",
464
+ "jsonb_path_exists",
465
+ "jsonb_path_match",
466
+ "jsonb_strip_nulls",
467
+ "jsonb_pretty",
468
+ "json_strip_nulls",
469
+ "to_tsvector",
470
+ "to_tsquery",
471
+ "plainto_tsquery",
472
+ "phraseto_tsquery",
473
+ "websearch_to_tsquery",
474
+ "ts_rank",
475
+ "ts_rank_cd",
476
+ "ts_headline",
477
+ "tsvector_to_array",
478
+ "array_to_tsvector",
479
+ "numnode",
480
+ "querytree",
481
+ "ts_rewrite",
482
+ "setweight",
483
+ "strip",
484
+ "ts_debug",
485
+ "ts_lexize",
486
+ "ts_parse",
487
+ "ts_token_type",
488
+ "get_current_ts_config",
489
+ "array_append",
490
+ "array_cat",
491
+ "array_dims",
492
+ "array_fill",
493
+ "array_length",
494
+ "array_lower",
495
+ "array_ndims",
496
+ "array_position",
497
+ "array_positions",
498
+ "array_prepend",
499
+ "array_remove",
500
+ "array_replace",
501
+ "array_upper",
502
+ "cardinality",
503
+ "unnest",
504
+ "generate_subscripts",
505
+ "lower",
506
+ "upper",
507
+ "isempty",
508
+ "lower_inc",
509
+ "lower_inf",
510
+ "upper_inc",
511
+ "upper_inf",
512
+ "range_merge",
513
+ "generate_series",
514
+ "pg_typeof",
515
+ "current_setting",
516
+ "current_database",
517
+ "current_schema",
518
+ "current_schemas",
519
+ "current_user",
520
+ "session_user",
521
+ "inet_client_addr",
522
+ "inet_client_port",
523
+ "version",
524
+ "obj_description",
525
+ "col_description",
526
+ "shobj_description",
527
+ "has_table_privilege",
528
+ "has_column_privilege",
529
+ "has_schema_privilege",
530
+ "txid_current",
531
+ "txid_current_snapshot",
532
+ "area",
533
+ "center",
534
+ "diameter",
535
+ "height",
536
+ "width",
537
+ "isclosed",
538
+ "isopen",
539
+ "npoints",
540
+ "pclose",
541
+ "popen",
542
+ "radius",
543
+ "abbrev",
544
+ "broadcast",
545
+ "family",
546
+ "host",
547
+ "hostmask",
548
+ "masklen",
549
+ "netmask",
550
+ "network",
551
+ "set_masklen",
552
+ "inet_merge",
553
+ "inet_same_family",
554
+ "gen_random_uuid",
555
+ "uuid_generate_v1",
556
+ "uuid_generate_v4",
557
+ ...PGVECTOR_FUNCTIONS,
558
+ ...POSTGIS_FUNCTIONS
559
+ ];
560
+ const SQLITE_FUNCTIONS = [
561
+ ...COMMON_FUNCTIONS,
562
+ "group_concat",
563
+ "total",
564
+ "char",
565
+ "format",
566
+ "glob",
567
+ "hex",
568
+ "unhex",
569
+ "instr",
570
+ "like",
571
+ "ltrim",
572
+ "rtrim",
573
+ "trim",
574
+ "printf",
575
+ "quote",
576
+ "soundex",
577
+ "unicode",
578
+ "zeroblob",
579
+ "acos",
580
+ "acosh",
581
+ "asin",
582
+ "asinh",
583
+ "atan",
584
+ "atan2",
585
+ "atanh",
586
+ "cos",
587
+ "cosh",
588
+ "sin",
589
+ "sinh",
590
+ "tan",
591
+ "tanh",
592
+ "date",
593
+ "time",
594
+ "datetime",
595
+ "julianday",
596
+ "unixepoch",
597
+ "strftime",
598
+ "timediff",
599
+ "typeof",
600
+ "type",
601
+ "last_insert_rowid",
602
+ "changes",
603
+ "total_changes",
604
+ "sqlite_version",
605
+ "json",
606
+ "json_array",
607
+ "json_array_length",
608
+ "json_extract",
609
+ "json_insert",
610
+ "json_object",
611
+ "json_patch",
612
+ "json_remove",
613
+ "json_replace",
614
+ "json_set",
615
+ "json_type",
616
+ "json_valid",
617
+ "json_quote",
618
+ "json_group_array",
619
+ "json_group_object",
620
+ "json_each",
621
+ "json_tree",
622
+ "iif",
623
+ "ifnull",
624
+ "likely",
625
+ "unlikely",
626
+ "max",
627
+ "min",
628
+ "nullif",
629
+ "randomblob",
630
+ "row_number",
631
+ "rank",
632
+ "dense_rank",
633
+ "percent_rank",
634
+ "cume_dist",
635
+ "ntile",
636
+ "lag",
637
+ "lead",
638
+ "first_value",
639
+ "last_value",
640
+ "nth_value"
641
+ ];
642
+ function checkFunctions(ast, db, allowExtraFunctions) {
643
+ const bad = findDisallowedFunction(ast, buildAllowedSet(db, allowExtraFunctions));
644
+ if (bad !== null) return Err(new SanitiseError(`Function '${bad}' is not allowed`));
645
+ return Ok(ast);
646
+ }
647
+ function buildSet(list) {
648
+ return new Set(list.map((f) => f.toLowerCase()));
649
+ }
650
+ const PG_SET = buildSet(POSTGRES_FUNCTIONS);
651
+ const SQLITE_SET = buildSet(SQLITE_FUNCTIONS);
652
+ function getAllowedFunctions(db) {
653
+ return db === "sqlite" ? SQLITE_SET : PG_SET;
654
+ }
655
+ function buildAllowedSet(db, extra) {
656
+ const base = getAllowedFunctions(db);
657
+ if (extra.length === 0) return base;
658
+ const merged = new Set(base);
659
+ for (const f of extra) merged.add(f.toLowerCase());
660
+ return merged;
661
+ }
662
+ function findDisallowedFunction(ast, allowed) {
663
+ for (const col of ast.columns) if (col.expr.kind === "expr" && col.expr.expr) {
664
+ const bad = checkWhereValue(col.expr.expr, allowed);
665
+ if (bad) return bad;
666
+ }
667
+ if (ast.distinct && ast.distinct.type === "distinct_on") for (const val of ast.distinct.columns) {
668
+ const bad = checkWhereValue(val, allowed);
669
+ if (bad) return bad;
670
+ }
671
+ for (const join of ast.joins) if (join.condition && join.condition.type === "join_on") {
672
+ const bad = checkWhereExpr(join.condition.expr, allowed);
673
+ if (bad) return bad;
674
+ }
675
+ if (ast.where) {
676
+ const bad = checkWhereExpr(ast.where.inner, allowed);
677
+ if (bad) return bad;
678
+ }
679
+ if (ast.groupBy) for (const item of ast.groupBy.items) {
680
+ const bad = checkWhereValue(item, allowed);
681
+ if (bad) return bad;
682
+ }
683
+ if (ast.having) {
684
+ const bad = checkWhereExpr(ast.having.expr, allowed);
685
+ if (bad) return bad;
686
+ }
687
+ if (ast.orderBy) for (const item of ast.orderBy.items) {
688
+ const bad = checkWhereValue(item.expr, allowed);
689
+ if (bad) return bad;
690
+ }
691
+ return null;
692
+ }
693
+ function checkFuncCall(func, allowed) {
694
+ if (!allowed.has(func.name.toLowerCase())) return func.name;
695
+ if (func.args.kind === "args") for (const arg of func.args.args) {
696
+ const bad = checkWhereValue(arg, allowed);
697
+ if (bad) return bad;
698
+ }
699
+ return null;
700
+ }
701
+ function checkWhereValue(val, allowed) {
702
+ switch (val.type) {
703
+ case "where_value":
704
+ if (val.kind === "func_call") return checkFuncCall(val.func, allowed);
705
+ return null;
706
+ case "where_arith":
707
+ case "where_jsonb_op":
708
+ case "where_pgvector_op": {
709
+ const l = checkWhereValue(val.left, allowed);
710
+ if (l) return l;
711
+ return checkWhereValue(val.right, allowed);
712
+ }
713
+ case "where_unary_minus": return checkWhereValue(val.expr, allowed);
714
+ case "case_expr":
715
+ if (val.subject) {
716
+ const s = checkWhereValue(val.subject, allowed);
717
+ if (s) return s;
718
+ }
719
+ for (const w of val.whens) {
720
+ const c = checkWhereValue(w.condition, allowed);
721
+ if (c) return c;
722
+ const r = checkWhereValue(w.result, allowed);
723
+ if (r) return r;
724
+ }
725
+ if (val.else) return checkWhereValue(val.else, allowed);
726
+ return null;
727
+ case "cast_expr": return checkWhereValue(val.expr, allowed);
728
+ default: return null;
729
+ }
730
+ }
731
+ function checkWhereExpr(expr, allowed) {
732
+ switch (expr.type) {
733
+ case "where_and":
734
+ case "where_or": {
735
+ const l = checkWhereExpr(expr.left, allowed);
736
+ if (l) return l;
737
+ return checkWhereExpr(expr.right, allowed);
738
+ }
739
+ case "where_not": return checkWhereExpr(expr.expr, allowed);
740
+ case "where_comparison": {
741
+ const l = checkWhereValue(expr.left, allowed);
742
+ if (l) return l;
743
+ return checkWhereValue(expr.right, allowed);
744
+ }
745
+ case "where_is_null": return checkWhereValue(expr.expr, allowed);
746
+ case "where_is_bool": return checkWhereValue(expr.expr, allowed);
747
+ case "where_between": {
748
+ const e = checkWhereValue(expr.expr, allowed);
749
+ if (e) return e;
750
+ const lo = checkWhereValue(expr.low, allowed);
751
+ if (lo) return lo;
752
+ return checkWhereValue(expr.high, allowed);
753
+ }
754
+ case "where_in": {
755
+ const e = checkWhereValue(expr.expr, allowed);
756
+ if (e) return e;
757
+ for (const item of expr.list) {
758
+ const bad = checkWhereValue(item, allowed);
759
+ if (bad) return bad;
760
+ }
761
+ return null;
762
+ }
763
+ case "where_like": {
764
+ const e = checkWhereValue(expr.expr, allowed);
765
+ if (e) return e;
766
+ return checkWhereValue(expr.pattern, allowed);
767
+ }
768
+ case "where_ts_match": {
769
+ const l = checkWhereValue(expr.left, allowed);
770
+ if (l) return l;
771
+ return checkWhereValue(expr.right, allowed);
772
+ }
773
+ default: return null;
774
+ }
775
+ }
776
+ //#endregion
13
777
  //#region src/utils.ts
14
778
  function unreachable(x) {
15
779
  throw new Error(`Unhandled variant: ${JSON.stringify(x)}`);
16
780
  }
17
781
  //#endregion
18
782
  //#region src/output.ts
19
- function outputSql(ast) {
20
- return normaliseWhitespace(r(ast));
783
+ function outputSql(ast, pretty = false) {
784
+ const res = r(ast, pretty);
785
+ if (pretty) return res;
786
+ return normaliseWhitespace(res);
21
787
  }
22
788
  function normaliseWhitespace(s) {
23
789
  return s.replace(/\s+/g, " ").trim();
24
790
  }
25
- function r(node) {
791
+ function r(node, pretty = false) {
26
792
  if (node === null || node === void 0) return "";
27
793
  if (typeof node === "string") return node;
28
794
  if (typeof node === "number") return String(node);
29
795
  switch (node.type) {
30
- case "select": return handleSelect(node);
796
+ case "select": return handleSelect(node, pretty);
31
797
  case "select_from": return handleSelectFrom(node);
32
- case "join": return handleJoin(node);
798
+ case "join": return handleJoin(node, pretty);
33
799
  case "join_on":
34
- case "join_using": return handleJoinCondition(node);
800
+ case "join_using": return handleJoinCondition(node, pretty);
35
801
  case "table_ref": return handleTableRef(node);
36
802
  case "limit": return handleLimit(node);
37
803
  case "offset": return handleOffset(node);
@@ -59,7 +825,7 @@ function r(node) {
59
825
  case "case_expr": return handleCaseExpr(node);
60
826
  case "cast_expr": return handleCastExpr(node);
61
827
  case "where_value": return handleWhereValue(node);
62
- case "column": return handleColumn(node);
828
+ case "column": return handleColumn(node, pretty);
63
829
  case "column_expr": return handleColumnExpr(node);
64
830
  case "column_ref": return handleColumnRef(node);
65
831
  case "alias": return handleAlias(node);
@@ -67,23 +833,37 @@ function r(node) {
67
833
  default: return unreachable(node);
68
834
  }
69
835
  }
70
- function mapR(t) {
71
- return t.map((n) => r(n).trim()).join(", ");
836
+ function rMap(t, pretty = false, sep = ", ") {
837
+ return t.map((n) => r(n, pretty).trimEnd()).join(sep);
72
838
  }
73
- function handleSelect(node) {
74
- const joins = node.joins.map((j) => r(j)).join(" ");
75
- return `SELECT ${node.distinct ? `${r(node.distinct)} ` : ""}${mapR(node.columns)} ${r(node.from)} ${joins} ${r(node.where)} ${r(node.groupBy)} ${r(node.having)} ${r(node.orderBy)} ${r(node.limit)} ${r(node.offset)}`;
839
+ function handleSelect(node, pretty) {
840
+ const distinctStr = node.distinct ? `${r(node.distinct)} ` : "";
841
+ const colSep = pretty ? ",\n" : ", ";
842
+ const joinSep = pretty ? "\n" : " ";
843
+ const joiner = pretty ? "\n" : " ";
844
+ return [
845
+ "SELECT",
846
+ `${distinctStr}${rMap(node.columns, pretty, colSep)}`,
847
+ r(node.from),
848
+ rMap(node.joins, pretty, joinSep),
849
+ r(node.where),
850
+ r(node.groupBy),
851
+ r(node.having),
852
+ r(node.orderBy),
853
+ r(node.limit),
854
+ r(node.offset)
855
+ ].join(joiner);
76
856
  }
77
857
  function handleDistinct(_node) {
78
858
  return "DISTINCT";
79
859
  }
80
860
  function handleDistinctOn(node) {
81
- return `DISTINCT ON (${mapR(node.columns)})`;
861
+ return `DISTINCT ON (${rMap(node.columns)})`;
82
862
  }
83
863
  function handleSelectFrom(node) {
84
864
  return `FROM ${r(node.table)}`;
85
865
  }
86
- function handleJoin(node) {
866
+ function handleJoin(node, pretty) {
87
867
  const typeStr = {
88
868
  inner: "INNER JOIN",
89
869
  inner_outer: "INNER OUTER JOIN",
@@ -96,12 +876,13 @@ function handleJoin(node) {
96
876
  cross: "CROSS JOIN",
97
877
  natural: "NATURAL JOIN"
98
878
  };
99
- const cond = node.condition ? ` ${handleJoinCondition(node.condition)}` : "";
879
+ const cond = node.condition ? `${handleJoinCondition(node.condition, pretty)}` : "";
100
880
  return `${typeStr[node.joinType]} ${r(node.table)}${cond}`;
101
881
  }
102
- function handleJoinCondition(node) {
103
- if (node.type === "join_on") return `ON ${r(node.expr)}`;
104
- return `USING (${node.columns.join(", ")})`;
882
+ function handleJoinCondition(node, pretty) {
883
+ const prefix = pretty ? "\n " : " ";
884
+ if (node.type === "join_on") return `${prefix}ON ${r(node.expr)}`;
885
+ return `${prefix}USING (${node.columns.join(", ")})`;
105
886
  }
106
887
  function handleLimit(node) {
107
888
  return `LIMIT ${node.value}`;
@@ -110,13 +891,13 @@ function handleOffset(node) {
110
891
  return `OFFSET ${node.value}`;
111
892
  }
112
893
  function handleGroupBy(node) {
113
- return `GROUP BY ${mapR(node.items)}`;
894
+ return `GROUP BY ${rMap(node.items)}`;
114
895
  }
115
896
  function handleHaving(node) {
116
897
  return `HAVING ${r(node.expr)}`;
117
898
  }
118
899
  function handleOrderBy(node) {
119
- return `ORDER BY ${mapR(node.items)}`;
900
+ return `ORDER BY ${rMap(node.items)}`;
120
901
  }
121
902
  function handleOrderByItem(node) {
122
903
  const dir = node.direction ? ` ${node.direction.toUpperCase()}` : "";
@@ -151,7 +932,7 @@ function handleWhereBetween(node) {
151
932
  return `${r(node.expr)}${node.not ? " NOT" : ""} BETWEEN ${r(node.low)} AND ${r(node.high)}`;
152
933
  }
153
934
  function handleWhereIn(node) {
154
- return `${r(node.expr)}${node.not ? " NOT" : ""} IN (${mapR(node.list)})`;
935
+ return `${r(node.expr)}${node.not ? " NOT" : ""} IN (${rMap(node.list)})`;
155
936
  }
156
937
  function handleCaseExpr(node) {
157
938
  return `CASE${node.subject ? ` ${r(node.subject)}` : ""} ${node.whens.map((w) => `WHEN ${r(w.condition)} THEN ${r(w.result)}`).join(" ")}${node.else ? ` ELSE ${r(node.else)}` : ""} END`;
@@ -196,8 +977,8 @@ function handleWhereValue(node) {
196
977
  default: return unreachable(node);
197
978
  }
198
979
  }
199
- function handleColumn(node) {
200
- return `${r(node.expr)} ${r(node.alias)}`;
980
+ function handleColumn(node, pretty) {
981
+ return `${pretty ? " " : ""}${r(node.expr)} ${r(node.alias)}`;
201
982
  }
202
983
  function handleAlias(node) {
203
984
  return `AS ${node.name}`;
@@ -209,7 +990,7 @@ function handleColumnRef(node) {
209
990
  function handleFuncCall(node) {
210
991
  const { name, args } = node;
211
992
  if (args.kind === "wildcard") return `${name}(*)`;
212
- return `${name}(${args.distinct ? "DISTINCT " : ""}${mapR(args.args)})`;
993
+ return `${name}(${args.distinct ? "DISTINCT " : ""}${rMap(args.args)})`;
213
994
  }
214
995
  function handleColumnExpr(node) {
215
996
  switch (node.kind) {
@@ -220,34 +1001,9 @@ function handleColumnExpr(node) {
220
1001
  }
221
1002
  }
222
1003
  //#endregion
223
- //#region src/result.ts
224
- function Err(error) {
225
- return {
226
- ok: false,
227
- error,
228
- unwrap() {
229
- throw new Error(String(error));
230
- }
231
- };
232
- }
233
- function Ok(data) {
234
- return {
235
- ok: true,
236
- data,
237
- unwrap() {
238
- return data;
239
- }
240
- };
241
- }
242
- function returnOrThrow(result, throws) {
243
- if (!throws) return result;
244
- if (result.ok) return result.data;
245
- throw result.error;
246
- }
247
- //#endregion
248
1004
  //#region src/guard.ts
249
1005
  const DEFAULT_LIMIT = 1e4;
250
- function applyGuards(ast, guards, limit = DEFAULT_LIMIT) {
1006
+ function applyGuards(ast, guards, limit) {
251
1007
  const ast2 = addWhereGuard(ast, guards);
252
1008
  if (!ast2.ok) return ast2;
253
1009
  return addLimitGuard(ast2.data, limit);
@@ -6448,37 +7204,48 @@ function parseSql(expr) {
6448
7204
  }
6449
7205
  //#endregion
6450
7206
  //#region src/index.ts
6451
- function agentSql(sql, column, value, { schema, limit } = {}) {
7207
+ function agentSql(sql, column, value, { schema, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [] } = {}) {
6452
7208
  return privateAgentSql(sql, {
6453
7209
  guards: { [column]: value },
6454
7210
  schema,
6455
7211
  limit,
7212
+ pretty,
7213
+ db,
7214
+ allowExtraFunctions,
6456
7215
  throws: true
6457
7216
  });
6458
7217
  }
6459
- function createAgentSql(schema, guards, { limit, throws = true } = {}) {
7218
+ function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [], throws = true } = {}) {
6460
7219
  return (expr) => throws ? privateAgentSql(expr, {
6461
7220
  guards,
6462
7221
  schema,
6463
7222
  limit,
7223
+ pretty,
7224
+ db,
7225
+ allowExtraFunctions,
6464
7226
  throws
6465
7227
  }) : privateAgentSql(expr, {
6466
7228
  guards,
6467
7229
  schema,
6468
7230
  limit,
7231
+ pretty,
7232
+ db,
7233
+ allowExtraFunctions,
6469
7234
  throws
6470
7235
  });
6471
7236
  }
6472
- function privateAgentSql(sql, { guards: guardsRaw, schema, limit, throws }) {
7237
+ function privateAgentSql(sql, { guards: guardsRaw, schema, limit, pretty, db, allowExtraFunctions, throws }) {
6473
7238
  const guards = resolveGuards(guardsRaw);
6474
7239
  if (!guards.ok) throw guards.error;
6475
7240
  const ast = parseSql(sql);
6476
7241
  if (!ast.ok) return returnOrThrow(ast, throws);
6477
7242
  const ast2 = checkJoins(ast.data, schema);
6478
7243
  if (!ast2.ok) return returnOrThrow(ast2, throws);
6479
- const san = applyGuards(ast2.data, guards.data, limit);
7244
+ const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
7245
+ if (!ast3.ok) return returnOrThrow(ast3, throws);
7246
+ const san = applyGuards(ast3.data, guards.data, limit);
6480
7247
  if (!san.ok) return returnOrThrow(san, throws);
6481
- const res = outputSql(san.data);
7248
+ const res = outputSql(san.data, pretty);
6482
7249
  if (throws) return res;
6483
7250
  return Ok(res);
6484
7251
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "agent-sql",
3
- "version": "0.2.1",
3
+ "version": "0.2.3",
4
4
  "description": "A starter for creating a TypeScript package.",
5
5
  "keywords": [
6
6
  "agent",
@@ -69,6 +69,5 @@
69
69
  "vite": "npm:@voidzero-dev/vite-plus-core@latest",
70
70
  "vitest": "npm:@voidzero-dev/vite-plus-test@latest"
71
71
  }
72
- },
73
- "logo": "https://cdn.jsdelivr.net/gh/carderne/agent-sql@main/docs/logo.png"
72
+ }
74
73
  }