kumi 0.0.20 → 0.0.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,234 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Analyzer
6
- module Passes
7
- # RESPONSIBILITY: Infer types for all declarations based on expression analysis
8
- # DEPENDENCIES: Toposorter (needs evaluation_order), DeclarationValidator (needs declarations)
9
- # PRODUCES: inferred_types hash mapping declaration names to inferred types
10
- # INTERFACE: new(schema, state).run(errors)
11
- class TypeInferencerPass < PassBase
12
- def run(errors)
13
- types = {}
14
- topo_order = get_state(:evaluation_order)
15
- definitions = get_state(:declarations)
16
-
17
- # Get broadcast metadata from broadcast detector
18
- broadcast_meta = get_state(:broadcasts, required: false) || {}
19
-
20
- # Process declarations in topological order to ensure dependencies are resolved
21
- topo_order.each do |name|
22
- decl = definitions[name]
23
- next unless decl
24
-
25
- # Check if this declaration is marked as vectorized
26
- if broadcast_meta[:vectorized_operations]&.key?(name)
27
- # Infer the element type and wrap in array
28
- element_type = infer_vectorized_element_type(decl.expression, types, broadcast_meta)
29
- types[name] = decl.is_a?(Kumi::Syntax::TraitDeclaration) ? { array: :boolean } : { array: element_type }
30
- else
31
- # Normal type inference
32
- inferred_type = infer_expression_type(decl.expression, types, broadcast_meta, name)
33
- types[name] = inferred_type
34
- end
35
- # rescue StandardError => e
36
- # report_type_error(errors, "Type inference failed: #{e.message}", location: decl&.loc)
37
- end
38
-
39
- state.with(:inferred_types, types)
40
- end
41
-
42
- private
43
-
44
- def infer_expression_type(expr, type_context = {}, broadcast_metadata = {}, current_decl_name = nil)
45
- case expr
46
- when Literal
47
- Types.infer_from_value(expr.value)
48
- when InputReference
49
- # Look up type from field metadata
50
- input_meta = get_state(:input_metadata, required: false) || {}
51
- meta = input_meta[expr.name]
52
- meta&.dig(:type) || :any
53
- when DeclarationReference
54
- type_context[expr.name] || :any
55
- when CallExpression
56
- infer_call_type(expr, type_context, broadcast_metadata, current_decl_name)
57
- when ArrayExpression
58
- infer_list_type(expr, type_context, broadcast_metadata, current_decl_name)
59
- when CascadeExpression
60
- infer_cascade_type(expr, type_context, broadcast_metadata, current_decl_name)
61
- when InputElementReference
62
- # Element reference returns the field type
63
- infer_element_reference_type(expr)
64
- else
65
- :any
66
- end
67
- end
68
-
69
- def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
70
- fn_name = call_expr.fn_name
71
- args = call_expr.args
72
-
73
- # Check broadcast metadata first
74
- if current_decl_name && broadcast_metadata[:vectorized_values]&.key?(current_decl_name)
75
- # This declaration is marked as vectorized, so it produces an array
76
- element_type = infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
77
- return { array: element_type }
78
- end
79
-
80
- if current_decl_name && broadcast_metadata[:reducer_values]&.key?(current_decl_name)
81
- # This declaration is marked as a reducer, get the result from the function
82
- return infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
83
- end
84
-
85
- # Check if function exists in registry
86
- unless Kumi::Registry.supported?(fn_name)
87
- # Don't push error here - let existing TypeChecker handle it
88
- return :any
89
- end
90
-
91
- signature = Kumi::Registry.signature(fn_name)
92
-
93
- # Validate arity if not variable
94
- if signature[:arity] >= 0 && args.size != signature[:arity]
95
- # Don't push error here - let existing TypeChecker handle it
96
- return :any
97
- end
98
-
99
- # Infer argument types
100
- arg_types = args.map { |arg| infer_expression_type(arg, type_context, broadcast_metadata, current_decl_name) }
101
-
102
- # Validate parameter types (warn but don't fail)
103
- param_types = signature[:param_types] || []
104
- if signature[:arity] >= 0 && param_types.size.positive?
105
- arg_types.each_with_index do |arg_type, i|
106
- expected_type = param_types[i] || param_types.last
107
- next if expected_type.nil?
108
-
109
- # unless Types.compatible?(arg_type, expected_type)
110
- # Could add warning here in future, but for now just infer best type
111
- # end
112
- end
113
- end
114
-
115
- signature[:return_type] || :any
116
- end
117
-
118
- def infer_vectorized_element_type(call_expr, _type_context, _broadcast_metadata)
119
- # For vectorized arithmetic operations, infer the element type
120
- # For now, assume arithmetic operations on floats produce floats
121
- case call_expr.fn_name
122
- when :multiply, :add, :subtract, :divide
123
- :float
124
- else
125
- :any
126
- end
127
- end
128
-
129
- def infer_function_return_type(fn_name, _args, _type_context, _broadcast_metadata)
130
- # Get the function signature
131
- return :any unless Kumi::Registry.supported?(fn_name)
132
-
133
- signature = Kumi::Registry.signature(fn_name)
134
- signature[:return_type] || :any
135
- end
136
-
137
- def infer_list_type(list_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
138
- return Types.array(:any) if list_expr.elements.empty?
139
-
140
- element_types = list_expr.elements.map do |elem|
141
- infer_expression_type(elem, type_context, broadcast_metadata, current_decl_name)
142
- end
143
-
144
- # Try to unify all element types
145
- unified_type = element_types.reduce { |acc, type| Types.unify(acc, type) }
146
- Types.array(unified_type)
147
- rescue StandardError
148
- # If unification fails, fall back to generic array
149
- Types.array(:any)
150
- end
151
-
152
- def infer_vectorized_element_type(expr, type_context, vectorization_meta)
153
- # For vectorized operations, we need to infer the element type
154
- case expr
155
- when InputElementReference
156
- # Get the field type from metadata
157
- input_meta = get_state(:input_metadata, required: false) || {}
158
- array_name = expr.path.first
159
- field_name = expr.path[1]
160
-
161
- array_meta = input_meta[array_name]
162
- return :any unless array_meta&.dig(:type) == :array
163
-
164
- array_meta.dig(:children, field_name, :type) || :any
165
-
166
- when CallExpression
167
- # For arithmetic operations, infer from operands
168
- if %i[add subtract multiply divide].include?(expr.fn_name)
169
- # Get types of operands
170
- arg_types = expr.args.map do |arg|
171
- if arg.is_a?(InputElementReference)
172
- infer_vectorized_element_type(arg, type_context, vectorization_meta)
173
- elsif arg.is_a?(DeclarationReference)
174
- # Get the element type if it's vectorized
175
- ref_type = type_context[arg.name]
176
- if ref_type.is_a?(Hash) && ref_type.key?(:array)
177
- ref_type[:array]
178
- else
179
- ref_type || :any
180
- end
181
- else
182
- infer_expression_type(arg, type_context, vectorization_meta)
183
- end
184
- end
185
-
186
- # Unify types for arithmetic
187
- Types.unify(*arg_types) || :float
188
- else
189
- :any
190
- end
191
-
192
- else
193
- :any
194
- end
195
- end
196
-
197
- def infer_element_reference_type(expr)
198
- # Get array field metadata
199
- input_meta = get_state(:input_metadata, required: false) || {}
200
-
201
- return :any unless expr.path.size >= 2
202
-
203
- array_name = expr.path.first
204
- field_name = expr.path[1]
205
-
206
- array_meta = input_meta[array_name]
207
- return :any unless array_meta&.dig(:type) == :array
208
-
209
- # Get the field type from children metadata
210
- field_type = array_meta.dig(:children, field_name, :type) || :any
211
-
212
- # Return array of field type (vectorized)
213
- { array: field_type }
214
- end
215
-
216
- def infer_cascade_type(cascade_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
217
- return :any if cascade_expr.cases.empty?
218
-
219
- result_types = cascade_expr.cases.map do |case_stmt|
220
- infer_expression_type(case_stmt.result, type_context, broadcast_metadata, current_decl_name)
221
- end
222
-
223
- # Reduce all possible types into a single unified type
224
- result_types.reduce { |unified, type| Types.unify(unified, type) } || :any
225
- rescue StandardError
226
- # Check if unification fails, fall back to base type
227
- # TODO: understand if this right to fallback or we should raise
228
- :any
229
- end
230
- end
231
- end
232
- end
233
- end
234
- end
@@ -1,137 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- # Base compiler class with shared compilation logic between Ruby and JS compilers
6
- class CompilerBase
7
- # Map node classes to compiler methods
8
- DISPATCH = {
9
- Kumi::Syntax::Literal => :compile_literal,
10
- Kumi::Syntax::InputReference => :compile_field_node,
11
- Kumi::Syntax::InputElementReference => :compile_element_field_reference,
12
- Kumi::Syntax::DeclarationReference => :compile_binding_node,
13
- Kumi::Syntax::ArrayExpression => :compile_list,
14
- Kumi::Syntax::CallExpression => :compile_call,
15
- Kumi::Syntax::CascadeExpression => :compile_cascade
16
- }.freeze
17
-
18
- def initialize(syntax_tree, analyzer_result)
19
- @schema = syntax_tree
20
- @analysis = analyzer_result
21
- end
22
-
23
- # Shared compilation logic
24
-
25
- def build_index
26
- @index = {}
27
- @schema.values.each { |a| @index[a.name] = a }
28
- @schema.traits.each { |t| @index[t.name] = t }
29
- end
30
-
31
- def determine_operation_mode_for_path(_path)
32
- # Use pre-computed operation mode from analysis
33
- compilation_meta = @analysis.state[:broadcasts]&.dig(:compilation_metadata, @current_declaration)
34
- compilation_meta&.dig(:operation_mode) || :broadcast
35
- end
36
-
37
- def vectorized_operation?(expr)
38
- # Use pre-computed vectorization decision from analysis
39
- compilation_meta = @analysis.state[:broadcasts]&.dig(:compilation_metadata, @current_declaration)
40
- return false unless compilation_meta
41
-
42
- # Check if current declaration is vectorized
43
- if compilation_meta[:is_vectorized]
44
- # For vectorized declarations, check if this specific operation should be vectorized
45
- vectorized_ops = @analysis.state[:broadcasts][:vectorized_operations] || {}
46
- current_decl_info = vectorized_ops[@current_declaration]
47
-
48
- # For cascade declarations, check individual operations within them
49
- return true if current_decl_info && current_decl_info[:operation] == expr.fn_name
50
-
51
- # For cascade_with_vectorized_conditions_or_results, allow nested operations
52
- return true if current_decl_info && current_decl_info[:source] == :cascade_with_vectorized_conditions_or_results
53
-
54
- # Check if this is a direct vectorized operation
55
- return true if current_decl_info && current_decl_info[:operation]
56
- end
57
-
58
- # Fallback: Reduction functions are NOT vectorized operations - they consume arrays
59
- return false if Kumi::Registry.reducer?(expr.fn_name)
60
-
61
- # Use pre-computed vectorization context for remaining cases
62
- compilation_meta.dig(:vectorization_context, :needs_broadcasting) || false
63
- end
64
-
65
- def is_cascade_vectorized?(_expr)
66
- # Use metadata to determine if this cascade is vectorized
67
- broadcast_meta = @analysis.state[:broadcasts]
68
- cascade_info = @current_declaration && broadcast_meta&.dig(:vectorized_operations, @current_declaration)
69
- cascade_info && cascade_info[:source] == :cascade_with_vectorized_conditions_or_results
70
- end
71
-
72
- def get_cascade_compilation_metadata
73
- compilation_meta = @analysis.state[:broadcasts]&.dig(:compilation_metadata, @current_declaration)
74
- cascade_info = compilation_meta&.dig(:cascade_info) || {}
75
- [compilation_meta, cascade_info]
76
- end
77
-
78
- def get_cascade_strategy
79
- @analysis.state[:broadcasts][:cascade_strategies][@current_declaration]
80
- end
81
-
82
- def get_function_call_strategy
83
- compilation_meta = @analysis.state[:broadcasts]&.dig(:compilation_metadata, @current_declaration)
84
- compilation_meta&.dig(:function_call_strategy) || {}
85
- end
86
-
87
- def needs_flattening?
88
- function_strategy = get_function_call_strategy
89
- function_strategy[:flattening_required]
90
- end
91
-
92
- def get_flattening_info
93
- @analysis.state[:broadcasts][:flattening_declarations][@current_declaration]
94
- end
95
-
96
- def get_flatten_argument_indices
97
- compilation_meta = @analysis.state[:broadcasts]&.dig(:compilation_metadata, @current_declaration)
98
- compilation_meta&.dig(:function_call_strategy, :flatten_argument_indices) || []
99
- end
100
-
101
- # Dispatch to the appropriate compile_* method
102
- def compile_expr(expr)
103
- method = DISPATCH.fetch(expr.class)
104
- send(method, expr)
105
- end
106
-
107
- # Abstract methods to be implemented by subclasses
108
- def compile_literal(expr)
109
- raise NotImplementedError, "Subclasses must implement compile_literal"
110
- end
111
-
112
- def compile_field_node(expr)
113
- raise NotImplementedError, "Subclasses must implement compile_field_node"
114
- end
115
-
116
- def compile_element_field_reference(expr)
117
- raise NotImplementedError, "Subclasses must implement compile_element_field_reference"
118
- end
119
-
120
- def compile_binding_node(expr)
121
- raise NotImplementedError, "Subclasses must implement compile_binding_node"
122
- end
123
-
124
- def compile_list(expr)
125
- raise NotImplementedError, "Subclasses must implement compile_list"
126
- end
127
-
128
- def compile_call(expr)
129
- raise NotImplementedError, "Subclasses must implement compile_call"
130
- end
131
-
132
- def compile_cascade(expr)
133
- raise NotImplementedError, "Subclasses must implement compile_cascade"
134
- end
135
- end
136
- end
137
- end
@@ -1,254 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Explain
6
- class ExplanationGenerator
7
- def initialize(syntax_tree, analysis_state, inputs, registry: Kumi::Registry)
8
- @syntax_tree = syntax_tree
9
- @state = analysis_state
10
- @inputs = inputs
11
- @definitions = analysis_state[:declarations] || {}
12
- @registry = registry
13
-
14
- @program = Kumi::Runtime::Executable.from_analysis(@state, registry: nil)
15
- @session = @program.read(@inputs, mode: :ruby)
16
- end
17
-
18
- def explain(target_name)
19
- decl = @definitions[target_name] or raise ArgumentError, "Unknown declaration: #{target_name}"
20
- expr = decl.expression
21
- value = @session.get(target_name)
22
-
23
- prefix = "#{target_name} = "
24
- expr_str = format_expression(expr, indent_context: prefix.length)
25
-
26
- "#{prefix}#{expr_str} => #{format_value(value)}"
27
- end
28
-
29
- private
30
-
31
- # ---------- formatting ----------
32
-
33
- def format_expression(expr, indent_context: 0, nested: false)
34
- case expr
35
- when Kumi::Syntax::InputReference
36
- "input.#{expr.name}"
37
- when Kumi::Syntax::InputElementReference
38
- "input.#{expr.path.join('.')}"
39
- when Kumi::Syntax::DeclarationReference
40
- expr.name.to_s
41
- when Kumi::Syntax::Literal
42
- format_value(expr.value)
43
- when Kumi::Syntax::ArrayExpression
44
- "[" + expr.elements.map { |e| format_expression(e, indent_context:, nested:) }.join(", ") + "]"
45
- when Kumi::Syntax::CascadeExpression
46
- format_cascade(expr, indent_context:)
47
- when Kumi::Syntax::CallExpression
48
- format_call(expr, indent_context:, nested:)
49
- else
50
- expr.class.name.split("::").last
51
- end
52
- end
53
-
54
- def format_call(expr, indent_context:, nested:)
55
- fn = expr.fn_name
56
- if pretty_print?(fn)
57
- format_pretty(expr, fn, indent_context:, nested:)
58
- else
59
- format_generic(expr, indent_context:)
60
- end
61
- end
62
-
63
- def pretty_print?(fn)
64
- %i[add subtract multiply divide == != > < >= <= and or not].include?(fn)
65
- end
66
-
67
- def format_pretty(expr, fn, indent_context:, nested:)
68
- if needs_eval?(expr.args) && !nested
69
- if chain_of_same_op?(expr, fn)
70
- ops = flatten_chain(expr, fn)
71
- sym = op_symbol(fn)
72
- sym_args = ops.map { |a| format_expression(a, indent_context:, nested: true) }
73
- eval_args = ops.map { |a| eval_arg_for_display(a) }
74
- "#{sym_args.join(" #{sym} ")} = #{eval_args.join(" #{sym} ")}"
75
- else
76
- sym_args = expr.args.map { |a| format_expression(a, indent_context:, nested: true) }
77
- eval_args = expr.args.map { |a| eval_arg_for_display(a) }
78
- display_fmt(fn, sym_args) + " = " + display_fmt(fn, eval_args)
79
- end
80
- else
81
- display_fmt(fn, expr.args.map { |a| format_expression(a, indent_context:, nested: true) })
82
- end
83
- end
84
-
85
- def format_generic(expr, indent_context:)
86
- parts = expr.args.map do |a|
87
- desc = format_expression(a, indent_context:)
88
- if literalish?(a)
89
- desc
90
- else
91
- val = evaluate(a)
92
- "#{desc} = #{format_value(val)}"
93
- end
94
- end
95
- if parts.length > 1
96
- indent = " " * (indent_context + expr.fn_name.to_s.length + 1)
97
- "#{expr.fn_name}(#{parts.join(",\n#{indent}")})"
98
- else
99
- "#{expr.fn_name}(#{parts.join(', ')})"
100
- end
101
- end
102
-
103
- def format_cascade(expr, indent_context:)
104
- lines = []
105
- expr.cases.each do |c|
106
- cond_val = evaluate(c.condition)
107
- cond_desc = format_expression(c.condition, indent_context:)
108
- res_desc = format_expression(c.result, indent_context:)
109
- lines << " #{cond_val ? '✓' : '✗'} on #{cond_desc}, #{res_desc}"
110
- break if cond_val
111
- end
112
- "\n" + lines.join("\n")
113
- end
114
-
115
- def literalish?(expr)
116
- expr.is_a?(Kumi::Syntax::Literal) ||
117
- (expr.is_a?(Kumi::Syntax::ArrayExpression) && expr.elements.all?(Kumi::Syntax::Literal))
118
- end
119
-
120
- def needs_eval?(args)
121
- args.any? { |a| !literalish?(a) }
122
- end
123
-
124
- def chain_of_same_op?(expr, fn) = expr.args.any? { |a| a.is_a?(Kumi::Syntax::CallExpression) && a.fn_name == fn }
125
-
126
- def flatten_chain(expr, fn)
127
- expr.args.flat_map do |a|
128
- a.is_a?(Kumi::Syntax::CallExpression) && a.fn_name == fn ? flatten_chain(a, fn) : [a]
129
- end
130
- end
131
-
132
- def op_symbol(fn)
133
- { add: "+", subtract: "-", multiply: "×", divide: "÷" }[fn] || fn.to_s
134
- end
135
-
136
- def display_fmt(fn, args)
137
- case fn
138
- when :add then args.join(" + ")
139
- when :subtract then args.join(" - ")
140
- when :multiply then args.join(" × ")
141
- when :divide then args.join(" ÷ ")
142
- when :== then "#{args[0]} == #{args[1]}"
143
- when :!= then "#{args[0]} != #{args[1]}"
144
- when :> then "#{args[0]} > #{args[1]}"
145
- when :< then "#{args[0]} < #{args[1]}"
146
- when :>= then "#{args[0]} >= #{args[1]}"
147
- when :<= then "#{args[0]} <= #{args[1]}"
148
- when :and then args.join(" && ")
149
- when :or then args.join(" || ")
150
- when :not then "!#{args[0]}"
151
- else "#{fn}(#{args.join(', ')})"
152
- end
153
- end
154
-
155
- def eval_arg_for_display(arg)
156
- return format_expression(arg, indent_context: 0, nested: true) if literalish?(arg)
157
-
158
- val = evaluate(arg)
159
- if arg.is_a?(Kumi::Syntax::DeclarationReference)
160
- "(#{format_expression(arg, indent_context: 0, nested: true)} = #{format_value(val)})"
161
- else
162
- format_value(val)
163
- end
164
- end
165
-
166
- def format_value(v)
167
- case v
168
- when Float, Integer then format_number(v)
169
- when String then "\"#{v}\""
170
- when Array then if v.length <= 4
171
- "[#{v.map { |x| format_value(x) }.join(', ')}]"
172
- else
173
- "[#{v.take(4).map { |x| format_value(x) }.join(', ')}, …]"
174
- end
175
- else v.to_s
176
- end
177
- end
178
-
179
- def format_number(n)
180
- return n.to_s unless n.is_a?(Numeric)
181
-
182
- i = n.is_a?(Integer) || n == n.to_i ? n.to_i : nil
183
- return n.to_s unless i
184
-
185
- i.abs >= 1000 ? i.to_s.reverse.gsub(/(\d{3})(?=\d)/, '\\1 ').reverse : i.to_s
186
- end
187
-
188
- # ---------- evaluation (Program + Registry) ----------
189
-
190
- def evaluate(expr)
191
- case expr
192
- when Kumi::Syntax::DeclarationReference
193
- @session.get(expr.name)
194
- when Kumi::Syntax::InputReference
195
- fetch_indifferent(@inputs, expr.name)
196
- when Kumi::Syntax::InputElementReference
197
- dig_path(@inputs, expr.path)
198
- when Kumi::Syntax::Literal
199
- expr.value
200
- when Kumi::Syntax::ArrayExpression
201
- expr.elements.map { |e| evaluate(e) }
202
- when Kumi::Syntax::CascadeExpression
203
- evaluate_cascade(expr)
204
- when Kumi::Syntax::CallExpression
205
- eval_call(expr)
206
- else
207
- raise "Unsupported expression: #{expr.class}"
208
- end
209
- end
210
-
211
- def eval_call(expr)
212
- entry = @registry.entry(expr.fn_name) or raise "Unknown function: #{expr.fn_name}"
213
- fn = entry.fn
214
- args = expr.args.map { |a| evaluate(a) }
215
- fn.call(*args)
216
- end
217
-
218
- def evaluate_cascade(expr)
219
- expr.cases.each do |c|
220
- return evaluate(c.result) if evaluate(c.condition)
221
- end
222
- nil
223
- end
224
-
225
- def fetch_indifferent(h, k)
226
- h[k] || h[k.to_s] || h[k.to_sym]
227
- end
228
-
229
- def dig_path(h, path)
230
- node = h
231
- path.each do |seg|
232
- node = if node.is_a?(Hash)
233
- fetch_indifferent(node, seg)
234
- else
235
- # if arrays are in path, interpret seg as index when Integer-like
236
- seg.is_a?(Integer) ? node[seg] : nil
237
- end
238
- end
239
- node
240
- end
241
- end
242
-
243
- module_function
244
-
245
- def call(schema_class, target_name, inputs:)
246
- syntax_tree = schema_class.instance_variable_get(:@__syntax_tree__)
247
- analysis_state = schema_class.instance_variable_get(:@__analyzer_result__)&.state
248
- raise ArgumentError, "Schema not found or not compiled" unless syntax_tree && analysis_state
249
-
250
- ExplanationGenerator.new(syntax_tree, analysis_state, inputs).explain(target_name)
251
- end
252
- end
253
- end
254
- end