kumi 0.0.5 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +4 -4
  2. data/CLAUDE.md +51 -6
  3. data/README.md +173 -51
  4. data/{documents → docs}/AST.md +29 -29
  5. data/{documents → docs}/SYNTAX.md +93 -1
  6. data/docs/features/README.md +45 -0
  7. data/docs/features/analysis-cascade-mutual-exclusion.md +89 -0
  8. data/docs/features/analysis-type-inference.md +42 -0
  9. data/docs/features/analysis-unsat-detection.md +71 -0
  10. data/docs/features/array-broadcasting.md +170 -0
  11. data/docs/features/input-declaration-system.md +42 -0
  12. data/docs/features/performance.md +16 -0
  13. data/examples/federal_tax_calculator_2024.rb +11 -6
  14. data/lib/kumi/analyzer/constant_evaluator.rb +1 -1
  15. data/lib/kumi/analyzer/passes/broadcast_detector.rb +251 -0
  16. data/lib/kumi/analyzer/passes/{definition_validator.rb → declaration_validator.rb} +4 -4
  17. data/lib/kumi/analyzer/passes/dependency_resolver.rb +72 -32
  18. data/lib/kumi/analyzer/passes/input_collector.rb +90 -29
  19. data/lib/kumi/analyzer/passes/pass_base.rb +1 -1
  20. data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +9 -9
  21. data/lib/kumi/analyzer/passes/toposorter.rb +42 -6
  22. data/lib/kumi/analyzer/passes/type_checker.rb +32 -10
  23. data/lib/kumi/analyzer/passes/type_inferencer.rb +126 -17
  24. data/lib/kumi/analyzer/passes/unsat_detector.rb +133 -53
  25. data/lib/kumi/analyzer/passes/visitor_pass.rb +2 -2
  26. data/lib/kumi/analyzer.rb +11 -12
  27. data/lib/kumi/compiler.rb +194 -16
  28. data/lib/kumi/constraint_relationship_solver.rb +6 -6
  29. data/lib/kumi/domain/validator.rb +0 -4
  30. data/lib/kumi/explain.rb +20 -20
  31. data/lib/kumi/export/node_registry.rb +26 -12
  32. data/lib/kumi/export/node_serializers.rb +1 -1
  33. data/lib/kumi/function_registry/collection_functions.rb +14 -9
  34. data/lib/kumi/function_registry/function_builder.rb +4 -3
  35. data/lib/kumi/function_registry.rb +8 -2
  36. data/lib/kumi/input/type_matcher.rb +3 -0
  37. data/lib/kumi/input/validator.rb +0 -3
  38. data/lib/kumi/parser/declaration_reference_proxy.rb +36 -0
  39. data/lib/kumi/parser/dsl_cascade_builder.rb +3 -3
  40. data/lib/kumi/parser/expression_converter.rb +6 -6
  41. data/lib/kumi/parser/input_builder.rb +40 -9
  42. data/lib/kumi/parser/input_field_proxy.rb +46 -0
  43. data/lib/kumi/parser/input_proxy.rb +3 -3
  44. data/lib/kumi/parser/nested_input.rb +15 -0
  45. data/lib/kumi/parser/schema_builder.rb +10 -9
  46. data/lib/kumi/parser/sugar.rb +61 -9
  47. data/lib/kumi/syntax/array_expression.rb +15 -0
  48. data/lib/kumi/syntax/call_expression.rb +11 -0
  49. data/lib/kumi/syntax/cascade_expression.rb +11 -0
  50. data/lib/kumi/syntax/case_expression.rb +11 -0
  51. data/lib/kumi/syntax/declaration_reference.rb +11 -0
  52. data/lib/kumi/syntax/hash_expression.rb +11 -0
  53. data/lib/kumi/syntax/input_declaration.rb +12 -0
  54. data/lib/kumi/syntax/input_element_reference.rb +12 -0
  55. data/lib/kumi/syntax/input_reference.rb +12 -0
  56. data/lib/kumi/syntax/literal.rb +11 -0
  57. data/lib/kumi/syntax/trait_declaration.rb +11 -0
  58. data/lib/kumi/syntax/value_declaration.rb +11 -0
  59. data/lib/kumi/vectorization_metadata.rb +108 -0
  60. data/lib/kumi/version.rb +1 -1
  61. metadata +31 -14
  62. data/lib/kumi/domain.rb +0 -8
  63. data/lib/kumi/input.rb +0 -8
  64. data/lib/kumi/syntax/declarations.rb +0 -26
  65. data/lib/kumi/syntax/expressions.rb +0 -34
  66. data/lib/kumi/syntax/terminal_expressions.rb +0 -30
  67. data/lib/kumi/syntax.rb +0 -9
  68. /data/{documents → docs}/DSL.md +0 -0
  69. /data/{documents → docs}/FUNCTIONS.md +0 -0
@@ -3,8 +3,8 @@
3
3
  module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
- # RESPONSIBILITY: Compute topological ordering of declarations from dependency graph
7
- # DEPENDENCIES: :dependency_graph from DependencyResolver, :definitions from NameIndexer
6
+ # RESPONSIBILITY: Compute topological ordering of declarations, allowing safe conditional cycles
7
+ # DEPENDENCIES: :dependency_graph from DependencyResolver, :definitions from NameIndexer, :cascade_metadata from UnsatDetector
8
8
  # PRODUCES: :topo_order - Array of declaration names in evaluation order
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class Toposorter < PassBase
@@ -22,17 +22,26 @@ module Kumi
22
22
  temp_marks = Set.new
23
23
  perm_marks = Set.new
24
24
  order = []
25
+ cascade_metadata = get_state(:cascade_metadata) || {}
25
26
 
26
- visit_node = lambda do |node|
27
+ visit_node = lambda do |node, path = []|
27
28
  return if perm_marks.include?(node)
28
29
 
29
30
  if temp_marks.include?(node)
30
- report_unexpected_cycle(temp_marks, node, errors)
31
- return
31
+ # Check if this is a safe conditional cycle
32
+ cycle_path = path + [node]
33
+ if safe_conditional_cycle?(cycle_path, graph, cascade_metadata)
34
+ # Allow this cycle - it's safe due to cascade mutual exclusion
35
+ return
36
+ else
37
+ report_unexpected_cycle(temp_marks, node, errors)
38
+ return
39
+ end
32
40
  end
33
41
 
34
42
  temp_marks << node
35
- Array(graph[node]).each { |edge| visit_node.call(edge.to) }
43
+ current_path = path + [node]
44
+ Array(graph[node]).each { |edge| visit_node.call(edge.to, current_path) }
36
45
  temp_marks.delete(node)
37
46
  perm_marks << node
38
47
 
@@ -50,6 +59,33 @@ module Kumi
50
59
  order.freeze
51
60
  end
52
61
 
62
+ def safe_conditional_cycle?(cycle_path, graph, cascade_metadata)
63
+ return false if cycle_path.nil? || cycle_path.size < 2
64
+
65
+ # Find where the cycle starts - look for the first occurrence of the repeated node
66
+ last_node = cycle_path.last
67
+ return false if last_node.nil?
68
+
69
+ cycle_start = cycle_path.index(last_node)
70
+ return false unless cycle_start && cycle_start < cycle_path.size - 1
71
+
72
+ cycle_nodes = cycle_path[cycle_start..-1]
73
+
74
+ # Check if all edges in the cycle are conditional
75
+ cycle_nodes.each_cons(2) do |from, to|
76
+ edges = graph[from] || []
77
+ edge = edges.find { |e| e.to == to }
78
+
79
+ return false unless edge&.conditional
80
+
81
+ # Check if the cascade has mutually exclusive conditions
82
+ cascade_meta = cascade_metadata[edge.cascade_owner]
83
+ return false unless cascade_meta&.dig(:all_mutually_exclusive)
84
+ end
85
+
86
+ true
87
+ end
88
+
53
89
  def report_unexpected_cycle(temp_marks, current_node, errors)
54
90
  cycle_path = temp_marks.to_a.join(" → ") + " → #{current_node}"
55
91
 
@@ -9,7 +9,7 @@ module Kumi
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class TypeChecker < VisitorPass
11
11
  def run(errors)
12
- visit_nodes_of_type(Expressions::CallExpression, errors: errors) do |node, _decl, errs|
12
+ visit_nodes_of_type(Kumi::Syntax::CallExpression, errors: errors) do |node, _decl, errs|
13
13
  validate_function_call(node, errs)
14
14
  end
15
15
  state
@@ -47,11 +47,33 @@ module Kumi
47
47
  types = signature[:param_types]
48
48
  return if types.nil? || (signature[:arity].negative? && node.args.empty?)
49
49
 
50
+ # Skip type checking for vectorized operations
51
+ broadcast_meta = get_state(:broadcast_metadata, required: false)
52
+ if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
53
+ return
54
+ end
55
+
50
56
  node.args.each_with_index do |arg, i|
51
57
  validate_argument_type(arg, i, types[i], node.fn_name, errors)
52
58
  end
53
59
  end
54
60
 
61
+ def is_part_of_vectorized_operation?(node, broadcast_meta)
62
+ # Check if this node is part of a vectorized or reduction operation
63
+ # This is a simplified check - in a real implementation we'd need to track context
64
+ node.args.any? do |arg|
65
+ case arg
66
+ when Kumi::Syntax::DeclarationReference
67
+ broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
68
+ broadcast_meta[:reduction_operations]&.key?(arg.name)
69
+ when Kumi::Syntax::InputElementReference
70
+ broadcast_meta[:array_fields]&.key?(arg.path.first)
71
+ else
72
+ false
73
+ end
74
+ end
75
+ end
76
+
55
77
  def validate_argument_type(arg, index, expected_type, fn_name, errors)
56
78
  return if expected_type.nil? || expected_type == Kumi::Types::ANY
57
79
 
@@ -67,15 +89,15 @@ module Kumi
67
89
 
68
90
  def get_expression_type(expr)
69
91
  case expr
70
- when TerminalExpressions::Literal
92
+ when Kumi::Syntax::Literal
71
93
  # Inferred type from literal value
72
94
  Kumi::Types.infer_from_value(expr.value)
73
95
 
74
- when TerminalExpressions::FieldRef
96
+ when Kumi::Syntax::InputReference
75
97
  # Declared type from input block (user-specified)
76
98
  get_declared_field_type(expr.name)
77
99
 
78
- when TerminalExpressions::Binding
100
+ when Kumi::Syntax::DeclarationReference
79
101
  # Inferred type from type inference results
80
102
  get_inferred_declaration_type(expr.name)
81
103
 
@@ -101,10 +123,10 @@ module Kumi
101
123
 
102
124
  def describe_expression_type(expr, type)
103
125
  case expr
104
- when TerminalExpressions::Literal
126
+ when Kumi::Syntax::Literal
105
127
  "`#{expr.value}` of type #{type} (literal value)"
106
128
 
107
- when TerminalExpressions::FieldRef
129
+ when Kumi::Syntax::InputReference
108
130
  input_meta = get_state(:input_meta, required: false) || {}
109
131
  field_meta = input_meta[expr.name]
110
132
 
@@ -117,17 +139,17 @@ module Kumi
117
139
  "undeclared input field `#{expr.name}` (inferred as #{type})"
118
140
  end
119
141
 
120
- when TerminalExpressions::Binding
142
+ when Kumi::Syntax::DeclarationReference
121
143
  # This type was inferred from the declaration's expression
122
144
  "reference to declaration `#{expr.name}` of inferred type #{type}"
123
145
 
124
- when Expressions::CallExpression
146
+ when Kumi::Syntax::CallExpression
125
147
  "result of function `#{expr.fn_name}` returning #{type}"
126
148
 
127
- when Expressions::ListExpression
149
+ when Kumi::Syntax::ArrayExpression
128
150
  "list expression of type #{type}"
129
151
 
130
- when Expressions::CascadeExpression
152
+ when Kumi::Syntax::CascadeExpression
131
153
  "cascade expression of type #{type}"
132
154
 
133
155
  else
@@ -4,7 +4,7 @@ module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
6
  # RESPONSIBILITY: Infer types for all declarations based on expression analysis
7
- # DEPENDENCIES: Toposorter (needs topo_order), DefinitionValidator (needs definitions)
7
+ # DEPENDENCIES: Toposorter (needs topo_order), DeclarationValidator (needs definitions)
8
8
  # PRODUCES: decl_types hash mapping declaration names to inferred types
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class TypeInferencer < PassBase
@@ -12,6 +12,9 @@ module Kumi
12
12
  types = {}
13
13
  topo_order = get_state(:topo_order)
14
14
  definitions = get_state(:definitions)
15
+
16
+ # Get broadcast metadata from broadcast detector
17
+ broadcast_meta = get_state(:broadcast_metadata, required: false) || {}
15
18
 
16
19
  # Process declarations in topological order to ensure dependencies are resolved
17
20
  topo_order.each do |name|
@@ -19,8 +22,16 @@ module Kumi
19
22
  next unless decl
20
23
 
21
24
  begin
22
- inferred_type = infer_expression_type(decl.expression, types)
23
- types[name] = inferred_type
25
+ # Check if this declaration is marked as vectorized
26
+ if broadcast_meta[:vectorized_operations]&.key?(name)
27
+ # Infer the element type and wrap in array
28
+ element_type = infer_vectorized_element_type(decl.expression, types, broadcast_meta)
29
+ types[name] = decl.is_a?(Kumi::Syntax::TraitDeclaration) ? { array: :boolean } : { array: element_type }
30
+ else
31
+ # Normal type inference
32
+ inferred_type = infer_expression_type(decl.expression, types, broadcast_meta, name)
33
+ types[name] = inferred_type
34
+ end
24
35
  rescue StandardError => e
25
36
  report_type_error(errors, "Type inference failed: #{e.message}", location: decl&.loc)
26
37
  end
@@ -31,32 +42,47 @@ module Kumi
31
42
 
32
43
  private
33
44
 
34
- def infer_expression_type(expr, type_context = {})
45
+ def infer_expression_type(expr, type_context = {}, broadcast_metadata = {}, current_decl_name = nil)
35
46
  case expr
36
47
  when Literal
37
48
  Types.infer_from_value(expr.value)
38
- when FieldRef
49
+ when InputReference
39
50
  # Look up type from field metadata
40
51
  input_meta = get_state(:input_meta, required: false) || {}
41
52
  meta = input_meta[expr.name]
42
53
  meta&.dig(:type) || :any
43
- when Binding
54
+ when DeclarationReference
44
55
  type_context[expr.name] || :any
45
56
  when CallExpression
46
- infer_call_type(expr, type_context)
47
- when ListExpression
48
- infer_list_type(expr, type_context)
57
+ infer_call_type(expr, type_context, broadcast_metadata, current_decl_name)
58
+ when ArrayExpression
59
+ infer_list_type(expr, type_context, broadcast_metadata, current_decl_name)
49
60
  when CascadeExpression
50
- infer_cascade_type(expr, type_context)
61
+ infer_cascade_type(expr, type_context, broadcast_metadata, current_decl_name)
62
+ when InputElementReference
63
+ # Element reference returns the field type
64
+ infer_element_reference_type(expr)
51
65
  else
52
66
  :any
53
67
  end
54
68
  end
55
69
 
56
- def infer_call_type(call_expr, type_context)
57
- fn_name = call_expr.fn_name
70
+ def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
71
+ fn_name = call_expr.fn_name
58
72
  args = call_expr.args
59
73
 
74
+ # Check broadcast metadata first
75
+ if current_decl_name && broadcast_metadata[:vectorized_values]&.key?(current_decl_name)
76
+ # This declaration is marked as vectorized, so it produces an array
77
+ element_type = infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
78
+ return { array: element_type }
79
+ end
80
+
81
+ if current_decl_name && broadcast_metadata[:reducer_values]&.key?(current_decl_name)
82
+ # This declaration is marked as a reducer, get the result from the function
83
+ return infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
84
+ end
85
+
60
86
  # Check if function exists in registry
61
87
  unless FunctionRegistry.supported?(fn_name)
62
88
  # Don't push error here - let existing TypeChecker handle it
@@ -72,7 +98,7 @@ module Kumi
72
98
  end
73
99
 
74
100
  # Infer argument types
75
- arg_types = args.map { |arg| infer_expression_type(arg, type_context) }
101
+ arg_types = args.map { |arg| infer_expression_type(arg, type_context, broadcast_metadata, current_decl_name) }
76
102
 
77
103
  # Validate parameter types (warn but don't fail)
78
104
  param_types = signature[:param_types] || []
@@ -90,10 +116,29 @@ module Kumi
90
116
  signature[:return_type] || :any
91
117
  end
92
118
 
93
- def infer_list_type(list_expr, type_context)
119
+ def infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
120
+ # For vectorized arithmetic operations, infer the element type
121
+ # For now, assume arithmetic operations on floats produce floats
122
+ case call_expr.fn_name
123
+ when :multiply, :add, :subtract, :divide
124
+ :float
125
+ else
126
+ :any
127
+ end
128
+ end
129
+
130
+ def infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
131
+ # Get the function signature
132
+ return :any unless FunctionRegistry.supported?(fn_name)
133
+
134
+ signature = FunctionRegistry.signature(fn_name)
135
+ signature[:return_type] || :any
136
+ end
137
+
138
+ def infer_list_type(list_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
94
139
  return Types.array(:any) if list_expr.elements.empty?
95
140
 
96
- element_types = list_expr.elements.map { |elem| infer_expression_type(elem, type_context) }
141
+ element_types = list_expr.elements.map { |elem| infer_expression_type(elem, type_context, broadcast_metadata, current_decl_name) }
97
142
 
98
143
  # Try to unify all element types
99
144
  unified_type = element_types.reduce { |acc, type| Types.unify(acc, type) }
@@ -103,11 +148,75 @@ module Kumi
103
148
  Types.array(:any)
104
149
  end
105
150
 
106
- def infer_cascade_type(cascade_expr, type_context)
151
+ def infer_vectorized_element_type(expr, type_context, vectorization_meta)
152
+ # For vectorized operations, we need to infer the element type
153
+ case expr
154
+ when InputElementReference
155
+ # Get the field type from metadata
156
+ input_meta = get_state(:input_meta, required: false) || {}
157
+ array_name = expr.path.first
158
+ field_name = expr.path[1]
159
+
160
+ array_meta = input_meta[array_name]
161
+ return :any unless array_meta&.dig(:type) == :array
162
+
163
+ array_meta.dig(:children, field_name, :type) || :any
164
+
165
+ when CallExpression
166
+ # For arithmetic operations, infer from operands
167
+ if %i[add subtract multiply divide].include?(expr.fn_name)
168
+ # Get types of operands
169
+ arg_types = expr.args.map do |arg|
170
+ if arg.is_a?(InputElementReference)
171
+ infer_vectorized_element_type(arg, type_context, vectorization_meta)
172
+ elsif arg.is_a?(DeclarationReference)
173
+ # Get the element type if it's vectorized
174
+ ref_type = type_context[arg.name]
175
+ if ref_type.is_a?(Hash) && ref_type.key?(:array)
176
+ ref_type[:array]
177
+ else
178
+ ref_type || :any
179
+ end
180
+ else
181
+ infer_expression_type(arg, type_context, vectorization_meta)
182
+ end
183
+ end
184
+
185
+ # Unify types for arithmetic
186
+ Types.unify(*arg_types) || :float
187
+ else
188
+ :any
189
+ end
190
+
191
+ else
192
+ :any
193
+ end
194
+ end
195
+
196
+ def infer_element_reference_type(expr)
197
+ # Get array field metadata
198
+ input_meta = get_state(:input_meta, required: false) || {}
199
+
200
+ return :any unless expr.path.size >= 2
201
+
202
+ array_name = expr.path.first
203
+ field_name = expr.path[1]
204
+
205
+ array_meta = input_meta[array_name]
206
+ return :any unless array_meta&.dig(:type) == :array
207
+
208
+ # Get the field type from children metadata
209
+ field_type = array_meta.dig(:children, field_name, :type) || :any
210
+
211
+ # Return array of field type (vectorized)
212
+ { array: field_type }
213
+ end
214
+
215
+ def infer_cascade_type(cascade_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
107
216
  return :any if cascade_expr.cases.empty?
108
217
 
109
218
  result_types = cascade_expr.cases.map do |case_stmt|
110
- infer_expression_type(case_stmt.result, type_context)
219
+ infer_expression_type(case_stmt.result, type_context, broadcast_metadata, current_decl_name)
111
220
  end
112
221
 
113
222
  # Reduce all possible types into a single unified type
@@ -3,6 +3,10 @@
3
3
  module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
+ # RESPONSIBILITY: Detect unsatisfiable constraints and analyze cascade mutual exclusion
7
+ # DEPENDENCIES: :definitions from NameIndexer, :input_meta from InputCollector
8
+ # PRODUCES: :cascade_metadata - Hash of cascade mutual exclusion analysis results
9
+ # INTERFACE: new(schema, state).run(errors)
6
10
  class UnsatDetector < VisitorPass
7
11
  include Syntax
8
12
 
@@ -15,35 +19,116 @@ module Kumi
15
19
  @definitions = definitions
16
20
  @evaluator = ConstantEvaluator.new(definitions)
17
21
 
22
+ # First pass: analyze cascade conditions for mutual exclusion
23
+ cascade_metadata = {}
24
+ each_decl do |decl|
25
+ cascade_metadata[decl.name] = analyze_cascade_mutual_exclusion(decl, definitions) if decl.expression.is_a?(CascadeExpression)
26
+ end
27
+
28
+ # Store cascade metadata for later passes
29
+
30
+ # Second pass: check for unsatisfiable constraints
18
31
  each_decl do |decl|
19
32
  if decl.expression.is_a?(CascadeExpression)
20
33
  # Special handling for cascade expressions
21
34
  check_cascade_expression(decl, definitions, errors)
22
- else
35
+ elsif decl.expression.is_a?(CallExpression) && decl.expression.fn_name == :or
23
36
  # Check for OR expressions which need special disjunctive handling
24
- if decl.expression.is_a?(CallExpression) && decl.expression.fn_name == :or
25
- impossible = check_or_expression(decl.expression, definitions, errors)
26
- report_error(errors, "conjunction `#{decl.name}` is impossible", location: decl.loc) if impossible
27
- else
28
- # Normal handling for non-cascade expressions
29
- atoms = gather_atoms(decl.expression, definitions, Set.new)
30
- next if atoms.empty?
31
-
32
- # Use enhanced solver that can detect cross-variable mathematical constraints
33
- impossible = if definitions && !definitions.empty?
34
- Kumi::ConstraintRelationshipSolver.unsat?(atoms, definitions, input_meta: @input_meta)
35
- else
36
- Kumi::AtomUnsatSolver.unsat?(atoms)
37
- end
37
+ impossible = check_or_expression(decl.expression, definitions, errors)
38
+ report_error(errors, "conjunction `#{decl.name}` is impossible", location: decl.loc) if impossible
39
+ else
40
+ # Normal handling for non-cascade expressions
41
+ atoms = gather_atoms(decl.expression, definitions, Set.new)
42
+ next if atoms.empty?
43
+
44
+ # Use enhanced solver that can detect cross-variable mathematical constraints
45
+ impossible = if definitions && !definitions.empty?
46
+ Kumi::ConstraintRelationshipSolver.unsat?(atoms, definitions, input_meta: @input_meta)
47
+ else
48
+ Kumi::AtomUnsatSolver.unsat?(atoms)
49
+ end
50
+
51
+ report_error(errors, "conjunction `#{decl.name}` is impossible", location: decl.loc) if impossible
52
+ end
53
+ end
54
+ state.with(:cascade_metadata, cascade_metadata)
55
+ end
38
56
 
39
- report_error(errors, "conjunction `#{decl.name}` is impossible", location: decl.loc) if impossible
57
+ private
58
+
59
+ def analyze_cascade_mutual_exclusion(decl, definitions)
60
+ conditions = []
61
+ condition_traits = []
62
+
63
+ # Extract all cascade conditions (except base case)
64
+ decl.expression.cases[0...-1].each do |when_case|
65
+ next unless when_case.condition
66
+
67
+ next unless when_case.condition.fn_name == :all?
68
+
69
+ when_case.condition.args.each do |arg|
70
+ next unless arg.is_a?(ArrayExpression)
71
+
72
+ arg.elements.each do |element|
73
+ next unless element.is_a?(DeclarationReference)
74
+
75
+ trait_name = element.name
76
+ trait = definitions[trait_name]
77
+ if trait
78
+ conditions << trait.expression
79
+ condition_traits << trait_name
80
+ end
40
81
  end
41
82
  end
83
+ # end
84
+ end
85
+
86
+ # Check mutual exclusion for all pairs
87
+ total_pairs = conditions.size * (conditions.size - 1) / 2
88
+ exclusive_pairs = 0
89
+
90
+ if conditions.size >= 2
91
+ conditions.combination(2).each do |cond1, cond2|
92
+ exclusive_pairs += 1 if conditions_mutually_exclusive?(cond1, cond2)
93
+ end
42
94
  end
43
- state
95
+
96
+ all_mutually_exclusive = (total_pairs > 0) && (exclusive_pairs == total_pairs)
97
+
98
+ {
99
+ condition_traits: condition_traits,
100
+ condition_count: conditions.size,
101
+ all_mutually_exclusive: all_mutually_exclusive,
102
+ exclusive_pairs: exclusive_pairs,
103
+ total_pairs: total_pairs
104
+ }
44
105
  end
45
106
 
46
- private
107
+ def conditions_mutually_exclusive?(cond1, cond2)
108
+ if cond1.is_a?(CallExpression) && cond1.fn_name == :== &&
109
+ cond2.is_a?(CallExpression) && cond2.fn_name == :==
110
+
111
+ c1_field, c1_value = cond1.args
112
+ c2_field, c2_value = cond2.args
113
+
114
+ # Same field, different values = mutually exclusive
115
+ return true if same_field?(c1_field, c2_field) && different_values?(c1_value, c2_value)
116
+ end
117
+
118
+ false
119
+ end
120
+
121
+ def same_field?(field1, field2)
122
+ return false unless field1.is_a?(InputReference) && field2.is_a?(InputReference)
123
+
124
+ field1.name == field2.name
125
+ end
126
+
127
+ def different_values?(val1, val2)
128
+ return false unless val1.is_a?(Literal) && val2.is_a?(Literal)
129
+
130
+ val1.value != val2.value
131
+ end
47
132
 
48
133
  def check_or_expression(or_expr, definitions, errors)
49
134
  # For OR expressions: A | B is impossible only if BOTH A AND B are impossible
@@ -52,26 +137,22 @@ module Kumi
52
137
 
53
138
  # Check if left side is impossible
54
139
  left_atoms = gather_atoms(left_side, definitions, Set.new)
55
- left_impossible = if !left_atoms.empty?
56
- if definitions && !definitions.empty?
57
- Kumi::ConstraintRelationshipSolver.unsat?(left_atoms, definitions, input_meta: @input_meta)
58
- else
59
- Kumi::AtomUnsatSolver.unsat?(left_atoms)
60
- end
61
- else
140
+ left_impossible = if left_atoms.empty?
62
141
  false
142
+ elsif definitions && !definitions.empty?
143
+ Kumi::ConstraintRelationshipSolver.unsat?(left_atoms, definitions, input_meta: @input_meta)
144
+ else
145
+ Kumi::AtomUnsatSolver.unsat?(left_atoms)
63
146
  end
64
147
 
65
148
  # Check if right side is impossible
66
149
  right_atoms = gather_atoms(right_side, definitions, Set.new)
67
- right_impossible = if !right_atoms.empty?
68
- if definitions && !definitions.empty?
69
- Kumi::ConstraintRelationshipSolver.unsat?(right_atoms, definitions, input_meta: @input_meta)
70
- else
71
- Kumi::AtomUnsatSolver.unsat?(right_atoms)
72
- end
73
- else
150
+ right_impossible = if right_atoms.empty?
74
151
  false
152
+ elsif definitions && !definitions.empty?
153
+ Kumi::ConstraintRelationshipSolver.unsat?(right_atoms, definitions, input_meta: @input_meta)
154
+ else
155
+ Kumi::AtomUnsatSolver.unsat?(right_atoms)
75
156
  end
76
157
 
77
158
  # OR is impossible only if BOTH sides are impossible
@@ -106,10 +187,10 @@ module Kumi
106
187
  elsif current.is_a?(CallExpression) && current.fn_name == :all?
107
188
  # For all? function, add all trait arguments to the stack
108
189
  current.args.each { |arg| stack << arg }
109
- elsif current.is_a?(ListExpression)
110
- # For ListExpression, add all elements to the stack
190
+ elsif current.is_a?(ArrayExpression)
191
+ # For ArrayExpression, add all elements to the stack
111
192
  current.elements.each { |elem| stack << elem }
112
- elsif current.is_a?(Binding)
193
+ elsif current.is_a?(DeclarationReference)
113
194
  name = current.name
114
195
  unless visited.include?(name)
115
196
  visited << name
@@ -141,8 +222,8 @@ module Kumi
141
222
 
142
223
  # Skip single-trait 'on' branches: trait-level unsat detection covers these
143
224
  if when_case.condition.is_a?(CallExpression) && when_case.condition.fn_name == :all?
144
- # Handle both ListExpression (old format) and multiple args (new format)
145
- if when_case.condition.args.size == 1 && when_case.condition.args.first.is_a?(ListExpression)
225
+ # Handle both ArrayExpression (old format) and multiple args (new format)
226
+ if when_case.condition.args.size == 1 && when_case.condition.args.first.is_a?(ArrayExpression)
146
227
  list = when_case.condition.args.first
147
228
  next if list.elements.size == 1
148
229
  elsif when_case.condition.args.size == 1
@@ -154,7 +235,6 @@ module Kumi
154
235
  condition_atoms = gather_atoms(when_case.condition, definitions, Set.new, [])
155
236
  # DEBUG
156
237
  # if when_case.condition.is_a?(CallExpression) && [:all?, :any?, :none?].include?(when_case.condition.fn_name)
157
- # puts "DEBUG: Processing #{when_case.condition.fn_name} condition"
158
238
  # puts " Args: #{when_case.condition.args.inspect}"
159
239
  # puts " Atoms found: #{condition_atoms.inspect}"
160
240
  # end
@@ -174,14 +254,14 @@ module Kumi
174
254
 
175
255
  # For multi-trait on-clauses, report the trait names rather than the value name
176
256
  if when_case.condition.is_a?(CallExpression) && when_case.condition.fn_name == :all?
177
- # Handle both ListExpression (old format) and multiple args (new format)
178
- trait_bindings = if when_case.condition.args.size == 1 && when_case.condition.args.first.is_a?(ListExpression)
257
+ # Handle both ArrayExpression (old format) and multiple args (new format)
258
+ trait_bindings = if when_case.condition.args.size == 1 && when_case.condition.args.first.is_a?(ArrayExpression)
179
259
  when_case.condition.args.first.elements
180
260
  else
181
261
  when_case.condition.args
182
262
  end
183
263
 
184
- if trait_bindings.all?(Binding)
264
+ if trait_bindings.all?(DeclarationReference)
185
265
  traits = trait_bindings.map(&:name).join(" AND ")
186
266
  report_error(errors, "conjunction `#{traits}` is impossible", location: decl.loc)
187
267
  next
@@ -193,7 +273,7 @@ module Kumi
193
273
 
194
274
  def term(node, _defs)
195
275
  case node
196
- when FieldRef, Binding
276
+ when InputReference, DeclarationReference
197
277
  val = @evaluator.evaluate(node)
198
278
  val == :unknown ? node.name : val
199
279
  when Literal
@@ -205,14 +285,14 @@ module Kumi
205
285
 
206
286
  def check_domain_constraints(node, definitions, errors)
207
287
  case node
208
- when FieldRef
209
- # Check if FieldRef points to a field with domain constraints
288
+ when InputReference
289
+ # Check if InputReference points to a field with domain constraints
210
290
  field_meta = @input_meta[node.name]
211
291
  nil unless field_meta&.dig(:domain)
212
292
 
213
- # For FieldRef, the constraint comes from trait conditions
214
- # We don't flag here since the FieldRef itself is valid
215
- when Binding
293
+ # For InputReference, the constraint comes from trait conditions
294
+ # We don't flag here since the InputReference itself is valid
295
+ when DeclarationReference
216
296
  # Check if this binding evaluates to a value that violates domain constraints
217
297
  definition = definitions[node.name]
218
298
  return unless definition
@@ -254,18 +334,18 @@ module Kumi
254
334
  end
255
335
 
256
336
  def impossible_constraint?(lhs, rhs, operator)
257
- # Case 1: FieldRef compared against value outside its domain
258
- if lhs.is_a?(FieldRef) && rhs.is_a?(Literal)
337
+ # Case 1: InputReference compared against value outside its domain
338
+ if lhs.is_a?(InputReference) && rhs.is_a?(Literal)
259
339
  return field_literal_impossible?(lhs, rhs, operator)
260
- elsif rhs.is_a?(FieldRef) && lhs.is_a?(Literal)
340
+ elsif rhs.is_a?(InputReference) && lhs.is_a?(Literal)
261
341
  # Reverse case: literal compared to field
262
342
  return field_literal_impossible?(rhs, lhs, flip_operator(operator))
263
343
  end
264
344
 
265
- # Case 2: Binding that evaluates to literal compared against impossible value
266
- if lhs.is_a?(Binding) && rhs.is_a?(Literal)
345
+ # Case 2: DeclarationReference that evaluates to literal compared against impossible value
346
+ if lhs.is_a?(DeclarationReference) && rhs.is_a?(Literal)
267
347
  return binding_literal_impossible?(lhs, rhs, operator)
268
- elsif rhs.is_a?(Binding) && lhs.is_a?(Literal)
348
+ elsif rhs.is_a?(DeclarationReference) && lhs.is_a?(Literal)
269
349
  return binding_literal_impossible?(rhs, lhs, flip_operator(operator))
270
350
  end
271
351