arel_extensions 2.1.5 → 2.1.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -49,6 +49,82 @@ module ArelExtensions
49
49
  'YYYY-MM-DDTHH:MM:SS:MMM' => 126
50
50
  }.freeze
51
51
 
52
+ # Quoting in JRuby + AR < 5 requires special handling for MSSQL.
53
+ #
54
+ # It relied on @connection.quote, which in turn relied on column type for
55
+ # quoting. We need only to rely on the value type.
56
+ #
57
+ # It didn't handle numbers correctly: `quote(1, nil)` translated into
58
+ # `N'1'` which we don't want.
59
+ #
60
+ # The following is adapted from activerecord/lib/active_record/connection_adapters/abstract/quoting.rb
61
+ #
62
+ if RUBY_PLATFORM == 'java' && ActiveRecord::VERSION::MAJOR < 5
63
+ def quote_string(s)
64
+ s.gsub('\\', '\&\&').gsub("'", "''") # ' (for ruby-mode)
65
+ end
66
+
67
+ def quoted_binary(value) # :nodoc:
68
+ "'#{quote_string(value.to_s)}'"
69
+ end
70
+
71
+ def quoted_date(value)
72
+ if value.acts_like?(:time)
73
+ if ActiveRecord::Base.default_timezone == :utc
74
+ value = value.getutc if value.respond_to?(:getutc) && !value.utc?
75
+ else
76
+ value = value.getlocal if value.respond_to?(:getlocal)
77
+ end
78
+ end
79
+
80
+ result = value.to_s(:db)
81
+ if value.respond_to?(:usec) && value.usec > 0
82
+ result << '.' << sprintf('%06d', value.usec)
83
+ else
84
+ result
85
+ end
86
+ end
87
+
88
+ def quoted_true
89
+ 'TRUE'
90
+ end
91
+
92
+ def quoted_false
93
+ 'FALSE'
94
+ end
95
+
96
+ def quoted_time(value) # :nodoc:
97
+ value = value.change(year: 2000, month: 1, day: 1)
98
+ quoted_date(value).sub(/\A\d{4}-\d{2}-\d{2} /, "")
99
+ end
100
+
101
+ def quote value, column = nil
102
+ case value
103
+ when Arel::Nodes::SqlLiteral
104
+ value
105
+ when String, Symbol, ActiveSupport::Multibyte::Chars
106
+ "'#{quote_string(value.to_s)}'"
107
+ when true
108
+ quoted_true
109
+ when false
110
+ quoted_false
111
+ when nil
112
+ 'NULL'
113
+ # BigDecimals need to be put in a non-normalized form and quoted.
114
+ when BigDecimal
115
+ value.to_s('F')
116
+ when Numeric, ActiveSupport::Duration
117
+ value.to_s
118
+ when Date, Time
119
+ "'#{quoted_date(value)}'"
120
+ when Class
121
+ "'#{value}'"
122
+ else
123
+ raise TypeError, "can't quote #{value.class.name}"
124
+ end
125
+ end
126
+ end
127
+
52
128
  # Math Functions
53
129
  def visit_ArelExtensions_Nodes_Ceil o, collector
54
130
  collector << 'CEILING('
@@ -279,6 +355,10 @@ module ArelExtensions
279
355
  end
280
356
 
281
357
  def visit_ArelExtensions_Nodes_Format o, collector
358
+ visit_ArelExtensions_Nodes_FormattedDate o, collector
359
+ end
360
+
361
+ def visit_ArelExtensions_Nodes_FormattedDate o, collector
282
362
  f = ArelExtensions::Visitors::strftime_to_format(o.iso_format, LOADED_VISITOR::DATE_FORMAT_DIRECTIVES)
283
363
  if fmt = LOADED_VISITOR::DATE_CONVERT_FORMATS[f]
284
364
  collector << "CONVERT(VARCHAR(#{f.length})"
@@ -487,6 +567,11 @@ module ArelExtensions
487
567
  collector
488
568
  end
489
569
 
570
+ def visit_Arel_Nodes_RollUp(o, collector)
571
+ collector << "ROLLUP"
572
+ grouping_array_or_grouping_element o, collector
573
+ end
574
+
490
575
  # TODO;
491
576
  def visit_ArelExtensions_Nodes_GroupConcat o, collector
492
577
  collector << '(STRING_AGG('
@@ -647,6 +732,18 @@ module ArelExtensions
647
732
  collector << ')'
648
733
  collector
649
734
  end
735
+
736
+ # Utilized by GroupingSet, Cube & RollUp visitors to
737
+ # handle grouping aggregation semantics
738
+ def grouping_array_or_grouping_element(o, collector)
739
+ if o.expr.is_a? Array
740
+ collector << "( "
741
+ visit o.expr, collector
742
+ collector << " )"
743
+ else
744
+ visit o.expr, collector
745
+ end
746
+ end
650
747
  end
651
748
  end
652
749
  end
@@ -14,6 +14,52 @@ module ArelExtensions
14
14
  '%M' => '%i', '%S' => '%S', '%L' => '', '%N' => '%f', '%z' => ''
15
15
  }.freeze
16
16
 
17
+ # This helper method did not exist in rails < 5.2
18
+ if !Arel::Visitors::MySQL.method_defined?(:collect_nodes_for)
19
+ def collect_nodes_for(nodes, collector, spacer, connector = ", ")
20
+ if nodes&.any?
21
+ collector << spacer
22
+ inject_join nodes, collector, connector
23
+ end
24
+ end
25
+ end
26
+
27
+ # The whole purpose of this override is to fix the behavior of RollUp.
28
+ # All other databases treat RollUp sanely, execpt MySQL which requires
29
+ # that it figures as the last element of a GROUP BY.
30
+ def visit_Arel_Nodes_SelectCore(o, collector)
31
+ collector << "SELECT"
32
+
33
+ collector = collect_optimizer_hints(o, collector) if self.respond_to?(:collect_optimizer_hinsts)
34
+ collector = maybe_visit o.set_quantifier, collector
35
+
36
+ collect_nodes_for o.projections, collector, " "
37
+
38
+ if o.source && !o.source.empty?
39
+ collector << " FROM "
40
+ collector = visit o.source, collector
41
+ end
42
+
43
+ # The actual work
44
+ groups = o.groups
45
+ rollup = groups.select { |g| g.expr.class == Arel::Nodes::RollUp }.map { |r| r.expr.value }
46
+ if rollup && !rollup.empty?
47
+ groups = o.groups.reject { |g| g.expr.class == Arel::Nodes::RollUp }
48
+ groups << Arel::Nodes::RollUp.new(rollup)
49
+ end
50
+ # FIN
51
+
52
+ collect_nodes_for o.wheres, collector, " WHERE ", " AND "
53
+ collect_nodes_for groups, collector, " GROUP BY " # Look ma, I'm viring a group
54
+ collect_nodes_for o.havings, collector, " HAVING ", " AND "
55
+ collect_nodes_for o.windows, collector, " WINDOW "
56
+
57
+ if o.respond_to?(:comment)
58
+ maybe_visit o.comment, collector
59
+ else
60
+ collector
61
+ end
62
+ end
17
63
 
18
64
  # Math functions
19
65
  def visit_ArelExtensions_Nodes_Log10 o, collector
@@ -138,6 +184,11 @@ module ArelExtensions
138
184
  collector
139
185
  end
140
186
 
187
+ def visit_Arel_Nodes_RollUp(o, collector)
188
+ visit o.expr, collector
189
+ collector << " WITH ROLLUP"
190
+ end
191
+
141
192
  def visit_ArelExtensions_Nodes_GroupConcat o, collector
142
193
  collector << 'GROUP_CONCAT('
143
194
  collector = visit o.left, collector
@@ -201,28 +252,22 @@ module ArelExtensions
201
252
  end
202
253
 
203
254
  def visit_ArelExtensions_Nodes_Format o, collector
204
- case o.col_type
255
+ # One use case we met is
256
+ # `case…when…then(valid_date).else(Arel.null).format(…)`.
257
+ #
258
+ # In this case, `o.col_type` is `nil` but we have a legitimate type in
259
+ # the expression to be formatted. The following is a best effort to
260
+ # infer the proper type.
261
+ first = o.expressions[0]
262
+ type =
263
+ o.col_type.nil? \
264
+ && (first.respond_to?(:return_type) && !first&.return_type.nil?) \
265
+ ? first&.return_type \
266
+ : o.col_type
267
+
268
+ case type
205
269
  when :date, :datetime, :time
206
- fmt = ArelExtensions::Visitors::strftime_to_format(o.iso_format, DATE_FORMAT_DIRECTIVES)
207
- collector << 'DATE_FORMAT('
208
- collector << 'CONVERT_TZ(' if o.time_zone
209
- collector = visit o.left, collector
210
- case o.time_zone
211
- when Hash
212
- src_tz, dst_tz = o.time_zone.first
213
- collector << COMMA
214
- collector = visit Arel.quoted(src_tz), collector
215
- collector << COMMA
216
- collector = visit Arel.quoted(dst_tz), collector
217
- collector << ')'
218
- when String
219
- collector << COMMA << "'UTC'" << COMMA
220
- collector = visit Arel.quoted(o.time_zone), collector
221
- collector << ')'
222
- end
223
- collector << COMMA
224
- collector = visit Arel.quoted(fmt), collector
225
- collector << ')'
270
+ visit_ArelExtensions_Nodes_FormattedDate o, collector
226
271
  when :integer, :float, :decimal
227
272
  collector << 'FORMAT('
228
273
  collector = visit o.left, collector
@@ -237,6 +282,29 @@ module ArelExtensions
237
282
  collector
238
283
  end
239
284
 
285
+ def visit_ArelExtensions_Nodes_FormattedDate o, collector
286
+ fmt = ArelExtensions::Visitors::strftime_to_format(o.iso_format, DATE_FORMAT_DIRECTIVES)
287
+ collector << 'DATE_FORMAT('
288
+ collector << 'CONVERT_TZ(' if o.time_zone
289
+ collector = visit o.left, collector
290
+ case o.time_zone
291
+ when Hash
292
+ src_tz, dst_tz = o.time_zone.first
293
+ collector << COMMA
294
+ collector = visit Arel.quoted(src_tz), collector
295
+ collector << COMMA
296
+ collector = visit Arel.quoted(dst_tz), collector
297
+ collector << ')'
298
+ when String
299
+ collector << COMMA << "'UTC'" << COMMA
300
+ collector = visit Arel.quoted(o.time_zone), collector
301
+ collector << ')'
302
+ end
303
+ collector << COMMA
304
+ collector = visit Arel.quoted(fmt), collector
305
+ collector << ')'
306
+ end
307
+
240
308
  def visit_ArelExtensions_Nodes_DateDiff o, collector
241
309
  case o.right_node_type
242
310
  when :ruby_date, :ruby_time, :date, :datetime, :time
@@ -126,6 +126,11 @@ module ArelExtensions
126
126
  collector
127
127
  end
128
128
 
129
+ def visit_Arel_Nodes_RollUp(o, collector)
130
+ collector << "ROLLUP"
131
+ grouping_array_or_grouping_element o, collector
132
+ end
133
+
129
134
  def visit_ArelExtensions_Nodes_GroupConcat o, collector
130
135
  collector << '(LISTAGG('
131
136
  collector = visit o.left, collector
@@ -446,6 +451,10 @@ module ArelExtensions
446
451
  end
447
452
 
448
453
  def visit_ArelExtensions_Nodes_Format o, collector
454
+ visit_ArelExtensions_Nodes_FormattedDate o, collector
455
+ end
456
+
457
+ def visit_ArelExtensions_Nodes_FormattedDate o, collector
449
458
  fmt = ArelExtensions::Visitors::strftime_to_format(o.iso_format, DATE_FORMAT_DIRECTIVES)
450
459
  collector << 'TO_CHAR('
451
460
  collector << 'CAST(' if o.time_zone
@@ -705,6 +714,18 @@ module ArelExtensions
705
714
  collector << ')'
706
715
  collector
707
716
  end
717
+
718
+ # Utilized by GroupingSet, Cube & RollUp visitors to
719
+ # handle grouping aggregation semantics
720
+ def grouping_array_or_grouping_element(o, collector)
721
+ if o.expr.is_a? Array
722
+ collector << "( "
723
+ visit o.expr, collector
724
+ collector << " )"
725
+ else
726
+ visit o.expr, collector
727
+ end
728
+ end
708
729
  end
709
730
  end
710
731
  end
@@ -174,6 +174,10 @@ module ArelExtensions
174
174
  end
175
175
 
176
176
  def visit_ArelExtensions_Nodes_Format o, collector
177
+ visit_ArelExtensions_Nodes_FormattedDate o, collector
178
+ end
179
+
180
+ def visit_ArelExtensions_Nodes_FormattedDate o, collector
177
181
  fmt = ArelExtensions::Visitors::strftime_to_format(o.iso_format, DATE_FORMAT_DIRECTIVES)
178
182
  collector << 'TO_CHAR('
179
183
  collector << '(' if o.time_zone
@@ -289,6 +289,26 @@ module ArelExtensions
289
289
  collector
290
290
  end
291
291
 
292
+ def visit_ArelExtensions_Nodes_FormattedDate o, collector
293
+ case o.col_type
294
+ when :date, :datetime, :time
295
+ collector << 'STRFTIME('
296
+ collector = visit o.right, collector
297
+ collector << COMMA
298
+ collector = visit o.left, collector
299
+ collector << ')'
300
+ when :integer, :float, :decimal
301
+ collector << 'FORMAT('
302
+ collector = visit o.left, collector
303
+ collector << COMMA
304
+ collector = visit o.right, collector
305
+ collector << ')'
306
+ else
307
+ collector = visit o.left, collector
308
+ end
309
+ collector
310
+ end
311
+
292
312
  # comparators
293
313
 
294
314
  def visit_ArelExtensions_Nodes_Cast o, collector
@@ -19,6 +19,14 @@ if defined?(Arel::Visitors::SQLServer)
19
19
  end
20
20
  end
21
21
 
22
+ if defined?(Arel::Visitors::DepthFirst)
23
+ class Arel::Visitors::DepthFirst
24
+ def visit_Arel_SelectManager o
25
+ visit o.ast
26
+ end
27
+ end
28
+ end
29
+
22
30
  if defined?(Arel::Visitors::MSSQL)
23
31
  class Arel::Visitors::MSSQL
24
32
  include ArelExtensions::Visitors::MSSQL
@@ -78,6 +78,8 @@ require 'arel_extensions/nodes/case'
78
78
  require 'arel_extensions/nodes/soundex'
79
79
  require 'arel_extensions/nodes/cast'
80
80
  require 'arel_extensions/nodes/json'
81
+ require 'arel_extensions/nodes/rollup'
82
+ require 'arel_extensions/nodes/select'
81
83
 
82
84
  # It seems like the code in lib/arel_extensions/visitors.rb that is supposed
83
85
  # to inject ArelExtension is not enough. Different versions of the sqlserver
@@ -132,6 +134,10 @@ module Arel
132
134
  ArelExtensions::Nodes::Rand.new
133
135
  end
134
136
 
137
+ def self.rollup(*args)
138
+ Arel::Nodes::RollUp.new(args)
139
+ end
140
+
135
141
  def self.shorten s
136
142
  Base64.urlsafe_encode64(Digest::MD5.new.digest(s)).tr('=', '').tr('-', '_')
137
143
  end
@@ -277,4 +283,14 @@ class Arel::Attributes::Attribute
277
283
  collector = engine.connection.visitor.accept self, collector
278
284
  collector.value
279
285
  end
286
+
287
+ def rollup
288
+ Arel::Nodes::RollUp.new([self])
289
+ end
290
+ end
291
+
292
+ class Arel::Nodes::Node
293
+ def rollup
294
+ Arel::Nodes::RollUp.new([self])
295
+ end
280
296
  end