agent-sql 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -12,6 +12,8 @@ It ensures that that the needed tenant table is somewhere in the query,
12
12
  and adds a `WHERE` clause ensuring that only values from the supplied ID are returned.
13
13
  Then it checks that the tables and `JOIN`s follow the schema, preventing sneaky joins.
14
14
 
15
+ Function calls also go through a whitelist (configurable).
16
+
15
17
  Finally, we throw in a `LIMIT` clause (configurable) to prevent accidental LLM denial-of-service.
16
18
 
17
19
  Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
package/dist/index.d.mts CHANGED
@@ -1,5 +1,8 @@
1
1
  import { i as SelectStatement, n as defineSchema, r as Result, t as Schema } from "./joins-Cu_0yAgN.mjs";
2
2
 
3
+ //#region src/functions.d.ts
4
+ type DbType = "postgres" | "pglite" | "sqlite";
5
+ //#endregion
3
6
  //#region src/guard.d.ts
4
7
  type GuardVal = string | number;
5
8
  interface GuardCol {
@@ -23,18 +26,26 @@ declare function parseSql(expr: string): Result<SelectStatement>;
23
26
  //#region src/index.d.ts
24
27
  declare function agentSql<S extends string>(sql: string, column: S & OneOrTwoDots<S>, value: GuardVal, {
25
28
  schema,
26
- limit
29
+ limit,
30
+ db,
31
+ allowExtraFunctions
27
32
  }?: {
28
33
  schema?: Schema;
29
34
  limit?: number;
35
+ db?: DbType;
36
+ allowExtraFunctions?: string[];
30
37
  }): string;
31
38
  declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts: {
32
39
  limit?: number;
33
40
  throws: false;
41
+ db?: DbType;
42
+ allowExtraFunctions?: string[];
34
43
  }): (expr: string) => Result<string>;
35
44
  declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts?: {
36
45
  limit?: number;
37
46
  throws?: true;
47
+ db?: DbType;
48
+ allowExtraFunctions?: string[];
38
49
  }): (expr: string) => string;
39
50
  //#endregion
40
- export { agentSql, createAgentSql, defineSchema, outputSql, parseSql, applyGuards as sanitiseSql };
51
+ export { type DbType, agentSql, createAgentSql, defineSchema, outputSql, parseSql, applyGuards as sanitiseSql };
package/dist/index.mjs CHANGED
@@ -10,6 +10,769 @@ var AgentSqlError = class extends Error {
10
10
  type = "agent_sql_error";
11
11
  };
12
12
  //#endregion
13
+ //#region src/result.ts
14
+ function Err(error) {
15
+ return {
16
+ ok: false,
17
+ error,
18
+ unwrap() {
19
+ throw new Error(String(error));
20
+ }
21
+ };
22
+ }
23
+ function Ok(data) {
24
+ return {
25
+ ok: true,
26
+ data,
27
+ unwrap() {
28
+ return data;
29
+ }
30
+ };
31
+ }
32
+ function returnOrThrow(result, throws) {
33
+ if (!throws) return result;
34
+ if (result.ok) return result.data;
35
+ throw result.error;
36
+ }
37
+ //#endregion
38
+ //#region src/functions.ts
39
+ const COMMON_FUNCTIONS = [
40
+ "count",
41
+ "sum",
42
+ "avg",
43
+ "min",
44
+ "max",
45
+ "coalesce",
46
+ "nullif",
47
+ "greatest",
48
+ "least",
49
+ "lower",
50
+ "upper",
51
+ "trim",
52
+ "ltrim",
53
+ "rtrim",
54
+ "length",
55
+ "replace",
56
+ "substr",
57
+ "substring",
58
+ "concat",
59
+ "reverse",
60
+ "repeat",
61
+ "position",
62
+ "left",
63
+ "right",
64
+ "lpad",
65
+ "rpad",
66
+ "translate",
67
+ "char_length",
68
+ "character_length",
69
+ "octet_length",
70
+ "overlay",
71
+ "ascii",
72
+ "chr",
73
+ "starts_with",
74
+ "abs",
75
+ "ceil",
76
+ "ceiling",
77
+ "floor",
78
+ "round",
79
+ "trunc",
80
+ "truncate",
81
+ "mod",
82
+ "power",
83
+ "sqrt",
84
+ "cbrt",
85
+ "log",
86
+ "ln",
87
+ "exp",
88
+ "sign",
89
+ "random",
90
+ "pi",
91
+ "degrees",
92
+ "radians",
93
+ "div",
94
+ "gcd",
95
+ "lcm",
96
+ "cast"
97
+ ];
98
+ const PGVECTOR_FUNCTIONS = [
99
+ "l2_distance",
100
+ "inner_product",
101
+ "cosine_distance",
102
+ "l1_distance",
103
+ "vector_dims",
104
+ "vector_norm",
105
+ "l2_normalize",
106
+ "binary_quantize",
107
+ "subvector",
108
+ "vector_avg"
109
+ ];
110
+ const POSTGIS_FUNCTIONS = [
111
+ "st_geomfromtext",
112
+ "st_geomfromwkb",
113
+ "st_geomfromewkt",
114
+ "st_geomfromewkb",
115
+ "st_geomfromgeojson",
116
+ "st_geogfromtext",
117
+ "st_geogfromwkb",
118
+ "st_makepoint",
119
+ "st_makepoint",
120
+ "st_makepointm",
121
+ "st_makeenvelope",
122
+ "st_makeline",
123
+ "st_makepolygon",
124
+ "st_makebox2d",
125
+ "st_point",
126
+ "st_pointz",
127
+ "st_pointm",
128
+ "st_pointzm",
129
+ "st_polygon",
130
+ "st_linefrommultipoint",
131
+ "st_tileenvelope",
132
+ "st_hexagongrid",
133
+ "st_squaregrid",
134
+ "st_letters",
135
+ "st_collect",
136
+ "st_linemerge",
137
+ "st_buildarea",
138
+ "st_polygonize",
139
+ "st_unaryunion",
140
+ "st_astext",
141
+ "st_asewkt",
142
+ "st_asbinary",
143
+ "st_asewkb",
144
+ "st_asgeojson",
145
+ "st_asgml",
146
+ "st_askml",
147
+ "st_assvg",
148
+ "st_astwkb",
149
+ "st_asmvtgeom",
150
+ "st_asmvt",
151
+ "st_asencodedpolyline",
152
+ "st_ashexewkb",
153
+ "st_aslatlontext",
154
+ "st_x",
155
+ "st_y",
156
+ "st_z",
157
+ "st_m",
158
+ "st_geometrytype",
159
+ "st_srid",
160
+ "st_dimension",
161
+ "st_coorddim",
162
+ "st_numgeometries",
163
+ "st_numpoints",
164
+ "st_npoints",
165
+ "st_nrings",
166
+ "st_numinteriorrings",
167
+ "st_numinteriorring",
168
+ "st_exteriorring",
169
+ "st_interiorringn",
170
+ "st_geometryn",
171
+ "st_pointn",
172
+ "st_startpoint",
173
+ "st_endpoint",
174
+ "st_envelope",
175
+ "st_boundingdiagonal",
176
+ "st_xmin",
177
+ "st_xmax",
178
+ "st_ymin",
179
+ "st_ymax",
180
+ "st_zmin",
181
+ "st_zmax",
182
+ "st_isempty",
183
+ "st_isclosed",
184
+ "st_isring",
185
+ "st_issimple",
186
+ "st_isvalid",
187
+ "st_isvalidreason",
188
+ "st_isvaliddetail",
189
+ "st_hasm",
190
+ "st_hasz",
191
+ "st_ismeasured",
192
+ "st_intersects",
193
+ "st_disjoint",
194
+ "st_contains",
195
+ "st_within",
196
+ "st_covers",
197
+ "st_coveredby",
198
+ "st_crosses",
199
+ "st_overlaps",
200
+ "st_touches",
201
+ "st_equals",
202
+ "st_relate",
203
+ "st_containsproperly",
204
+ "st_dwithin",
205
+ "st_3dintersects",
206
+ "st_3ddwithin",
207
+ "st_orderingequals",
208
+ "st_distance",
209
+ "st_3ddistance",
210
+ "st_maxdistance",
211
+ "st_area",
212
+ "st_length",
213
+ "st_length2d",
214
+ "st_3dlength",
215
+ "st_perimeter",
216
+ "st_azimuth",
217
+ "st_angle",
218
+ "st_hausdorffdistance",
219
+ "st_frechetdistance",
220
+ "st_longestline",
221
+ "st_shortestline",
222
+ "st_transform",
223
+ "st_setsrid",
224
+ "st_force2d",
225
+ "st_force3d",
226
+ "st_force3dz",
227
+ "st_force3dm",
228
+ "st_force4d",
229
+ "st_forcecollection",
230
+ "st_forcepolygoncw",
231
+ "st_forcepolygonccw",
232
+ "st_forcecurve",
233
+ "st_forcesfs",
234
+ "st_multi",
235
+ "st_normalize",
236
+ "st_flipcoordinates",
237
+ "st_translate",
238
+ "st_scale",
239
+ "st_rotate",
240
+ "st_rotatex",
241
+ "st_rotatey",
242
+ "st_rotatez",
243
+ "st_affine",
244
+ "st_transscale",
245
+ "st_snap",
246
+ "st_snaptogrid",
247
+ "st_segmentize",
248
+ "st_simplify",
249
+ "st_simplifypreservetopology",
250
+ "st_simplifyvw",
251
+ "st_chaikinsmoothing",
252
+ "st_seteffectivearea",
253
+ "st_filterbym",
254
+ "st_locatebetween",
255
+ "st_locatebetweenelevations",
256
+ "st_offsetcurve",
257
+ "st_intersection",
258
+ "st_union",
259
+ "st_difference",
260
+ "st_symdifference",
261
+ "st_buffer",
262
+ "st_convexhull",
263
+ "st_concavehull",
264
+ "st_minimumboundingcircle",
265
+ "st_minimumboundingradius",
266
+ "st_orientedenvelope",
267
+ "st_centroid",
268
+ "st_pointonsurface",
269
+ "st_geometricmedian",
270
+ "st_voronoipolygons",
271
+ "st_voronoilines",
272
+ "st_delaunaytriangles",
273
+ "st_subdivide",
274
+ "st_split",
275
+ "st_sharedpaths",
276
+ "st_node",
277
+ "st_clusterdbscan",
278
+ "st_clusterkmeans",
279
+ "st_clusterintersecting",
280
+ "st_clusterwithin",
281
+ "st_makevalid",
282
+ "st_lineinterpolatepoint",
283
+ "st_lineinterpolatepoints",
284
+ "st_linelocatepoint",
285
+ "st_linesubstring",
286
+ "st_addmeasure",
287
+ "st_closestpoint",
288
+ "st_linefromencodedpolyline",
289
+ "box2d",
290
+ "box3d",
291
+ "st_expand",
292
+ "st_estimatedextent",
293
+ "st_extent",
294
+ "st_3dextent",
295
+ "st_memsize",
296
+ "st_distancesphere",
297
+ "st_distancespheroid",
298
+ "st_project",
299
+ "st_memunion",
300
+ "st_polygonize",
301
+ "st_nband",
302
+ "st_numbands",
303
+ "st_summary",
304
+ "st_dump",
305
+ "st_dumppoints",
306
+ "st_dumprings",
307
+ "postgis_version",
308
+ "postgis_full_version",
309
+ "postgis_geos_version",
310
+ "postgis_proj_version",
311
+ "postgis_lib_version",
312
+ "postgis_scripts_installed",
313
+ "postgis_scripts_released",
314
+ "postgis_type_name",
315
+ "populate_geometry_columns",
316
+ "find_srid",
317
+ "updategeometrysrid",
318
+ "addgeometrycolumn",
319
+ "dropgeometrycolumn",
320
+ "geography",
321
+ "geometry"
322
+ ];
323
+ const POSTGRES_FUNCTIONS = [
324
+ ...COMMON_FUNCTIONS,
325
+ "array_agg",
326
+ "string_agg",
327
+ "json_agg",
328
+ "jsonb_agg",
329
+ "json_object_agg",
330
+ "jsonb_object_agg",
331
+ "bool_and",
332
+ "bool_or",
333
+ "every",
334
+ "bit_and",
335
+ "bit_or",
336
+ "bit_xor",
337
+ "corr",
338
+ "covar_pop",
339
+ "covar_samp",
340
+ "regr_avgx",
341
+ "regr_avgy",
342
+ "regr_count",
343
+ "regr_intercept",
344
+ "regr_r2",
345
+ "regr_slope",
346
+ "regr_sxx",
347
+ "regr_sxy",
348
+ "regr_syy",
349
+ "stddev",
350
+ "stddev_pop",
351
+ "stddev_samp",
352
+ "variance",
353
+ "var_pop",
354
+ "var_samp",
355
+ "percentile_cont",
356
+ "percentile_disc",
357
+ "mode",
358
+ "rank",
359
+ "dense_rank",
360
+ "percent_rank",
361
+ "cume_dist",
362
+ "ntile",
363
+ "lag",
364
+ "lead",
365
+ "first_value",
366
+ "last_value",
367
+ "nth_value",
368
+ "row_number",
369
+ "initcap",
370
+ "strpos",
371
+ "encode",
372
+ "decode",
373
+ "md5",
374
+ "sha256",
375
+ "sha224",
376
+ "sha384",
377
+ "sha512",
378
+ "format",
379
+ "concat_ws",
380
+ "regexp_replace",
381
+ "regexp_match",
382
+ "regexp_matches",
383
+ "regexp_split_to_array",
384
+ "regexp_split_to_table",
385
+ "split_part",
386
+ "btrim",
387
+ "bit_length",
388
+ "quote_ident",
389
+ "quote_literal",
390
+ "quote_nullable",
391
+ "to_hex",
392
+ "convert",
393
+ "convert_from",
394
+ "convert_to",
395
+ "string_to_array",
396
+ "array_to_string",
397
+ "now",
398
+ "current_timestamp",
399
+ "current_date",
400
+ "current_time",
401
+ "localtime",
402
+ "localtimestamp",
403
+ "clock_timestamp",
404
+ "statement_timestamp",
405
+ "transaction_timestamp",
406
+ "timeofday",
407
+ "date_trunc",
408
+ "date_part",
409
+ "extract",
410
+ "age",
411
+ "to_char",
412
+ "to_date",
413
+ "to_timestamp",
414
+ "to_number",
415
+ "make_date",
416
+ "make_time",
417
+ "make_timestamp",
418
+ "make_timestamptz",
419
+ "make_interval",
420
+ "justify_days",
421
+ "justify_hours",
422
+ "justify_interval",
423
+ "isfinite",
424
+ "json_extract_path",
425
+ "json_extract_path_text",
426
+ "jsonb_extract_path",
427
+ "jsonb_extract_path_text",
428
+ "json_array_length",
429
+ "jsonb_array_length",
430
+ "json_typeof",
431
+ "jsonb_typeof",
432
+ "json_build_object",
433
+ "jsonb_build_object",
434
+ "json_build_array",
435
+ "jsonb_build_array",
436
+ "to_json",
437
+ "to_jsonb",
438
+ "row_to_json",
439
+ "json_each",
440
+ "json_each_text",
441
+ "jsonb_each",
442
+ "jsonb_each_text",
443
+ "json_object_keys",
444
+ "jsonb_object_keys",
445
+ "json_populate_record",
446
+ "jsonb_populate_record",
447
+ "json_populate_recordset",
448
+ "jsonb_populate_recordset",
449
+ "json_to_record",
450
+ "jsonb_to_record",
451
+ "json_to_recordset",
452
+ "jsonb_to_recordset",
453
+ "json_array_elements",
454
+ "jsonb_array_elements",
455
+ "json_array_elements_text",
456
+ "jsonb_array_elements_text",
457
+ "jsonb_set",
458
+ "jsonb_set_lax",
459
+ "jsonb_insert",
460
+ "jsonb_path_query",
461
+ "jsonb_path_query_array",
462
+ "jsonb_path_query_first",
463
+ "jsonb_path_exists",
464
+ "jsonb_path_match",
465
+ "jsonb_strip_nulls",
466
+ "jsonb_pretty",
467
+ "json_strip_nulls",
468
+ "to_tsvector",
469
+ "to_tsquery",
470
+ "plainto_tsquery",
471
+ "phraseto_tsquery",
472
+ "websearch_to_tsquery",
473
+ "ts_rank",
474
+ "ts_rank_cd",
475
+ "ts_headline",
476
+ "tsvector_to_array",
477
+ "array_to_tsvector",
478
+ "numnode",
479
+ "querytree",
480
+ "ts_rewrite",
481
+ "setweight",
482
+ "strip",
483
+ "ts_debug",
484
+ "ts_lexize",
485
+ "ts_parse",
486
+ "ts_token_type",
487
+ "get_current_ts_config",
488
+ "array_append",
489
+ "array_cat",
490
+ "array_dims",
491
+ "array_fill",
492
+ "array_length",
493
+ "array_lower",
494
+ "array_ndims",
495
+ "array_position",
496
+ "array_positions",
497
+ "array_prepend",
498
+ "array_remove",
499
+ "array_replace",
500
+ "array_upper",
501
+ "cardinality",
502
+ "unnest",
503
+ "generate_subscripts",
504
+ "lower",
505
+ "upper",
506
+ "isempty",
507
+ "lower_inc",
508
+ "lower_inf",
509
+ "upper_inc",
510
+ "upper_inf",
511
+ "range_merge",
512
+ "generate_series",
513
+ "pg_typeof",
514
+ "current_setting",
515
+ "current_database",
516
+ "current_schema",
517
+ "current_schemas",
518
+ "current_user",
519
+ "session_user",
520
+ "inet_client_addr",
521
+ "inet_client_port",
522
+ "version",
523
+ "obj_description",
524
+ "col_description",
525
+ "shobj_description",
526
+ "has_table_privilege",
527
+ "has_column_privilege",
528
+ "has_schema_privilege",
529
+ "txid_current",
530
+ "txid_current_snapshot",
531
+ "area",
532
+ "center",
533
+ "diameter",
534
+ "height",
535
+ "width",
536
+ "isclosed",
537
+ "isopen",
538
+ "npoints",
539
+ "pclose",
540
+ "popen",
541
+ "radius",
542
+ "abbrev",
543
+ "broadcast",
544
+ "family",
545
+ "host",
546
+ "hostmask",
547
+ "masklen",
548
+ "netmask",
549
+ "network",
550
+ "set_masklen",
551
+ "inet_merge",
552
+ "inet_same_family",
553
+ "gen_random_uuid",
554
+ "uuid_generate_v1",
555
+ "uuid_generate_v4",
556
+ ...PGVECTOR_FUNCTIONS,
557
+ ...POSTGIS_FUNCTIONS
558
+ ];
559
+ const SQLITE_FUNCTIONS = [
560
+ ...COMMON_FUNCTIONS,
561
+ "group_concat",
562
+ "total",
563
+ "char",
564
+ "format",
565
+ "glob",
566
+ "hex",
567
+ "unhex",
568
+ "instr",
569
+ "like",
570
+ "ltrim",
571
+ "rtrim",
572
+ "trim",
573
+ "printf",
574
+ "quote",
575
+ "soundex",
576
+ "unicode",
577
+ "zeroblob",
578
+ "acos",
579
+ "acosh",
580
+ "asin",
581
+ "asinh",
582
+ "atan",
583
+ "atan2",
584
+ "atanh",
585
+ "cos",
586
+ "cosh",
587
+ "sin",
588
+ "sinh",
589
+ "tan",
590
+ "tanh",
591
+ "date",
592
+ "time",
593
+ "datetime",
594
+ "julianday",
595
+ "unixepoch",
596
+ "strftime",
597
+ "timediff",
598
+ "typeof",
599
+ "type",
600
+ "last_insert_rowid",
601
+ "changes",
602
+ "total_changes",
603
+ "sqlite_version",
604
+ "json",
605
+ "json_array",
606
+ "json_array_length",
607
+ "json_extract",
608
+ "json_insert",
609
+ "json_object",
610
+ "json_patch",
611
+ "json_remove",
612
+ "json_replace",
613
+ "json_set",
614
+ "json_type",
615
+ "json_valid",
616
+ "json_quote",
617
+ "json_group_array",
618
+ "json_group_object",
619
+ "json_each",
620
+ "json_tree",
621
+ "iif",
622
+ "ifnull",
623
+ "likely",
624
+ "unlikely",
625
+ "max",
626
+ "min",
627
+ "nullif",
628
+ "randomblob",
629
+ "row_number",
630
+ "rank",
631
+ "dense_rank",
632
+ "percent_rank",
633
+ "cume_dist",
634
+ "ntile",
635
+ "lag",
636
+ "lead",
637
+ "first_value",
638
+ "last_value",
639
+ "nth_value"
640
+ ];
641
+ function checkFunctions(ast, db, allowExtraFunctions) {
642
+ const bad = findDisallowedFunction(ast, buildAllowedSet(db, allowExtraFunctions));
643
+ if (bad !== null) return Err(new SanitiseError(`Function '${bad}' is not allowed`));
644
+ return Ok(ast);
645
+ }
646
+ function buildSet(list) {
647
+ return new Set(list.map((f) => f.toLowerCase()));
648
+ }
649
+ const PG_SET = buildSet(POSTGRES_FUNCTIONS);
650
+ const SQLITE_SET = buildSet(SQLITE_FUNCTIONS);
651
+ function getAllowedFunctions(db) {
652
+ return db === "sqlite" ? SQLITE_SET : PG_SET;
653
+ }
654
+ function buildAllowedSet(db, extra) {
655
+ const base = getAllowedFunctions(db);
656
+ if (extra.length === 0) return base;
657
+ const merged = new Set(base);
658
+ for (const f of extra) merged.add(f.toLowerCase());
659
+ return merged;
660
+ }
661
+ function findDisallowedFunction(ast, allowed) {
662
+ for (const col of ast.columns) if (col.expr.kind === "expr" && col.expr.expr) {
663
+ const bad = checkWhereValue(col.expr.expr, allowed);
664
+ if (bad) return bad;
665
+ }
666
+ if (ast.distinct && ast.distinct.type === "distinct_on") for (const val of ast.distinct.columns) {
667
+ const bad = checkWhereValue(val, allowed);
668
+ if (bad) return bad;
669
+ }
670
+ for (const join of ast.joins) if (join.condition && join.condition.type === "join_on") {
671
+ const bad = checkWhereExpr(join.condition.expr, allowed);
672
+ if (bad) return bad;
673
+ }
674
+ if (ast.where) {
675
+ const bad = checkWhereExpr(ast.where.inner, allowed);
676
+ if (bad) return bad;
677
+ }
678
+ if (ast.groupBy) for (const item of ast.groupBy.items) {
679
+ const bad = checkWhereValue(item, allowed);
680
+ if (bad) return bad;
681
+ }
682
+ if (ast.having) {
683
+ const bad = checkWhereExpr(ast.having.expr, allowed);
684
+ if (bad) return bad;
685
+ }
686
+ if (ast.orderBy) for (const item of ast.orderBy.items) {
687
+ const bad = checkWhereValue(item.expr, allowed);
688
+ if (bad) return bad;
689
+ }
690
+ return null;
691
+ }
692
+ function checkFuncCall(func, allowed) {
693
+ if (!allowed.has(func.name.toLowerCase())) return func.name;
694
+ if (func.args.kind === "args") for (const arg of func.args.args) {
695
+ const bad = checkWhereValue(arg, allowed);
696
+ if (bad) return bad;
697
+ }
698
+ return null;
699
+ }
700
+ function checkWhereValue(val, allowed) {
701
+ switch (val.type) {
702
+ case "where_value":
703
+ if (val.kind === "func_call") return checkFuncCall(val.func, allowed);
704
+ return null;
705
+ case "where_arith":
706
+ case "where_jsonb_op":
707
+ case "where_pgvector_op": {
708
+ const l = checkWhereValue(val.left, allowed);
709
+ if (l) return l;
710
+ return checkWhereValue(val.right, allowed);
711
+ }
712
+ case "where_unary_minus": return checkWhereValue(val.expr, allowed);
713
+ case "case_expr":
714
+ if (val.subject) {
715
+ const s = checkWhereValue(val.subject, allowed);
716
+ if (s) return s;
717
+ }
718
+ for (const w of val.whens) {
719
+ const c = checkWhereValue(w.condition, allowed);
720
+ if (c) return c;
721
+ const r = checkWhereValue(w.result, allowed);
722
+ if (r) return r;
723
+ }
724
+ if (val.else) return checkWhereValue(val.else, allowed);
725
+ return null;
726
+ case "cast_expr": return checkWhereValue(val.expr, allowed);
727
+ default: return null;
728
+ }
729
+ }
730
+ function checkWhereExpr(expr, allowed) {
731
+ switch (expr.type) {
732
+ case "where_and":
733
+ case "where_or": {
734
+ const l = checkWhereExpr(expr.left, allowed);
735
+ if (l) return l;
736
+ return checkWhereExpr(expr.right, allowed);
737
+ }
738
+ case "where_not": return checkWhereExpr(expr.expr, allowed);
739
+ case "where_comparison": {
740
+ const l = checkWhereValue(expr.left, allowed);
741
+ if (l) return l;
742
+ return checkWhereValue(expr.right, allowed);
743
+ }
744
+ case "where_is_null": return checkWhereValue(expr.expr, allowed);
745
+ case "where_is_bool": return checkWhereValue(expr.expr, allowed);
746
+ case "where_between": {
747
+ const e = checkWhereValue(expr.expr, allowed);
748
+ if (e) return e;
749
+ const lo = checkWhereValue(expr.low, allowed);
750
+ if (lo) return lo;
751
+ return checkWhereValue(expr.high, allowed);
752
+ }
753
+ case "where_in": {
754
+ const e = checkWhereValue(expr.expr, allowed);
755
+ if (e) return e;
756
+ for (const item of expr.list) {
757
+ const bad = checkWhereValue(item, allowed);
758
+ if (bad) return bad;
759
+ }
760
+ return null;
761
+ }
762
+ case "where_like": {
763
+ const e = checkWhereValue(expr.expr, allowed);
764
+ if (e) return e;
765
+ return checkWhereValue(expr.pattern, allowed);
766
+ }
767
+ case "where_ts_match": {
768
+ const l = checkWhereValue(expr.left, allowed);
769
+ if (l) return l;
770
+ return checkWhereValue(expr.right, allowed);
771
+ }
772
+ default: return null;
773
+ }
774
+ }
775
+ //#endregion
13
776
  //#region src/utils.ts
14
777
  function unreachable(x) {
15
778
  throw new Error(`Unhandled variant: ${JSON.stringify(x)}`);
@@ -220,31 +983,6 @@ function handleColumnExpr(node) {
220
983
  }
221
984
  }
222
985
  //#endregion
223
- //#region src/result.ts
224
- function Err(error) {
225
- return {
226
- ok: false,
227
- error,
228
- unwrap() {
229
- throw new Error(String(error));
230
- }
231
- };
232
- }
233
- function Ok(data) {
234
- return {
235
- ok: true,
236
- data,
237
- unwrap() {
238
- return data;
239
- }
240
- };
241
- }
242
- function returnOrThrow(result, throws) {
243
- if (!throws) return result;
244
- if (result.ok) return result.data;
245
- throw result.error;
246
- }
247
- //#endregion
248
986
  //#region src/guard.ts
249
987
  const DEFAULT_LIMIT = 1e4;
250
988
  function applyGuards(ast, guards, limit = DEFAULT_LIMIT) {
@@ -6448,35 +7186,43 @@ function parseSql(expr) {
6448
7186
  }
6449
7187
  //#endregion
6450
7188
  //#region src/index.ts
6451
- function agentSql(sql, column, value, { schema, limit } = {}) {
7189
+ function agentSql(sql, column, value, { schema, limit, db = "postgres", allowExtraFunctions = [] } = {}) {
6452
7190
  return privateAgentSql(sql, {
6453
7191
  guards: { [column]: value },
6454
7192
  schema,
6455
7193
  limit,
7194
+ db,
7195
+ allowExtraFunctions,
6456
7196
  throws: true
6457
7197
  });
6458
7198
  }
6459
- function createAgentSql(schema, guards, { limit, throws = true } = {}) {
7199
+ function createAgentSql(schema, guards, { limit, throws = true, db = "postgres", allowExtraFunctions = [] } = {}) {
6460
7200
  return (expr) => throws ? privateAgentSql(expr, {
6461
7201
  guards,
6462
7202
  schema,
6463
7203
  limit,
7204
+ db,
7205
+ allowExtraFunctions,
6464
7206
  throws
6465
7207
  }) : privateAgentSql(expr, {
6466
7208
  guards,
6467
7209
  schema,
6468
7210
  limit,
7211
+ db,
7212
+ allowExtraFunctions,
6469
7213
  throws
6470
7214
  });
6471
7215
  }
6472
- function privateAgentSql(sql, { guards: guardsRaw, schema, limit, throws }) {
7216
+ function privateAgentSql(sql, { guards: guardsRaw, schema, limit, db, allowExtraFunctions, throws }) {
6473
7217
  const guards = resolveGuards(guardsRaw);
6474
7218
  if (!guards.ok) throw guards.error;
6475
7219
  const ast = parseSql(sql);
6476
7220
  if (!ast.ok) return returnOrThrow(ast, throws);
6477
7221
  const ast2 = checkJoins(ast.data, schema);
6478
7222
  if (!ast2.ok) return returnOrThrow(ast2, throws);
6479
- const san = applyGuards(ast2.data, guards.data, limit);
7223
+ const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
7224
+ if (!ast3.ok) return returnOrThrow(ast3, throws);
7225
+ const san = applyGuards(ast3.data, guards.data, limit);
6480
7226
  if (!san.ok) return returnOrThrow(san, throws);
6481
7227
  const res = outputSql(san.data);
6482
7228
  if (throws) return res;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "agent-sql",
3
- "version": "0.2.1",
3
+ "version": "0.2.2",
4
4
  "description": "A starter for creating a TypeScript package.",
5
5
  "keywords": [
6
6
  "agent",
@@ -69,6 +69,5 @@
69
69
  "vite": "npm:@voidzero-dev/vite-plus-core@latest",
70
70
  "vitest": "npm:@voidzero-dev/vite-plus-test@latest"
71
71
  }
72
- },
73
- "logo": "https://cdn.jsdelivr.net/gh/carderne/agent-sql@main/docs/logo.png"
72
+ }
74
73
  }