pg_query 2.0.3 → 2.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +77 -0
  3. data/README.md +12 -0
  4. data/Rakefile +5 -19
  5. data/ext/pg_query/extconf.rb +3 -1
  6. data/ext/pg_query/include/c.h +12 -0
  7. data/ext/pg_query/include/executor/executor.h +6 -0
  8. data/ext/pg_query/include/nodes/execnodes.h +9 -6
  9. data/ext/pg_query/include/nodes/pathnodes.h +1 -1
  10. data/ext/pg_query/include/optimizer/paths.h +8 -0
  11. data/ext/pg_query/include/pg_config.h +9 -6
  12. data/ext/pg_query/include/pg_config_manual.h +7 -0
  13. data/ext/pg_query/include/pg_query.h +2 -2
  14. data/ext/pg_query/include/pg_query_outfuncs_defs.c +1 -0
  15. data/ext/pg_query/include/pg_query_readfuncs_defs.c +1 -0
  16. data/ext/pg_query/include/protobuf/pg_query.pb-c.h +472 -467
  17. data/ext/pg_query/include/protobuf-c/protobuf-c.h +7 -3
  18. data/ext/pg_query/include/protobuf-c.h +7 -3
  19. data/ext/pg_query/include/utils/array.h +1 -0
  20. data/ext/pg_query/include/utils/lsyscache.h +1 -0
  21. data/ext/pg_query/include/utils/probes.h +57 -57
  22. data/ext/pg_query/pg_query.pb-c.c +502 -487
  23. data/ext/pg_query/pg_query_deparse.c +6 -0
  24. data/ext/pg_query/pg_query_fingerprint.c +119 -32
  25. data/ext/pg_query/pg_query_fingerprint.h +3 -1
  26. data/ext/pg_query/pg_query_normalize.c +222 -63
  27. data/ext/pg_query/pg_query_parse_plpgsql.c +21 -1
  28. data/ext/pg_query/pg_query_ruby.c +1 -1
  29. data/ext/pg_query/pg_query_ruby.sym +1 -0
  30. data/ext/pg_query/protobuf-c.c +34 -27
  31. data/ext/pg_query/src_backend_utils_mmgr_mcxt.c +36 -0
  32. data/ext/pg_query/src_common_hashfn.c +420 -0
  33. data/ext/pg_query/src_pl_plpgsql_src_pl_gram.c +1 -1
  34. data/lib/pg_query/filter_columns.rb +4 -4
  35. data/lib/pg_query/fingerprint.rb +1 -3
  36. data/lib/pg_query/parse.rb +111 -45
  37. data/lib/pg_query/pg_query_pb.rb +1385 -1383
  38. data/lib/pg_query/version.rb +1 -1
  39. data/lib/pg_query.rb +0 -1
  40. metadata +8 -8
  41. data/lib/pg_query/json_field_names.rb +0 -1402
@@ -2247,6 +2247,12 @@ static void deparseRangeVar(StringInfo str, RangeVar *range_var, DeparseNodeCont
2247
2247
  if (!range_var->inh && context != DEPARSE_NODE_CONTEXT_CREATE_TYPE && context != DEPARSE_NODE_CONTEXT_ALTER_TYPE)
2248
2248
  appendStringInfoString(str, "ONLY ");
2249
2249
 
2250
+ if (range_var->catalogname != NULL)
2251
+ {
2252
+ appendStringInfoString(str, quote_identifier(range_var->catalogname));
2253
+ appendStringInfoChar(str, '.');
2254
+ }
2255
+
2250
2256
  if (range_var->schemaname != NULL)
2251
2257
  {
2252
2258
  appendStringInfoString(str, quote_identifier(range_var->schemaname));
@@ -17,6 +17,8 @@
17
17
  #include "nodes/parsenodes.h"
18
18
  #include "nodes/value.h"
19
19
 
20
+ #include "common/hashfn.h"
21
+
20
22
  #include <unistd.h>
21
23
  #include <fcntl.h>
22
24
 
@@ -26,15 +28,41 @@ typedef struct FingerprintContext
26
28
  {
27
29
  XXH3_state_t *xxh_state;
28
30
 
31
+ struct listsort_cache_hash *listsort_cache;
32
+
29
33
  bool write_tokens;
30
34
  dlist_head tokens;
31
35
  } FingerprintContext;
32
36
 
33
- typedef struct FingerprintListContext
37
+ typedef struct FingerprintListsortItem
34
38
  {
35
39
  XXH64_hash_t hash;
36
40
  size_t list_pos;
37
- } FingerprintListContext;
41
+ } FingerprintListsortItem;
42
+
43
+ typedef struct FingerprintListsortItemCacheEntry
44
+ {
45
+ /* List node this cache entry is for */
46
+ uintptr_t node;
47
+
48
+ /* Hashes of all list items -- this is expensive to calculate */
49
+ FingerprintListsortItem **listsort_items;
50
+ size_t listsort_items_size;
51
+
52
+ /* hash entry status */
53
+ char status;
54
+ } FingerprintListsortItemCacheEntry;
55
+
56
+ #define SH_PREFIX listsort_cache
57
+ #define SH_ELEMENT_TYPE FingerprintListsortItemCacheEntry
58
+ #define SH_KEY_TYPE uintptr_t
59
+ #define SH_KEY node
60
+ #define SH_HASH_KEY(tb, key) hash_bytes((const unsigned char *) &key, sizeof(uintptr_t))
61
+ #define SH_EQUAL(tb, a, b) a == b
62
+ #define SH_SCOPE static inline
63
+ #define SH_DEFINE
64
+ #define SH_DECLARE
65
+ #include "lib/simplehash.h"
38
66
 
39
67
  typedef struct FingerprintToken
40
68
  {
@@ -43,7 +71,7 @@ typedef struct FingerprintToken
43
71
  } FingerprintToken;
44
72
 
45
73
  static void _fingerprintNode(FingerprintContext *ctx, const void *obj, const void *parent, char *parent_field_name, unsigned int depth);
46
- static void _fingerprintInitContext(FingerprintContext *ctx, bool write_tokens);
74
+ static void _fingerprintInitContext(FingerprintContext *ctx, FingerprintContext *parent, bool write_tokens);
47
75
  static void _fingerprintFreeContext(FingerprintContext *ctx);
48
76
 
49
77
  #define PG_QUERY_FINGERPRINT_VERSION 3
@@ -96,10 +124,10 @@ _fingerprintBitString(FingerprintContext *ctx, const Value *node)
96
124
  }
97
125
  }
98
126
 
99
- static int compareFingerprintListContext(const void *a, const void *b)
127
+ static int compareFingerprintListsortItem(const void *a, const void *b)
100
128
  {
101
- FingerprintListContext *ca = *(FingerprintListContext**) a;
102
- FingerprintListContext *cb = *(FingerprintListContext**) b;
129
+ FingerprintListsortItem *ca = *(FingerprintListsortItem**) a;
130
+ FingerprintListsortItem *cb = *(FingerprintListsortItem**) b;
103
131
  if (ca->hash > cb->hash)
104
132
  return 1;
105
133
  else if (ca->hash < cb->hash)
@@ -111,38 +139,69 @@ static void
111
139
  _fingerprintList(FingerprintContext *ctx, const List *node, const void *parent, char *field_name, unsigned int depth)
112
140
  {
113
141
  if (field_name != NULL && (strcmp(field_name, "fromClause") == 0 || strcmp(field_name, "targetList") == 0 ||
114
- strcmp(field_name, "cols") == 0 || strcmp(field_name, "rexpr") == 0 || strcmp(field_name, "valuesLists") == 0 ||
115
- strcmp(field_name, "args") == 0)) {
142
+ strcmp(field_name, "cols") == 0 || strcmp(field_name, "rexpr") == 0 || strcmp(field_name, "valuesLists") == 0 ||
143
+ strcmp(field_name, "args") == 0))
144
+ {
145
+ /*
146
+ * Check for cached values for the hashes of subnodes
147
+ *
148
+ * Note this cache is important so we avoid exponential runtime behavior,
149
+ * which would be the case if we fingerprinted each node twice, which
150
+ * then would also again have to fingerprint each of its subnodes twice,
151
+ * etc., leading to deep nodes to be fingerprinted many many times over.
152
+ *
153
+ * We have seen real-world problems with this logic here without
154
+ * a cache in place.
155
+ */
156
+ FingerprintListsortItem** listsort_items = NULL;
157
+ size_t listsort_items_size = 0;
158
+ FingerprintListsortItemCacheEntry *entry = listsort_cache_lookup(ctx->listsort_cache, (uintptr_t) node);
159
+ if (entry != NULL)
160
+ {
161
+ listsort_items = entry->listsort_items;
162
+ listsort_items_size = entry->listsort_items_size;
163
+ }
164
+ else
165
+ {
166
+ listsort_items = palloc0(node->length * sizeof(FingerprintListsortItem*));
167
+ listsort_items_size = 0;
168
+ ListCell *lc;
169
+ bool found;
116
170
 
117
- FingerprintListContext** listCtxArr = palloc0(node->length * sizeof(FingerprintListContext*));
118
- size_t listCtxCount = 0;
119
- const ListCell *lc;
171
+ foreach(lc, node)
172
+ {
173
+ FingerprintContext fctx;
174
+ FingerprintListsortItem* lctx = palloc0(sizeof(FingerprintListsortItem));
120
175
 
121
- foreach(lc, node)
122
- {
123
- FingerprintContext subCtx;
124
- FingerprintListContext* listCtx = palloc0(sizeof(FingerprintListContext));
176
+ _fingerprintInitContext(&fctx, ctx, false);
177
+ _fingerprintNode(&fctx, lfirst(lc), parent, field_name, depth + 1);
178
+ lctx->hash = XXH3_64bits_digest(fctx.xxh_state);
179
+ lctx->list_pos = listsort_items_size;
180
+ _fingerprintFreeContext(&fctx);
125
181
 
126
- _fingerprintInitContext(&subCtx, false);
127
- _fingerprintNode(&subCtx, lfirst(lc), parent, field_name, depth + 1);
128
- listCtx->hash = XXH3_64bits_digest(subCtx.xxh_state);
129
- listCtx->list_pos = listCtxCount;
130
- _fingerprintFreeContext(&subCtx);
182
+ listsort_items[listsort_items_size] = lctx;
183
+ listsort_items_size += 1;
184
+ }
131
185
 
132
- listCtxArr[listCtxCount] = listCtx;
133
- listCtxCount += 1;
134
- }
186
+ pg_qsort(listsort_items, listsort_items_size, sizeof(FingerprintListsortItem*), compareFingerprintListsortItem);
135
187
 
136
- pg_qsort(listCtxArr, listCtxCount, sizeof(FingerprintListContext*), compareFingerprintListContext);
188
+ FingerprintListsortItemCacheEntry *entry = listsort_cache_insert(ctx->listsort_cache, (uintptr_t) node, &found);
189
+ Assert(!found);
137
190
 
138
- for (size_t i = 0; i < listCtxCount; i++)
191
+ entry->listsort_items = listsort_items;
192
+ entry->listsort_items_size = listsort_items_size;
193
+ }
194
+
195
+ for (size_t i = 0; i < listsort_items_size; i++)
139
196
  {
140
- if (i > 0 && listCtxArr[i - 1]->hash == listCtxArr[i]->hash)
197
+ if (i > 0 && listsort_items[i - 1]->hash == listsort_items[i]->hash)
141
198
  continue; // Ignore duplicates
142
199
 
143
- _fingerprintNode(ctx, lfirst(list_nth_cell(node, listCtxArr[i]->list_pos)), parent, field_name, depth + 1);
200
+ _fingerprintNode(ctx, lfirst(list_nth_cell(node, listsort_items[i]->list_pos)), parent, field_name, depth + 1);
144
201
  }
145
- } else {
202
+ }
203
+ else
204
+ {
146
205
  const ListCell *lc;
147
206
 
148
207
  foreach(lc, node)
@@ -155,15 +214,28 @@ _fingerprintList(FingerprintContext *ctx, const List *node, const void *parent,
155
214
  }
156
215
 
157
216
  static void
158
- _fingerprintInitContext(FingerprintContext *ctx, bool write_tokens) {
217
+ _fingerprintInitContext(FingerprintContext *ctx, FingerprintContext *parent, bool write_tokens)
218
+ {
159
219
  ctx->xxh_state = XXH3_createState();
160
220
  if (ctx->xxh_state == NULL) abort();
161
221
  if (XXH3_64bits_reset_withSeed(ctx->xxh_state, PG_QUERY_FINGERPRINT_VERSION) == XXH_ERROR) abort();
162
222
 
163
- if (write_tokens) {
223
+ if (parent != NULL)
224
+ {
225
+ ctx->listsort_cache = parent->listsort_cache;
226
+ }
227
+ else
228
+ {
229
+ ctx->listsort_cache = listsort_cache_create(CurrentMemoryContext, 128, NULL);
230
+ }
231
+
232
+ if (write_tokens)
233
+ {
164
234
  ctx->write_tokens = true;
165
235
  dlist_init(&ctx->tokens);
166
- } else {
236
+ }
237
+ else
238
+ {
167
239
  ctx->write_tokens = false;
168
240
  }
169
241
  }
@@ -219,6 +291,21 @@ _fingerprintNode(FingerprintContext *ctx, const void *obj, const void *parent, c
219
291
  }
220
292
  }
221
293
 
294
+ uint64_t pg_query_fingerprint_node(const void *node)
295
+ {
296
+ FingerprintContext ctx;
297
+ uint64 result;
298
+
299
+ _fingerprintInitContext(&ctx, NULL, false);
300
+ _fingerprintNode(&ctx, node, NULL, NULL, 0);
301
+
302
+ result = XXH3_64bits_digest(ctx.xxh_state);
303
+
304
+ _fingerprintFreeContext(&ctx);
305
+
306
+ return result;
307
+ }
308
+
222
309
  PgQueryFingerprintResult pg_query_fingerprint_with_opts(const char* input, bool printTokens)
223
310
  {
224
311
  MemoryContext ctx = NULL;
@@ -237,7 +324,7 @@ PgQueryFingerprintResult pg_query_fingerprint_with_opts(const char* input, bool
237
324
  FingerprintContext ctx;
238
325
  XXH64_canonical_t chash;
239
326
 
240
- _fingerprintInitContext(&ctx, printTokens);
327
+ _fingerprintInitContext(&ctx, NULL, printTokens);
241
328
 
242
329
  if (parsetree_and_error.tree != NULL) {
243
330
  _fingerprintNode(&ctx, parsetree_and_error.tree, NULL, NULL, 0);
@@ -3,6 +3,8 @@
3
3
 
4
4
  #include <stdbool.h>
5
5
 
6
- PgQueryFingerprintResult pg_query_fingerprint_with_opts(const char* input, bool printTokens);
6
+ extern PgQueryFingerprintResult pg_query_fingerprint_with_opts(const char* input, bool printTokens);
7
+
8
+ extern uint64_t pg_query_fingerprint_node(const void * node);
7
9
 
8
10
  #endif
@@ -1,5 +1,6 @@
1
1
  #include "pg_query.h"
2
2
  #include "pg_query_internal.h"
3
+ #include "pg_query_fingerprint.h"
3
4
 
4
5
  #include "parser/parser.h"
5
6
  #include "parser/scanner.h"
@@ -14,6 +15,7 @@ typedef struct pgssLocationLen
14
15
  {
15
16
  int location; /* start offset in query text */
16
17
  int length; /* length in bytes, or -1 to ignore */
18
+ int param_id; /* Param id to use - if negative prefix, need to abs(..) and add highest_extern_param_id */
17
19
  } pgssLocationLen;
18
20
 
19
21
  /*
@@ -30,14 +32,32 @@ typedef struct pgssConstLocations
30
32
  /* Current number of valid entries in clocations array */
31
33
  int clocations_count;
32
34
 
35
+ /* highest Param id we have assigned, not yet taking into account external param refs */
36
+ int highest_normalize_param_id;
37
+
33
38
  /* highest Param id we've seen, in order to start normalization correctly */
34
39
  int highest_extern_param_id;
35
40
 
36
41
  /* query text */
37
42
  const char * query;
38
43
  int query_len;
44
+
45
+ /* optional recording of assigned or discovered param refs, only active if param_refs is not NULL */
46
+ int *param_refs;
47
+ int param_refs_buf_size;
48
+ int param_refs_count;
39
49
  } pgssConstLocations;
40
50
 
51
+ /*
52
+ * Intermediate working state struct to remember param refs for individual target list elements
53
+ */
54
+ typedef struct FpAndParamRefs
55
+ {
56
+ uint64_t fp;
57
+ int* param_refs;
58
+ int param_refs_count;
59
+ } FpAndParamRefs;
60
+
41
61
  /*
42
62
  * comp_location: comparator for qsorting pgssLocationLen structs by location
43
63
  */
@@ -230,7 +250,8 @@ generate_normalized_query(pgssConstLocations *jstate, int query_loc, int* query_
230
250
  for (i = 0; i < jstate->clocations_count; i++)
231
251
  {
232
252
  int off, /* Offset from start for cur tok */
233
- tok_len; /* Length (in bytes) of that tok */
253
+ tok_len, /* Length (in bytes) of that tok */
254
+ param_id; /* Param ID to be assigned */
234
255
 
235
256
  off = jstate->clocations[i].location;
236
257
  /* Adjust recorded location if we're dealing with partial string */
@@ -250,8 +271,10 @@ generate_normalized_query(pgssConstLocations *jstate, int query_loc, int* query_
250
271
  n_quer_loc += len_to_wrt;
251
272
 
252
273
  /* And insert a param symbol in place of the constant token */
253
- n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d",
254
- i + 1 + jstate->highest_extern_param_id);
274
+ param_id = (jstate->clocations[i].param_id < 0) ?
275
+ jstate->highest_extern_param_id + abs(jstate->clocations[i].param_id) :
276
+ jstate->clocations[i].param_id;
277
+ n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d", param_id);
255
278
 
256
279
  quer_loc = off + tok_len;
257
280
  last_off = off;
@@ -292,6 +315,18 @@ static void RecordConstLocation(pgssConstLocations *jstate, int location)
292
315
  jstate->clocations[jstate->clocations_count].location = location;
293
316
  /* initialize lengths to -1 to simplify fill_in_constant_lengths */
294
317
  jstate->clocations[jstate->clocations_count].length = -1;
318
+ /* by default we assume that we need a new param ref */
319
+ jstate->clocations[jstate->clocations_count].param_id = - jstate->highest_normalize_param_id;
320
+ jstate->highest_normalize_param_id++;
321
+ /* record param ref number if requested */
322
+ if (jstate->param_refs != NULL) {
323
+ jstate->param_refs[jstate->param_refs_count] = jstate->clocations[jstate->clocations_count].param_id;
324
+ jstate->param_refs_count++;
325
+ if (jstate->param_refs_count >= jstate->param_refs_buf_size) {
326
+ jstate->param_refs_buf_size *= 2;
327
+ jstate->param_refs = (int *) repalloc(jstate->param_refs, jstate->param_refs_buf_size * sizeof(int));
328
+ }
329
+ }
295
330
  jstate->clocations_count++;
296
331
  }
297
332
  }
@@ -303,71 +338,188 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate)
303
338
 
304
339
  if (node == NULL) return false;
305
340
 
306
- if (IsA(node, A_Const))
341
+ switch (nodeTag(node))
307
342
  {
308
- RecordConstLocation(jstate, castNode(A_Const, node)->location);
309
- }
310
- else if (IsA(node, ParamRef))
311
- {
312
- /* Track the highest ParamRef number */
313
- if (((ParamRef *) node)->number > jstate->highest_extern_param_id)
314
- jstate->highest_extern_param_id = castNode(ParamRef, node)->number;
315
- }
316
- else if (IsA(node, DefElem))
317
- {
318
- DefElem * defElem = (DefElem *) node;
319
- if (defElem->arg != NULL && IsA(defElem->arg, String)) {
320
- for (int i = defElem->location; i < jstate->query_len; i++) {
321
- if (jstate->query[i] == '\'') {
322
- RecordConstLocation(jstate, i);
323
- break;
324
- }
343
+ case T_A_Const:
344
+ RecordConstLocation(jstate, castNode(A_Const, node)->location);
345
+ break;
346
+ case T_ParamRef:
347
+ {
348
+ /* Track the highest ParamRef number */
349
+ if (((ParamRef *) node)->number > jstate->highest_extern_param_id)
350
+ jstate->highest_extern_param_id = castNode(ParamRef, node)->number;
351
+
352
+ if (jstate->param_refs != NULL) {
353
+ jstate->param_refs[jstate->param_refs_count] = ((ParamRef *) node)->number;
354
+ jstate->param_refs_count++;
355
+ if (jstate->param_refs_count >= jstate->param_refs_buf_size) {
356
+ jstate->param_refs_buf_size *= 2;
357
+ jstate->param_refs = (int *) repalloc(jstate->param_refs, jstate->param_refs_buf_size * sizeof(int));
358
+ }
359
+ }
325
360
  }
326
- }
327
- return const_record_walker((Node *) ((DefElem *) node)->arg, jstate);
328
- }
329
- else if (IsA(node, RawStmt))
330
- {
331
- return const_record_walker((Node *) ((RawStmt *) node)->stmt, jstate);
332
- }
333
- else if (IsA(node, VariableSetStmt))
334
- {
335
- return const_record_walker((Node *) ((VariableSetStmt *) node)->args, jstate);
336
- }
337
- else if (IsA(node, CopyStmt))
338
- {
339
- return const_record_walker((Node *) ((CopyStmt *) node)->query, jstate);
340
- }
341
- else if (IsA(node, ExplainStmt))
342
- {
343
- return const_record_walker((Node *) ((ExplainStmt *) node)->query, jstate);
344
- }
345
- else if (IsA(node, CreateRoleStmt))
346
- {
347
- return const_record_walker((Node *) ((CreateRoleStmt *) node)->options, jstate);
348
- }
349
- else if (IsA(node, AlterRoleStmt))
350
- {
351
- return const_record_walker((Node *) ((AlterRoleStmt *) node)->options, jstate);
352
- }
353
- else if (IsA(node, DeclareCursorStmt))
354
- {
355
- return const_record_walker((Node *) ((DeclareCursorStmt *) node)->query, jstate);
356
- }
361
+ break;
362
+ case T_DefElem:
363
+ {
364
+ DefElem * defElem = (DefElem *) node;
365
+ if (defElem->arg != NULL && IsA(defElem->arg, String)) {
366
+ for (int i = defElem->location; i < jstate->query_len; i++) {
367
+ if (jstate->query[i] == '\'') {
368
+ RecordConstLocation(jstate, i);
369
+ break;
370
+ }
371
+ }
372
+ }
373
+ return const_record_walker((Node *) ((DefElem *) node)->arg, jstate);
374
+ }
375
+ break;
376
+ case T_RawStmt:
377
+ return const_record_walker((Node *) ((RawStmt *) node)->stmt, jstate);
378
+ case T_VariableSetStmt:
379
+ return const_record_walker((Node *) ((VariableSetStmt *) node)->args, jstate);
380
+ case T_CopyStmt:
381
+ return const_record_walker((Node *) ((CopyStmt *) node)->query, jstate);
382
+ case T_ExplainStmt:
383
+ return const_record_walker((Node *) ((ExplainStmt *) node)->query, jstate);
384
+ case T_CreateRoleStmt:
385
+ return const_record_walker((Node *) ((CreateRoleStmt *) node)->options, jstate);
386
+ case T_AlterRoleStmt:
387
+ return const_record_walker((Node *) ((AlterRoleStmt *) node)->options, jstate);
388
+ case T_DeclareCursorStmt:
389
+ return const_record_walker((Node *) ((DeclareCursorStmt *) node)->query, jstate);
390
+ case T_TypeName:
391
+ /* Don't normalize constants in typmods or arrayBounds */
392
+ return false;
393
+ case T_SelectStmt:
394
+ {
395
+ SelectStmt *stmt = (SelectStmt *) node;
396
+ ListCell *lc;
397
+ List *fp_and_param_refs_list = NIL;
398
+
399
+ if (const_record_walker((Node *) stmt->distinctClause, jstate))
400
+ return true;
401
+ if (const_record_walker((Node *) stmt->intoClause, jstate))
402
+ return true;
403
+ foreach(lc, stmt->targetList)
404
+ {
405
+ ResTarget *res_target = lfirst_node(ResTarget, lc);
406
+ FpAndParamRefs *fp_and_param_refs = palloc0(sizeof(FpAndParamRefs));
407
+
408
+ /* Save all param refs we encounter or assign */
409
+ jstate->param_refs = palloc0(1 * sizeof(int));
410
+ jstate->param_refs_buf_size = 1;
411
+ jstate->param_refs_count = 0;
412
+
413
+ /* Walk the element */
414
+ if (const_record_walker((Node *) res_target, jstate))
415
+ return true;
416
+
417
+ /* Remember fingerprint and param refs for later */
418
+ fp_and_param_refs->fp = pg_query_fingerprint_node(res_target->val);
419
+ fp_and_param_refs->param_refs = jstate->param_refs;
420
+ fp_and_param_refs->param_refs_count = jstate->param_refs_count;
421
+ fp_and_param_refs_list = lappend(fp_and_param_refs_list, fp_and_param_refs);
422
+
423
+ /* Reset for next element, or stop recording if this is the last element */
424
+ jstate->param_refs = NULL;
425
+ jstate->param_refs_buf_size = 0;
426
+ jstate->param_refs_count = 0;
427
+ }
428
+ if (const_record_walker((Node *) stmt->fromClause, jstate))
429
+ return true;
430
+ if (const_record_walker((Node *) stmt->whereClause, jstate))
431
+ return true;
357
432
 
358
- PG_TRY();
359
- {
360
- result = raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
361
- }
362
- PG_CATCH();
363
- {
364
- MemoryContextSwitchTo(normalize_context);
365
- result = false;
366
- FlushErrorState();
433
+ /*
434
+ * Instead of walking all of groupClause (like raw_expression_tree_walker does),
435
+ * only walk certain items.
436
+ */
437
+ foreach(lc, stmt->groupClause)
438
+ {
439
+ /*
440
+ * Do not walk A_Const values that are simple integers, this avoids
441
+ * turning "GROUP BY 1" into "GROUP BY $n", which obscures an important
442
+ * semantic meaning. This matches how pg_stat_statements handles the
443
+ * GROUP BY clause (i.e. it doesn't touch these constants)
444
+ */
445
+ if (IsA(lfirst(lc), A_Const) && IsA(&castNode(A_Const, lfirst(lc))->val, Integer))
446
+ continue;
447
+
448
+ /*
449
+ * Match up GROUP BY clauses against the target list, to assign the same
450
+ * param refs as used in the target list - this ensures the query is valid,
451
+ * instead of throwing a bogus "columns ... must appear in the GROUP BY
452
+ * clause or be used in an aggregate function" error
453
+ */
454
+ uint64_t fp = pg_query_fingerprint_node(lfirst(lc));
455
+ FpAndParamRefs *fppr = NULL;
456
+ ListCell *lc2;
457
+ foreach(lc2, fp_and_param_refs_list) {
458
+ if (fp == ((FpAndParamRefs *) lfirst(lc2))->fp) {
459
+ fppr = (FpAndParamRefs *) lfirst(lc2);
460
+ foreach_delete_current(fp_and_param_refs_list, lc2);
461
+ break;
462
+ }
463
+ }
464
+
465
+ int prev_cloc_count = jstate->clocations_count;
466
+ if (const_record_walker((Node *) lfirst(lc), jstate))
467
+ return true;
468
+
469
+ if (fppr != NULL && fppr->param_refs_count == jstate->clocations_count - prev_cloc_count) {
470
+ for (int i = prev_cloc_count; i < jstate->clocations_count; i++) {
471
+ jstate->clocations[i].param_id = fppr->param_refs[i - prev_cloc_count];
472
+ }
473
+ jstate->highest_normalize_param_id -= fppr->param_refs_count;
474
+ }
475
+ }
476
+ foreach(lc, stmt->sortClause)
477
+ {
478
+ /* Similarly, don't turn "ORDER BY 1" into "ORDER BY $n" */
479
+ if (IsA(lfirst(lc), SortBy) && IsA(castNode(SortBy, lfirst(lc))->node, A_Const) &&
480
+ IsA(&castNode(A_Const, castNode(SortBy, lfirst(lc))->node)->val, Integer))
481
+ continue;
482
+
483
+ if (const_record_walker((Node *) lfirst(lc), jstate))
484
+ return true;
485
+ }
486
+ if (const_record_walker((Node *) stmt->havingClause, jstate))
487
+ return true;
488
+ if (const_record_walker((Node *) stmt->windowClause, jstate))
489
+ return true;
490
+ if (const_record_walker((Node *) stmt->valuesLists, jstate))
491
+ return true;
492
+ if (const_record_walker((Node *) stmt->limitOffset, jstate))
493
+ return true;
494
+ if (const_record_walker((Node *) stmt->limitCount, jstate))
495
+ return true;
496
+ if (const_record_walker((Node *) stmt->lockingClause, jstate))
497
+ return true;
498
+ if (const_record_walker((Node *) stmt->withClause, jstate))
499
+ return true;
500
+ if (const_record_walker((Node *) stmt->larg, jstate))
501
+ return true;
502
+ if (const_record_walker((Node *) stmt->rarg, jstate))
503
+ return true;
504
+
505
+ return false;
506
+ }
507
+ default:
508
+ {
509
+ PG_TRY();
510
+ {
511
+ return raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
512
+ }
513
+ PG_CATCH();
514
+ {
515
+ MemoryContextSwitchTo(normalize_context);
516
+ FlushErrorState();
517
+ }
518
+ PG_END_TRY();
519
+ }
367
520
  }
368
- PG_END_TRY();
369
521
 
370
- return result;
522
+ return false;
371
523
  }
372
524
 
373
525
  PgQueryNormalizeResult pg_query_normalize(const char* input)
@@ -393,9 +545,13 @@ PgQueryNormalizeResult pg_query_normalize(const char* input)
393
545
  jstate.clocations = (pgssLocationLen *)
394
546
  palloc(jstate.clocations_buf_size * sizeof(pgssLocationLen));
395
547
  jstate.clocations_count = 0;
548
+ jstate.highest_normalize_param_id = 1;
396
549
  jstate.highest_extern_param_id = 0;
397
550
  jstate.query = input;
398
551
  jstate.query_len = query_len;
552
+ jstate.param_refs = NULL;
553
+ jstate.param_refs_buf_size = 0;
554
+ jstate.param_refs_count = 0;
399
555
 
400
556
  /* Walk tree and record const locations */
401
557
  const_record_walker((Node *) tree, &jstate);
@@ -414,6 +570,8 @@ PgQueryNormalizeResult pg_query_normalize(const char* input)
414
570
  error = malloc(sizeof(PgQueryError));
415
571
  error->message = strdup(error_data->message);
416
572
  error->filename = strdup(error_data->filename);
573
+ error->funcname = strdup(error_data->funcname);
574
+ error->context = NULL;
417
575
  error->lineno = error_data->lineno;
418
576
  error->cursorpos = error_data->cursorpos;
419
577
 
@@ -432,6 +590,7 @@ void pg_query_free_normalize_result(PgQueryNormalizeResult result)
432
590
  if (result.error) {
433
591
  free(result.error->message);
434
592
  free(result.error->filename);
593
+ free(result.error->funcname);
435
594
  free(result.error);
436
595
  }
437
596
 
@@ -108,7 +108,7 @@ static PLpgSQL_function *compile_create_function_stmt(CreateFunctionStmt* stmt)
108
108
  }
109
109
  }
110
110
 
111
- assert(proc_source);
111
+ assert(proc_source != NULL);
112
112
 
113
113
  if (stmt->returnType != NULL) {
114
114
  foreach(lc3, stmt->returnType->names)
@@ -179,6 +179,26 @@ static PLpgSQL_function *compile_create_function_stmt(CreateFunctionStmt* stmt)
179
179
  plpgsql_DumpExecTree = false;
180
180
  plpgsql_start_datums();
181
181
 
182
+ /* Setup parameter names */
183
+ foreach(lc, stmt->parameters)
184
+ {
185
+ FunctionParameter *param = lfirst_node(FunctionParameter, lc);
186
+ if (param->name != NULL)
187
+ {
188
+ char buf[32];
189
+ PLpgSQL_type *argdtype;
190
+ PLpgSQL_variable *argvariable;
191
+ PLpgSQL_nsitem_type argitemtype;
192
+ snprintf(buf, sizeof(buf), "$%d", foreach_current_index(lc) + 1);
193
+ argdtype = plpgsql_build_datatype(UNKNOWNOID, -1, InvalidOid, NULL);
194
+ argvariable = plpgsql_build_variable(param->name ? param->name : buf, 0, argdtype, false);
195
+ argitemtype = argvariable->dtype == PLPGSQL_DTYPE_VAR ? PLPGSQL_NSTYPE_VAR : PLPGSQL_NSTYPE_REC;
196
+ plpgsql_ns_additem(argitemtype, argvariable->dno, buf);
197
+ if (param->name != NULL)
198
+ plpgsql_ns_additem(argitemtype, argvariable->dno, param->name);
199
+ }
200
+ }
201
+
182
202
  /* Set up as though in a function returning VOID */
183
203
  function->fn_rettype = VOIDOID;
184
204
  function->fn_retset = is_setof;
@@ -14,7 +14,7 @@ VALUE pg_query_ruby_fingerprint(VALUE self, VALUE input);
14
14
  VALUE pg_query_ruby_scan(VALUE self, VALUE input);
15
15
  VALUE pg_query_ruby_hash_xxh3_64(VALUE self, VALUE input, VALUE seed);
16
16
 
17
- void Init_pg_query(void)
17
+ __attribute__((visibility ("default"))) void Init_pg_query(void)
18
18
  {
19
19
  VALUE cPgQuery;
20
20
 
@@ -1 +1,2 @@
1
1
  _Init_pg_query
2
+ Init_pg_query