rlsl 0.1.1 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -1
  3. data/README.md +2 -0
  4. data/lib/rlsl/base_translator/call_parser.rb +68 -0
  5. data/lib/rlsl/base_translator/code_rewriter.rb +107 -0
  6. data/lib/rlsl/base_translator/code_scanner.rb +102 -0
  7. data/lib/rlsl/base_translator.rb +159 -63
  8. data/lib/rlsl/code_generator/math_prelude.rb +79 -0
  9. data/lib/rlsl/code_generator/ruby_wrapper_generator.rb +97 -0
  10. data/lib/rlsl/code_generator/shader_function_generator.rb +19 -0
  11. data/lib/rlsl/code_generator/template_context.rb +41 -0
  12. data/lib/rlsl/code_generator/uniform_struct_generator.rb +29 -0
  13. data/lib/rlsl/code_generator.rb +26 -201
  14. data/lib/rlsl/compiled_shader.rb +4 -12
  15. data/lib/rlsl/function_context.rb +31 -14
  16. data/lib/rlsl/glsl/translator.rb +19 -38
  17. data/lib/rlsl/msl/shader.rb +8 -38
  18. data/lib/rlsl/msl/translator.rb +19 -36
  19. data/lib/rlsl/msl/uniform_buffer_packer.rb +68 -0
  20. data/lib/rlsl/prism/ast_visitor/control_flow_visiting.rb +96 -0
  21. data/lib/rlsl/prism/ast_visitor/definition_visiting.rb +79 -0
  22. data/lib/rlsl/prism/ast_visitor/expression_visiting.rb +168 -0
  23. data/lib/rlsl/prism/ast_visitor/scope_context.rb +44 -0
  24. data/lib/rlsl/prism/ast_visitor/visitor_registry.rb +21 -0
  25. data/lib/rlsl/prism/ast_visitor.rb +54 -298
  26. data/lib/rlsl/prism/builtins/function_registry.rb +114 -0
  27. data/lib/rlsl/prism/builtins/operator_rules.rb +99 -0
  28. data/lib/rlsl/prism/builtins/swizzle_rules.rb +38 -0
  29. data/lib/rlsl/prism/builtins.rb +31 -148
  30. data/lib/rlsl/prism/compilation_unit.rb +7 -0
  31. data/lib/rlsl/prism/emitters/base_emitter/control_flow_emission.rb +83 -0
  32. data/lib/rlsl/prism/emitters/base_emitter/definition_emission.rb +104 -0
  33. data/lib/rlsl/prism/emitters/base_emitter/expression_emission.rb +92 -0
  34. data/lib/rlsl/prism/emitters/base_emitter/statement_emission.rb +89 -0
  35. data/lib/rlsl/prism/emitters/base_emitter.rb +68 -423
  36. data/lib/rlsl/prism/emitters/c_emitter.rb +78 -121
  37. data/lib/rlsl/prism/emitters/glsl_emitter.rb +30 -65
  38. data/lib/rlsl/prism/emitters/msl_emitter.rb +35 -62
  39. data/lib/rlsl/prism/emitters/target_emitter.rb +77 -0
  40. data/lib/rlsl/prism/emitters/target_profile.rb +34 -0
  41. data/lib/rlsl/prism/emitters/wgsl_emitter.rb +65 -61
  42. data/lib/rlsl/prism/ir/control_flow.rb +83 -0
  43. data/lib/rlsl/prism/ir/definitions.rb +67 -0
  44. data/lib/rlsl/prism/ir/expressions.rb +197 -0
  45. data/lib/rlsl/prism/ir/node.rb +21 -0
  46. data/lib/rlsl/prism/ir/nodes.rb +4 -371
  47. data/lib/rlsl/prism/ir/traversal.rb +62 -0
  48. data/lib/rlsl/prism/source_extractor/block_locator.rb +52 -0
  49. data/lib/rlsl/prism/source_extractor.rb +14 -133
  50. data/lib/rlsl/prism/source_unit/parser.rb +78 -0
  51. data/lib/rlsl/prism/source_unit.rb +42 -0
  52. data/lib/rlsl/prism/target_capability_validator.rb +85 -0
  53. data/lib/rlsl/prism/transpiler.rb +63 -60
  54. data/lib/rlsl/prism/type_inference/call_type_resolver.rb +35 -0
  55. data/lib/rlsl/prism/type_inference/call_validator.rb +71 -0
  56. data/lib/rlsl/prism/type_inference/collection_type_resolver.rb +80 -0
  57. data/lib/rlsl/prism/type_inference/control_flow_inferer.rb +66 -0
  58. data/lib/rlsl/prism/type_inference/definition_inferer.rb +41 -0
  59. data/lib/rlsl/prism/type_inference/expression_inferer.rb +92 -0
  60. data/lib/rlsl/prism/type_inference/field_type_resolver.rb +17 -0
  61. data/lib/rlsl/prism/type_inference/inferer_registry.rb +38 -0
  62. data/lib/rlsl/prism/type_inference/scope_stack.rb +47 -0
  63. data/lib/rlsl/prism/type_inference/type_environment.rb +112 -0
  64. data/lib/rlsl/prism/type_inference/type_shapes.rb +33 -0
  65. data/lib/rlsl/prism/type_inference.rb +114 -248
  66. data/lib/rlsl/runtime_shader.rb +47 -0
  67. data/lib/rlsl/shader_builder/build_service.rb +77 -0
  68. data/lib/rlsl/shader_builder/native_extension_compiler.rb +71 -0
  69. data/lib/rlsl/shader_builder/shader_definition.rb +68 -0
  70. data/lib/rlsl/shader_builder/source_resolver.rb +91 -0
  71. data/lib/rlsl/shader_builder.rb +22 -113
  72. data/lib/rlsl/types/catalog.rb +47 -0
  73. data/lib/rlsl/types/target_resolver.rb +15 -0
  74. data/lib/rlsl/types/type_spec.rb +167 -0
  75. data/lib/rlsl/types/value_normalizer.rb +81 -0
  76. data/lib/rlsl/types.rb +11 -29
  77. data/lib/rlsl/uniform_context.rb +6 -12
  78. data/lib/rlsl/version.rb +1 -1
  79. data/lib/rlsl/wgsl/translator.rb +23 -39
  80. data/lib/rlsl.rb +14 -12
  81. metadata +55 -2
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "prism"
4
+
5
+ module RLSL
6
+ module Prism
7
+ class SourceUnitParser
8
+ def initialize(source)
9
+ @source = source.to_s
10
+ end
11
+
12
+ def parse
13
+ normalized = @source.strip
14
+ return SourceUnit.new(params: [], body: "") if normalized.empty?
15
+
16
+ params_source, body_source = split_sections(normalized)
17
+ validate_body!(body_source)
18
+
19
+ SourceUnit.new(
20
+ params: parse_params(params_source),
21
+ body: body_source.strip
22
+ )
23
+ end
24
+
25
+ private
26
+
27
+ def split_sections(source)
28
+ lines = source.lines
29
+ first_line = lines.first&.strip
30
+ return [nil, source] unless parameter_line?(first_line)
31
+
32
+ [first_line, lines[1..].to_a.join]
33
+ end
34
+
35
+ def parameter_line?(line)
36
+ line&.start_with?("|") && line&.end_with?("|")
37
+ end
38
+
39
+ def validate_body!(body_source)
40
+ return if body_source.to_s.strip.empty?
41
+
42
+ parsed = ::Prism.parse(body_source)
43
+ raise ArgumentError, "Unable to parse source unit body" unless parsed.success?
44
+ end
45
+
46
+ def parse_params(params_source)
47
+ return [] unless params_source
48
+
49
+ parsed = ::Prism.parse("proc do #{params_source}\nend\n")
50
+ raise ArgumentError, "Unable to parse source unit params" unless parsed.success?
51
+
52
+ block = each_node(parsed.value).find { |node| node.is_a?(::Prism::BlockNode) }
53
+ return [] unless block&.parameters
54
+
55
+ block.parameters.parameters.requireds.map(&:name)
56
+ end
57
+
58
+ def each_node(node)
59
+ return enum_for(:each_node, node) unless block_given?
60
+ return unless node
61
+
62
+ stack = [node]
63
+
64
+ until stack.empty?
65
+ current = stack.pop
66
+ yield current
67
+
68
+ children = if current.respond_to?(:compact_child_nodes)
69
+ current.compact_child_nodes
70
+ else
71
+ Array(current.child_nodes).compact
72
+ end
73
+ stack.concat(children.reverse)
74
+ end
75
+ end
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "source_unit/parser"
4
+
5
+ module RLSL
6
+ module Prism
7
+ SourceUnit = Struct.new(:params, :body, keyword_init: true) do
8
+ class << self
9
+ def from_source(source)
10
+ SourceUnitParser.new(source).parse
11
+ end
12
+
13
+ def from_block(block)
14
+ new(
15
+ params: extract_params(block),
16
+ body: block.body&.slice.to_s.strip
17
+ )
18
+ end
19
+
20
+ private
21
+
22
+ def extract_params(block)
23
+ return [] unless block.parameters
24
+
25
+ parameters = block.parameters.parameters
26
+ parameters.requireds.map(&:name)
27
+ end
28
+ end
29
+
30
+ def without_params
31
+ self.class.new(params: [], body: body)
32
+ end
33
+
34
+ def to_source
35
+ segments = []
36
+ segments << "|#{params.join(', ')}|" unless params.empty?
37
+ segments << body unless body.empty?
38
+ segments.join("\n")
39
+ end
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,85 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../types"
4
+ require_relative "ir/traversal"
5
+ require_relative "type_inference/type_shapes"
6
+
7
+ module RLSL
8
+ module Prism
9
+ class TargetCapabilityError < StandardError; end
10
+
11
+ class TargetCapabilityValidator
12
+ include TypeShapes
13
+
14
+ def validate!(node, target)
15
+ @target = target.to_sym
16
+ IR::Traversal.each(node) { |current| validate_node!(current) }
17
+ node
18
+ end
19
+
20
+ private
21
+
22
+ attr_reader :target
23
+
24
+ def validate_node!(node)
25
+ return if node.nil?
26
+
27
+ validate_type!(node.type, context: node.class.name.split("::").last)
28
+
29
+ case node
30
+ when IR::FuncCall
31
+ validate_builtin!(node)
32
+ when IR::FunctionDefinition
33
+ validate_type!(node.return_type, context: "function #{node.name} return")
34
+ node.param_types.each_value do |type|
35
+ validate_type!(type, context: "function #{node.name} parameter")
36
+ end
37
+ when IR::GlobalDecl
38
+ validate_type!(node.element_type, context: "global #{node.name} element")
39
+ end
40
+ end
41
+
42
+ def validate_builtin!(node)
43
+ return unless Builtins.function?(node.name)
44
+
45
+ unless Builtins.supported_on_target?(node.name, target)
46
+ raise TargetCapabilityError, "Builtin #{node.name} is not supported on #{target.to_s.upcase}"
47
+ end
48
+
49
+ Builtins.explicit_types(node.name).each do |type|
50
+ validate_type!(type, context: "builtin #{node.name}")
51
+ end
52
+ end
53
+
54
+ def validate_type!(type, context:)
55
+ return if type.nil?
56
+
57
+ if TypeShapes.array?(type)
58
+ validate_type!(TypeShapes.element_type(type), context: "#{context} array element")
59
+ return
60
+ end
61
+
62
+ if type.is_a?(Array)
63
+ type.each { |element_type| validate_type!(element_type, context: "#{context} tuple element") }
64
+ return
65
+ end
66
+
67
+ if type.is_a?(IR::TupleType)
68
+ type.types.each { |element_type| validate_type!(element_type, context: "#{context} tuple element") }
69
+ return
70
+ end
71
+
72
+ return unless uniform_type_symbol?(type)
73
+
74
+ return if RLSL::UniformTypes.fetch(type).target_supported?(target)
75
+
76
+ raise TargetCapabilityError,
77
+ "Unsupported #{target.to_s.upcase} type #{type.inspect} in #{context}"
78
+ end
79
+
80
+ def uniform_type_symbol?(type)
81
+ type.is_a?(Symbol) && RLSL::UniformTypes.supported?(type)
82
+ end
83
+ end
84
+ end
85
+ end
@@ -3,11 +3,15 @@
3
3
  require "prism"
4
4
 
5
5
  require_relative "ir/nodes"
6
+ require_relative "compilation_unit"
7
+ require_relative "source_unit"
6
8
  require_relative "source_extractor"
7
9
  require_relative "builtins"
8
10
  require_relative "ast_visitor"
9
11
  require_relative "type_inference"
12
+ require_relative "target_capability_validator"
10
13
  require_relative "emitters/base_emitter"
14
+ require_relative "emitters/target_emitter"
11
15
  require_relative "emitters/c_emitter"
12
16
  require_relative "emitters/msl_emitter"
13
17
  require_relative "emitters/wgsl_emitter"
@@ -23,70 +27,91 @@ module RLSL
23
27
  glsl: Emitters::GLSLEmitter
24
28
  }.freeze
25
29
 
26
- attr_reader :ir, :uniforms, :custom_functions
30
+ attr_reader :uniforms, :custom_functions
27
31
 
28
- def initialize(uniforms = {}, custom_functions = {})
32
+ def initialize(uniforms = {}, custom_functions = {}, globals: {})
29
33
  @uniforms = uniforms
30
34
  @custom_functions = custom_functions
35
+ @globals = globals
31
36
  @source_extractor = SourceExtractor.new
32
- @ir = nil
33
37
  end
34
38
 
35
- def parse_block(block)
36
- source = @source_extractor.extract(block)
37
- parse_source(source)
39
+ def compile_block(block)
40
+ compile_unit(@source_extractor.extract_unit(block))
38
41
  end
39
42
 
40
- def parse_source(source)
41
- params, body = extract_block_body(source)
42
-
43
- visitor = ASTVisitor.new(uniforms: @uniforms, params: params)
44
- @ir = visitor.parse(body)
45
-
46
- inference = TypeInference.new(@uniforms, @custom_functions)
47
- inference.register(:frag_coord, :vec2)
48
- inference.register(:resolution, :vec2)
49
- inference.infer(@ir)
50
-
51
- @ir
43
+ def compile_source(source)
44
+ compile_unit(source_unit(source))
52
45
  end
53
46
 
54
- def emit(target, needs_return: true)
55
- raise "No IR parsed yet. Call parse_block or parse_source first." unless @ir
56
-
57
- emitter_class = TARGETS[target.to_sym]
58
- raise "Unknown target: #{target}" unless emitter_class
47
+ def compile_helpers(block, function_signatures = {})
48
+ compile_unit(
49
+ @source_extractor.extract_unit(block).without_params,
50
+ function_signatures: function_signatures
51
+ )
52
+ end
59
53
 
60
- emitter = emitter_class.new
61
- emitter.emit(@ir, needs_return: needs_return)
54
+ def emit(target, compilation:, needs_return: true)
55
+ emitter = resolve_emitter(target)
56
+ validate_target_capabilities!(compilation.ir, target)
57
+ emitter.emit(compilation.ir, needs_return: needs_return)
62
58
  end
63
59
 
64
60
  def transpile(block, target)
65
- parse_block(block)
66
- emit(target)
61
+ emit(target, compilation: compile_block(block))
67
62
  end
68
63
 
69
64
  def transpile_source(source, target)
70
- parse_source(source)
71
- emit(target)
65
+ emit(target, compilation: compile_source(source))
72
66
  end
73
67
 
74
68
  def transpile_helpers(block, target, function_signatures = {})
75
- source = @source_extractor.extract(block)
76
- _, body = extract_block_body(source)
69
+ emit(
70
+ target,
71
+ needs_return: false,
72
+ compilation: compile_helpers(block, function_signatures)
73
+ )
74
+ end
77
75
 
78
- visitor = ASTVisitor.new(uniforms: @uniforms)
79
- @ir = visitor.parse(body)
76
+ private
80
77
 
81
- apply_function_signatures(@ir, function_signatures)
78
+ def source_unit(source)
79
+ SourceUnit.from_source(source)
80
+ end
82
81
 
83
- inference = TypeInference.new(@uniforms, @custom_functions)
84
- inference.infer(@ir)
82
+ def build_ir(unit)
83
+ visitor = ASTVisitor.new(uniforms: @uniforms, params: unit.params)
84
+ visitor.parse(unit.body)
85
+ end
85
86
 
86
- emit(target, needs_return: false)
87
+ def compile_unit(unit, function_signatures: nil)
88
+ ir = build_ir(unit)
89
+ apply_function_signatures(ir, function_signatures || {})
90
+ infer_ir(ir)
91
+ CompilationUnit.new(source_unit: unit, ir: ir)
87
92
  end
88
93
 
89
- private
94
+ def infer_ir(ir)
95
+ inference = TypeInference.new(@uniforms, @custom_functions, globals: @globals)
96
+ register_pipeline_symbols(inference)
97
+ inference.infer(ir)
98
+ end
99
+
100
+ def register_pipeline_symbols(inference)
101
+ inference.register(:frag_coord, :vec2)
102
+ inference.register(:resolution, :vec2)
103
+ end
104
+
105
+ def resolve_emitter(target)
106
+ emitter_class = TARGETS[target.to_sym]
107
+ raise "Unknown target: #{target}" unless emitter_class
108
+
109
+ emitter_class.new
110
+ end
111
+
112
+ def validate_target_capabilities!(ir, target)
113
+ TargetCapabilityValidator.new.validate!(ir, target)
114
+ end
90
115
 
91
116
  def apply_function_signatures(ir, signatures)
92
117
  return unless ir.is_a?(IR::Block)
@@ -101,28 +126,6 @@ module RLSL
101
126
  stmt.param_types = sig[:params] || {}
102
127
  end
103
128
  end
104
-
105
- def extract_block_body(source)
106
- lines = source.strip.lines
107
- params = []
108
-
109
- first_line = lines.first&.strip || ""
110
-
111
- if first_line.start_with?("|")
112
- param_end = first_line.index("|", 1)
113
- if param_end
114
- param_str = first_line[1...param_end]
115
- params = param_str.split(",").map { |p| p.strip.to_sym }
116
- lines[0] = first_line[(param_end + 1)..]
117
- end
118
- end
119
-
120
- lines.shift while lines.first&.strip&.empty?
121
- lines.pop while lines.last&.strip&.empty?
122
-
123
- body = lines.join.strip
124
- [params, body]
125
- end
126
129
  end
127
130
  end
128
131
  end
@@ -0,0 +1,35 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RLSL
4
+ module Prism
5
+ class CallTypeResolver
6
+ def initialize(custom_functions:, call_validator:)
7
+ @custom_functions = custom_functions
8
+ @call_validator = call_validator
9
+ end
10
+
11
+ def resolve(node)
12
+ sig = Builtins.function_signature(node.name)
13
+ return resolve_builtin(node, sig) if sig
14
+
15
+ custom_function = @custom_functions[node.name.to_sym]
16
+ return resolve_custom(node, custom_function) if custom_function
17
+
18
+ node.receiver&.type
19
+ end
20
+
21
+ private
22
+
23
+ def resolve_builtin(node, signature)
24
+ arg_types = node.args.map(&:type)
25
+ @call_validator.validate_builtin!(node, arg_types, signature)
26
+ Builtins.resolve_return_type(signature[:returns], arg_types)
27
+ end
28
+
29
+ def resolve_custom(node, signature)
30
+ @call_validator.validate_custom!(node.name, node.args.map(&:type), signature)
31
+ signature[:returns]
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,71 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RLSL
4
+ module Prism
5
+ class CallValidator
6
+ def validate_builtin!(node, arg_types, signature)
7
+ validate_signature!(
8
+ node.name,
9
+ arg_types,
10
+ signature[:args],
11
+ variadic: signature[:variadic],
12
+ min_args: signature[:min_args]
13
+ ) unless skip_builtin_validation?(node, signature)
14
+ end
15
+
16
+ def validate_custom!(name, arg_types, signature)
17
+ params = signature[:params]
18
+ return unless params
19
+
20
+ validate_signature!(name, arg_types, params.values)
21
+ end
22
+
23
+ private
24
+
25
+ def validate_signature!(name, arg_types, expected_types, variadic: false, min_args: nil)
26
+ validate_argument_count!(name, arg_types.length, expected_types.length, variadic: variadic, min_args: min_args)
27
+
28
+ arg_types.each_with_index do |actual_type, index|
29
+ expected_type = expected_types[index]
30
+ next if compatible_argument_type?(expected_type, actual_type)
31
+
32
+ raise SignatureError,
33
+ "Invalid argument #{index + 1} for #{name}: expected #{expected_type}, got #{actual_type || :unknown}"
34
+ end
35
+ end
36
+
37
+ def validate_argument_count!(name, actual_count, expected_count, variadic:, min_args:)
38
+ return if valid_argument_count?(actual_count, expected_count, variadic: variadic, min_args: min_args)
39
+
40
+ raise SignatureError,
41
+ "Wrong number of arguments for #{name}: expected #{expected_count_description(expected_count, variadic, min_args)}, got #{actual_count}"
42
+ end
43
+
44
+ def valid_argument_count?(actual_count, expected_count, variadic:, min_args:)
45
+ return actual_count == expected_count unless variadic
46
+
47
+ minimum = min_args || expected_count
48
+ actual_count.between?(minimum, expected_count)
49
+ end
50
+
51
+ def expected_count_description(expected_count, variadic, min_args)
52
+ return expected_count.to_s unless variadic
53
+
54
+ minimum = min_args || expected_count
55
+ minimum == expected_count ? minimum.to_s : "#{minimum}..#{expected_count}"
56
+ end
57
+
58
+ def compatible_argument_type?(expected_type, actual_type)
59
+ return true if expected_type == :any
60
+ return true if expected_type == actual_type
61
+ return true if expected_type == :float && actual_type == :int
62
+
63
+ false
64
+ end
65
+
66
+ def skip_builtin_validation?(node, signature)
67
+ signature[:variadic] && node.args.empty? && !node.type.nil?
68
+ end
69
+ end
70
+ end
71
+ end
@@ -0,0 +1,80 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RLSL
4
+ module Prism
5
+ class CollectionTypeResolver
6
+ include TypeShapes
7
+
8
+ def initialize(type_environment:, custom_functions:, register:)
9
+ @type_environment = type_environment
10
+ @custom_functions = custom_functions
11
+ @register = register
12
+ end
13
+
14
+ def resolve_array_literal(node)
15
+ element_type = node.elements.first&.type || :float
16
+ TypeShapes.array(element_type)
17
+ end
18
+
19
+ def resolve_array_index(node)
20
+ array_type = node.array.type
21
+ return TypeShapes.element_type(array_type) if TypeShapes.array?(array_type)
22
+ return @type_environment.array_element_type(node.array.name) || :float if node.array.is_a?(IR::VarRef)
23
+
24
+ :float
25
+ end
26
+
27
+ def resolve_global_decl(node)
28
+ if node.initializer.is_a?(IR::ArrayLiteral)
29
+ node.array_size ||= node.initializer.elements.length
30
+ first_elem = node.initializer.elements.first
31
+ node.element_type ||= first_elem&.type || :float
32
+ return TypeShapes.array(node.element_type)
33
+ end
34
+
35
+ node.initializer&.type
36
+ end
37
+
38
+ def assign_multiple_targets(node)
39
+ value_type = node.value.type
40
+ return assign_tuple_targets(node.targets, value_type.types) if value_type.is_a?(IR::TupleType)
41
+ return assign_array_targets(node.targets, TypeShapes.element_type(value_type)) if TypeShapes.array?(value_type)
42
+ return assign_custom_targets(node.targets, @custom_functions[node.value.name]) if custom_multi_return?(node.value)
43
+
44
+ nil
45
+ end
46
+
47
+ private
48
+
49
+ def assign_tuple_targets(targets, types)
50
+ targets.each_with_index do |target, index|
51
+ assign_target(target, types[index])
52
+ end
53
+ end
54
+
55
+ def assign_array_targets(targets, type)
56
+ targets.each do |target|
57
+ assign_target(target, type)
58
+ end
59
+ end
60
+
61
+ def assign_custom_targets(targets, signature)
62
+ returns = signature[:returns]
63
+ return unless returns.is_a?(Array)
64
+
65
+ targets.each_with_index do |target, index|
66
+ assign_target(target, returns[index])
67
+ end
68
+ end
69
+
70
+ def assign_target(target, type)
71
+ target.type = type
72
+ @register.call(target.name, type)
73
+ end
74
+
75
+ def custom_multi_return?(value)
76
+ value.is_a?(IR::FuncCall) && @custom_functions.key?(value.name)
77
+ end
78
+ end
79
+ end
80
+ end
@@ -0,0 +1,66 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RLSL
4
+ module Prism
5
+ class ControlFlowInferer
6
+ def initialize(infer:, infer_child_scope:, infer_in_scope:)
7
+ @infer = infer
8
+ @infer_child_scope = infer_child_scope
9
+ @infer_in_scope = infer_in_scope
10
+ end
11
+
12
+ def infer_if_statement(node)
13
+ @infer.call(node.condition)
14
+ @infer_child_scope.call(node.then_branch)
15
+ @infer_child_scope.call(node.else_branch) if node.else_branch
16
+
17
+ node.type = node.then_branch.type
18
+ node
19
+ end
20
+
21
+ def infer_ternary(node)
22
+ @infer.call(node.condition)
23
+ @infer.call(node.then_expr)
24
+ @infer.call(node.else_expr)
25
+ node.type = node.then_expr.type
26
+ node
27
+ end
28
+
29
+ def infer_return(node)
30
+ @infer.call(node.expression) if node.expression
31
+ node.type = node.expression&.type
32
+ node
33
+ end
34
+
35
+ def infer_for_loop(node)
36
+ @infer.call(node.range_start)
37
+ @infer.call(node.range_end)
38
+
39
+ @infer_in_scope.call(node.variable => :int) do
40
+ @infer.call(node.body)
41
+ end
42
+
43
+ node.type = nil
44
+ node
45
+ end
46
+
47
+ def infer_while_loop(node)
48
+ @infer.call(node.condition)
49
+ @infer_child_scope.call(node.body)
50
+ node.type = nil
51
+ node
52
+ end
53
+
54
+ def infer_function_definition(node)
55
+ @infer_in_scope.call(node.param_types) do
56
+ @infer.call(node.body)
57
+
58
+ node.return_type ||= node.body&.type
59
+ node.type = node.return_type
60
+ end
61
+
62
+ node
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,41 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RLSL
4
+ module Prism
5
+ class DefinitionInferer
6
+ def initialize(infer:, register:, collection_type_resolver:)
7
+ @infer = infer
8
+ @register = register
9
+ @collection_type_resolver = collection_type_resolver
10
+ end
11
+
12
+ def infer_var_decl(node)
13
+ @infer.call(node.initializer) if node.initializer
14
+ node.type ||= node.initializer&.type
15
+ @register.call(node.name, node.type) if node.type
16
+ node
17
+ end
18
+
19
+ def infer_assignment(node)
20
+ @infer.call(node.target)
21
+ @infer.call(node.value)
22
+ node.type = node.value.type
23
+ node
24
+ end
25
+
26
+ def infer_global_decl(node)
27
+ @infer.call(node.initializer) if node.initializer
28
+ node.type ||= @collection_type_resolver.resolve_global_decl(node)
29
+ @register.call(node.name, node.type) if node.type
30
+ node
31
+ end
32
+
33
+ def infer_multiple_assignment(node)
34
+ @infer.call(node.value)
35
+ @collection_type_resolver.assign_multiple_targets(node)
36
+ node.type = nil
37
+ node
38
+ end
39
+ end
40
+ end
41
+ end