semdiff 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Semdiff
4
+ # ConstantsCompiler is a compiler that folds constant literal
5
+ # expressions into single constants.
6
+ #
7
+ # @example
8
+ # ```ruby
9
+ # # Expression Canonical Form
10
+ # (2 + 3) 5
11
+ # (4 * 2 - 1) 7
12
+ # (8 / 4) + 3 5
13
+ # (1 + 2**3) 9
14
+ # ```
15
+ class ConstantsCompiler < ::Prism::MutationCompiler
16
+ include Prism::DSL
17
+ include CompilerUtils
18
+
19
+ SUPPORTED_OPS = %i[+ - * / % ** & | ^ << >>].freeze
20
+
21
+ def visit_call_node(node)
22
+ receiver = visit(node.receiver)
23
+ arguments = visit(node.arguments)
24
+ result = nil
25
+
26
+ result = fold_constants(node, receiver, arguments) if constant_foldable?(node, receiver, arguments)
27
+
28
+ if result
29
+ CompilerUtils.inherit_newline(node, result)
30
+ else
31
+ # finalized transformations
32
+ node.copy(receiver: receiver, arguments: arguments, block: visit(node.block))
33
+ end
34
+ end
35
+
36
+ def visit_parentheses_node(node)
37
+ if node.multiple_statements?
38
+ statements = visit_all(node.body.body)
39
+ body = node.body.copy(body: statements)
40
+ return node.copy(body: body)
41
+ end
42
+
43
+ single_child = node.body.body.first
44
+ visited = visit(single_child)
45
+
46
+ return visited if visited.is_a?(Prism::IntegerNode)
47
+
48
+ new_body = node.body.copy(body: [visited])
49
+ node.copy(body: new_body)
50
+ end
51
+
52
+ private
53
+
54
+ def constant_foldable?(node, receiver, arguments)
55
+ receiver.is_a?(Prism::IntegerNode) &&
56
+ arguments&.arguments&.size == 1 &&
57
+ arguments.arguments.first.is_a?(Prism::IntegerNode) &&
58
+ SUPPORTED_OPS.include?(node.name)
59
+ end
60
+
61
+ def fold_constants(node, receiver, arguments)
62
+ lhs = receiver.value
63
+ rhs = arguments.arguments.first.value
64
+ op = node.name
65
+
66
+ case op
67
+ when :/, :%
68
+ return nil if rhs.zero?
69
+ when :**
70
+ return nil if rhs.negative?
71
+ end
72
+
73
+ result = begin
74
+ lhs.send(op, rhs)
75
+ rescue StandardError
76
+ nil
77
+ end
78
+
79
+ return nil unless result.is_a?(Integer)
80
+
81
+ flags = integer_base_flag(:decimal) | Prism::NodeFlags::STATIC_LITERAL
82
+ integer_node(
83
+ node_id: node.node_id,
84
+ source: node.send(:source),
85
+ location: node.location,
86
+ flags: flags,
87
+ value: result
88
+ )
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,136 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Semdiff
4
+ # IdentityCompiler is a compiler that simplifies arithmetic identities
5
+ # for variables based on type information.
6
+ #
7
+ # @example
8
+ # ```ruby
9
+ # # Expression Canonical Form
10
+ # x + 0 x # zero addition
11
+ # x - 0 x # zero subtraction
12
+ # 0 - x -x # unary negation
13
+ # x * 1 x # multiplicative identity
14
+ # 0 * x 0 # zero multiplication
15
+ # x / 1 x # division identity
16
+ # x ** 1 x # exponent identity
17
+ # x ** 0 1 # zero exponent
18
+ # +x x # unary plus is noop
19
+ # -(-x) x # double negation
20
+ # ---x -x # multiple negation
21
+ # ```
22
+ class IdentityCompiler < ::Prism::MutationCompiler
23
+ include Prism::DSL
24
+ include CompilerUtils
25
+
26
+ def initialize(node_types)
27
+ @node_types = node_types
28
+ super()
29
+ end
30
+
31
+ def visit_call_node(node)
32
+ lhs = visit(node.receiver)
33
+ arguments = visit(node.arguments)
34
+ rhs = arguments&.arguments&.first
35
+ lhs_types = lhs && @node_types[lhs.node_id]
36
+ rhs_types = rhs && @node_types[rhs.node_id]
37
+ lhs_numeric = lhs_types&.anybits?(NodeTypeFlags::NUMERIC)
38
+ rhs_numeric = rhs_types&.anybits?(NodeTypeFlags::NUMERIC)
39
+ lhs_integer = lhs_numeric && lhs_types.anybits?(NodeTypeFlags::INTEGER)
40
+ rhs_integer = rhs_numeric && rhs_types.anybits?(NodeTypeFlags::INTEGER)
41
+ result = nil
42
+
43
+ case node.name
44
+ when :+
45
+ if lhs_numeric && zero_literal?(rhs)
46
+ # x + 0 => x
47
+ result = lhs
48
+ elsif rhs_numeric && zero_literal?(lhs)
49
+ # 0 + x => x
50
+ result = rhs
51
+ end
52
+ when :+@
53
+ # + x => x
54
+ result = lhs if lhs_numeric
55
+ when :-
56
+ if lhs_numeric && zero_literal?(rhs)
57
+ # x - 0 => x
58
+ result = lhs
59
+ elsif rhs_numeric && zero_literal?(lhs)
60
+ # 0 - x => -x
61
+ result = node.copy(
62
+ receiver: rhs,
63
+ name: :-@,
64
+ arguments: nil
65
+ )
66
+ end
67
+ when :-@
68
+ inner_lhs = lhs
69
+ neg_count = 1
70
+ loop do
71
+ if inner_lhs.type == :parentheses_node && !inner_lhs.multiple_statements?
72
+ # -( - x ) => x
73
+ inner_lhs = inner_lhs.body.body.first
74
+ next
75
+ end
76
+ if inner_lhs.type == :call_node && inner_lhs.name == :-@
77
+ # ---x => -x
78
+ inner_lhs = inner_lhs.receiver
79
+ neg_count += 1
80
+ next
81
+ end
82
+ break
83
+ end
84
+ result = inner_lhs if neg_count.even?
85
+ when :*
86
+ if lhs_numeric && one_literal?(rhs)
87
+ # x * 1 => x
88
+ result = lhs
89
+ elsif rhs_numeric && one_literal?(lhs)
90
+ # 1 * x => x
91
+ result = rhs
92
+ elsif (lhs_integer && zero_literal?(rhs)) ||
93
+ (rhs_integer && zero_literal?(lhs))
94
+ # 0 * x or x * 0 => 0
95
+ # This works for integers, but not for Float
96
+ # because of INFINITY or NaN instead of 0.
97
+ result = rhs && zero_literal?(rhs) ? rhs : lhs
98
+ end
99
+ when :/
100
+ # x / 1 => x
101
+ result = lhs if lhs_numeric && one_literal?(rhs)
102
+ when :**
103
+ if lhs_numeric && one_literal?(rhs)
104
+ # x ** 1 => x
105
+ result = lhs
106
+ elsif lhs_integer && zero_literal?(rhs)
107
+ # x ** 0 => 1
108
+ flags = integer_base_flag(:decimal) | Prism::NodeFlags::STATIC_LITERAL
109
+ result = integer_node(
110
+ source: node.send(:source),
111
+ location: node.location,
112
+ flags: flags,
113
+ value: 1
114
+ )
115
+ end
116
+ end
117
+
118
+ if result
119
+ CompilerUtils.inherit_newline(node, result)
120
+ else
121
+ # finalized transformations
122
+ node.copy(receiver: lhs, arguments: arguments, block: visit(node.block))
123
+ end
124
+ end
125
+
126
+ private
127
+
128
+ def zero_literal?(node)
129
+ node.type == :integer_node && node.value.zero?
130
+ end
131
+
132
+ def one_literal?(node)
133
+ node.type == :integer_node && node.value == 1
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,194 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Semdiff
4
+ # StructuresCompiler is a compiler that normalizes common
5
+ # data structure operations.
6
+ #
7
+ # @example
8
+ # ```ruby
9
+ # # Expression Canonical Form
10
+ # Array.new []
11
+ # Array.new(0) []
12
+ # Hash.new {}
13
+ # Hash[] {}
14
+ # array.push(e) array << e
15
+ # "a".concat("b") "a" << "b"
16
+ # hash.store(key, value) hash[key] = value
17
+ # '' ""
18
+ # ```
19
+ class StructuresCompiler < ::Prism::MutationCompiler
20
+ include Prism::DSL
21
+ include CompilerUtils
22
+
23
+ def visit_call_node(node)
24
+ receiver = visit(node.receiver)
25
+ arguments = visit(node.arguments)
26
+ block = visit(node.block)
27
+
28
+ result = case node.name
29
+ when :new
30
+ transform_constructor(node, receiver, arguments, block)
31
+ when :[]
32
+ transform_class_bracket_method(node, receiver, arguments)
33
+ when :push
34
+ transform_push_to_shovel(node, receiver, arguments)
35
+ when :concat
36
+ transform_concat_to_shovel(node, receiver, arguments)
37
+ when :store
38
+ transform_store_to_bracket_assign(node, receiver, arguments)
39
+ end
40
+
41
+ if result
42
+ CompilerUtils.inherit_newline(node, result)
43
+ else
44
+ node.copy(receiver: receiver, arguments: arguments, block: block)
45
+ end
46
+ end
47
+
48
+ private
49
+
50
+ def transform_constructor(node, receiver, arguments, block)
51
+ return nil unless receiver.is_a?(Prism::ConstantReadNode)
52
+ return nil if block
53
+
54
+ case receiver.name
55
+ when :Array
56
+ transform_array_new(node, arguments)
57
+ when :Hash
58
+ transform_hash_new(node, arguments)
59
+ end
60
+ end
61
+
62
+ def transform_array_new(node, arguments)
63
+ # Array.new -> []
64
+ if arguments.nil? || arguments.arguments.empty?
65
+ return array_node(
66
+ node_id: node.node_id,
67
+ source: node.send(:source),
68
+ location: node.location,
69
+ flags: Prism::NodeFlags::STATIC_LITERAL,
70
+ elements: [],
71
+ opening_loc: node.receiver.location,
72
+ closing_loc: node.message_loc
73
+ )
74
+ end
75
+
76
+ # Array.new(0) -> []
77
+ if arguments.arguments.size == 1 &&
78
+ arguments.arguments.first.is_a?(Prism::IntegerNode) &&
79
+ arguments.arguments.first.value.zero?
80
+ return array_node(
81
+ node_id: node.node_id,
82
+ source: node.send(:source),
83
+ location: node.location,
84
+ flags: Prism::NodeFlags::STATIC_LITERAL,
85
+ elements: [],
86
+ opening_loc: node.receiver.location,
87
+ closing_loc: node.closing_loc
88
+ )
89
+ end
90
+
91
+ nil
92
+ end
93
+
94
+ def transform_hash_new(node, arguments)
95
+ # Hash.new -> {} (only without arguments)
96
+ if arguments.nil? || arguments.arguments.empty?
97
+ return hash_node(
98
+ node_id: node.node_id,
99
+ source: node.send(:source),
100
+ location: node.location,
101
+ flags: Prism::NodeFlags::STATIC_LITERAL,
102
+ elements: [],
103
+ opening_loc: node.receiver.location,
104
+ closing_loc: node.message_loc
105
+ )
106
+ end
107
+
108
+ nil
109
+ end
110
+
111
+ def transform_class_bracket_method(node, receiver, arguments)
112
+ return nil unless receiver.is_a?(Prism::ConstantReadNode)
113
+
114
+ # Hash[] -> {}
115
+ if receiver.name == :Hash && (arguments.nil? || arguments.arguments.empty?)
116
+ return hash_node(
117
+ node_id: node.node_id,
118
+ source: node.send(:source),
119
+ location: node.location,
120
+ flags: Prism::NodeFlags::STATIC_LITERAL,
121
+ elements: [],
122
+ opening_loc: node.opening_loc,
123
+ closing_loc: node.closing_loc
124
+ )
125
+ end
126
+
127
+ nil
128
+ end
129
+
130
+ def transform_push_to_shovel(node, receiver, arguments)
131
+ # array.push(e) -> array << e
132
+ return nil unless arguments&.arguments&.size == 1
133
+
134
+ call_node(
135
+ node_id: node.node_id,
136
+ source: node.send(:source),
137
+ location: node.location,
138
+ flags: 0,
139
+ receiver: receiver,
140
+ call_operator_loc: nil,
141
+ name: :<<,
142
+ message_loc: node.message_loc,
143
+ opening_loc: nil,
144
+ arguments: arguments,
145
+ closing_loc: nil,
146
+ block: nil
147
+ )
148
+ end
149
+
150
+ def transform_concat_to_shovel(node, receiver, arguments)
151
+ # "a".concat("b") -> "a" << "b"
152
+ return nil unless arguments&.arguments&.size == 1
153
+ return nil unless receiver.is_a?(Prism::StringNode)
154
+
155
+ call_node(
156
+ node_id: node.node_id,
157
+ source: node.send(:source),
158
+ location: node.location,
159
+ flags: 0,
160
+ receiver: receiver,
161
+ call_operator_loc: nil,
162
+ name: :<<,
163
+ message_loc: node.message_loc,
164
+ opening_loc: nil,
165
+ arguments: arguments,
166
+ closing_loc: nil,
167
+ block: nil
168
+ )
169
+ end
170
+
171
+ def transform_store_to_bracket_assign(node, receiver, arguments)
172
+ # hash.store(key, value) -> hash[key] = value
173
+ return nil unless arguments&.arguments&.size == 2
174
+
175
+ # NOTE: there seems to be a bug here where flags for the
176
+ # hash[key] = value callnode has 256 (1 << 8) set, we
177
+ # copy it here to pass === check
178
+ call_node(
179
+ node_id: node.node_id,
180
+ source: node.send(:source),
181
+ location: node.location,
182
+ flags: call_node_flag(:attribute_write) | (1 << 8),
183
+ receiver: receiver,
184
+ call_operator_loc: nil,
185
+ name: :[]=,
186
+ message_loc: node.message_loc,
187
+ opening_loc: node.opening_loc,
188
+ arguments: arguments,
189
+ closing_loc: node.closing_loc,
190
+ block: nil
191
+ )
192
+ end
193
+ end
194
+ end
@@ -0,0 +1,207 @@
1
+ # frozen_string_literal: true
2
+
3
+ # require 'prism'
4
+
5
+ module Semdiff
6
+ # This visitor is responsible for mapping AST nodes to type information.
7
+ # Returns a Hash of `node_id => flags` : `{10 => 7, 11 => 7}`
8
+ class TypeVisitor < ::Prism::Visitor
9
+ include ::Typeguard::TypeModel::Definitions
10
+
11
+ UNSAFE_NUMERIC_CLASSES = %i[Date::Infinity].freeze
12
+ NUMERIC_CLASSES = %i[Numeric Integer Float Rational Complex].freeze
13
+ ENUMERABLE_CLASSES = %i[
14
+ Enumerable ARGF Array Dir Enumerator ENV Hash
15
+ IO Range Struct CSV CSV::Table CSV::Row Set
16
+ ].freeze
17
+
18
+ def initialize(types)
19
+ @types = types
20
+ @mods = map_mods(@types)
21
+ @scope = []
22
+ @mod = nil
23
+ @method = nil
24
+ @node_types = {}
25
+ super()
26
+ end
27
+
28
+ def visit_program_node(node)
29
+ super
30
+ @node_types
31
+ end
32
+
33
+ def visit_module_node(node)
34
+ process_mod_node(node)
35
+ end
36
+
37
+ def visit_class_node(node)
38
+ process_mod_node(node)
39
+ end
40
+
41
+ def visit_def_node(node)
42
+ # Either look in the current mod or root
43
+ haystack = @mod&.children || (@types if @scope.empty?)
44
+ # NOTE: If an explicit receiver is provided, we look
45
+ # for singleton methods on the current mod,
46
+ # assuming that the receiver is self. Could be
47
+ # expanded by resolving the receiver if not self.
48
+ scope = node.receiver ? :class : :instance
49
+ needle = node.name
50
+ method = haystack.find do |e|
51
+ e.is_a?(MethodDefinition) && e.name == needle && e.scope == scope
52
+ end
53
+ @method = method
54
+ super
55
+ @method = nil
56
+ end
57
+
58
+ def visit_class_variable_read_node(node)
59
+ types = find_var_types(:class, node.name)
60
+ store_node_types(node, types)
61
+ super
62
+ end
63
+
64
+ def visit_constant_read_node(node)
65
+ types = find_var_types(:constant, node.name)
66
+ store_node_types(node, types)
67
+ super
68
+ end
69
+
70
+ def visit_instance_variable_read_node(node)
71
+ types = find_var_types(:instance, node.name)
72
+ store_node_types(node, types)
73
+ super
74
+ end
75
+
76
+ def visit_local_variable_read_node(node)
77
+ unless @method.nil? || @method.parameters.empty?
78
+ params = @method.parameters.find { |p| p.name == node.name }
79
+ types = params&.types
80
+ store_node_types(node, types)
81
+ end
82
+ super
83
+ end
84
+
85
+ def visit_call_node(node)
86
+ if node.variable_call? || node.receiver&.type == :self_node
87
+ # In the case where a singleton method is defined on the
88
+ # current mod: `class << self; attr_reader :a; end` we
89
+ # might have type information (a var) on the mod.
90
+ types = find_var_types(:self, node.name)
91
+ store_node_types(node, types)
92
+ end
93
+ super
94
+ end
95
+
96
+ private
97
+
98
+ def map_mods(types)
99
+ mods = {}
100
+ return mods if types.nil?
101
+
102
+ queue = Array(types).select { |o| o.respond_to?(:vars) }
103
+ visited = Set.new
104
+ while (mod = queue.shift)
105
+ next unless mod && visited.add?(mod.object_id)
106
+
107
+ mods[mod.name] = mod
108
+ queue.concat(mod.children.select { |c| c.respond_to?(:vars) })
109
+ end
110
+ mods
111
+ end
112
+
113
+ def compute_fqn(path_node, scope)
114
+ fqn = scope.join('::')
115
+
116
+ case path_node.type
117
+ when :constant_read_node
118
+ name_s = path_node.name.to_s
119
+ fqn.empty? ? name_s : "#{fqn}::#{name_s}"
120
+ when :constant_path_node
121
+ path = path_node.full_name
122
+ return path[2..] if path.start_with?('::')
123
+
124
+ candidate = fqn.empty? ? path : "#{fqn}::#{path}"
125
+ if @mods.include?(candidate) || !@mods.include?(path)
126
+ candidate
127
+ else
128
+ path
129
+ end
130
+ else
131
+ raise "Unsupported path node type for #{path_node.type}"
132
+ end
133
+ end
134
+
135
+ def process_mod_node(node)
136
+ fqn = compute_fqn(node.constant_path, @scope)
137
+ prev_scope = @scope
138
+ @scope = fqn.split('::')
139
+ prev_mod = @mod
140
+ @mod = @mods[fqn]
141
+ visit_child_nodes(node)
142
+ @scope = prev_scope
143
+ @mod = prev_mod
144
+ end
145
+
146
+ def find_var_types(scope, name)
147
+ return nil if @mod.nil? || @mod.vars.empty?
148
+
149
+ var = @mod.vars.find { |v| v.scope == scope && v.name == name }
150
+ var&.types
151
+ end
152
+
153
+ def store_node_types(node, types)
154
+ return if types.nil?
155
+
156
+ flags = 0
157
+ if types.one?
158
+ flags |= NodeTypeFlags::ONE
159
+ t = types.first
160
+ if t.shape == :basic
161
+ flags |= NodeTypeFlags::BASIC
162
+ case t.kind
163
+ when :Integer
164
+ flags |= NodeTypeFlags::INTEGER
165
+ when :Float
166
+ flags |= NodeTypeFlags::FLOAT
167
+ when :boolean
168
+ flags |= NodeTypeFlags::BOOLEAN
169
+ when :String
170
+ flags |= NodeTypeFlags::STRING
171
+ end
172
+ end
173
+ end
174
+ is_numeric = types.all? { |e| NUMERIC_CLASSES.include?(e.kind) }
175
+ flags |= NodeTypeFlags::NUMERIC if is_numeric
176
+ is_enumerable = types.all? { |e| ENUMERABLE_CLASSES.include?(e.kind) }
177
+ flags |= NodeTypeFlags::ENUMERABLE if is_enumerable
178
+ @node_types[node.node_id] = flags
179
+ end
180
+ end
181
+
182
+ module NodeTypeFlags
183
+ # a node with a single type [a] instead of a union [a, b]
184
+ ONE = 1 << 0
185
+
186
+ # a node with a basic shape, single constant, no children
187
+ BASIC = 1 << 1
188
+
189
+ # an Integer node
190
+ INTEGER = 1 << 2
191
+
192
+ # a Float node
193
+ FLOAT = 1 << 3
194
+
195
+ # a Booolean node
196
+ BOOLEAN = 1 << 4
197
+
198
+ # a String node
199
+ STRING = 1 << 5
200
+
201
+ # a Numeric subclass node
202
+ NUMERIC = 1 << 6
203
+
204
+ # a node that includes (or extends) Enumerable
205
+ ENUMERABLE = 1 << 7
206
+ end
207
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Semdiff
4
+ module CompilerUtils
5
+ # NOTE: Uses `send` to access protected flags, seems
6
+ # preferable over extending Node(s)
7
+ def self.inherit_newline(parent, node)
8
+ need_newline = parent.newline?
9
+ if need_newline == node.newline?
10
+ node
11
+ else
12
+ flags = node.send(:flags)
13
+ new_flags = if need_newline
14
+ flags | Prism::NodeFlags::NEWLINE
15
+ else
16
+ flags & ~Prism::NodeFlags::NEWLINE
17
+ end
18
+ node.copy(flags: new_flags)
19
+ end
20
+ end
21
+ end
22
+ end