steep 0.47.1 → 0.49.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/Gemfile.lock +7 -7
  4. data/lib/steep/ast/types/any.rb +2 -0
  5. data/lib/steep/ast/types/boolean.rb +2 -0
  6. data/lib/steep/ast/types/bot.rb +2 -0
  7. data/lib/steep/ast/types/class.rb +2 -0
  8. data/lib/steep/ast/types/factory.rb +162 -138
  9. data/lib/steep/ast/types/helper.rb +8 -0
  10. data/lib/steep/ast/types/instance.rb +2 -0
  11. data/lib/steep/ast/types/intersection.rb +4 -0
  12. data/lib/steep/ast/types/literal.rb +2 -0
  13. data/lib/steep/ast/types/logic.rb +2 -0
  14. data/lib/steep/ast/types/name.rb +6 -0
  15. data/lib/steep/ast/types/nil.rb +2 -0
  16. data/lib/steep/ast/types/proc.rb +9 -0
  17. data/lib/steep/ast/types/record.rb +4 -0
  18. data/lib/steep/ast/types/self.rb +2 -0
  19. data/lib/steep/ast/types/top.rb +2 -0
  20. data/lib/steep/ast/types/tuple.rb +4 -0
  21. data/lib/steep/ast/types/union.rb +4 -0
  22. data/lib/steep/ast/types/var.rb +16 -3
  23. data/lib/steep/ast/types/void.rb +2 -0
  24. data/lib/steep/diagnostic/ruby.rb +36 -11
  25. data/lib/steep/diagnostic/signature.rb +56 -0
  26. data/lib/steep/drivers/check.rb +2 -2
  27. data/lib/steep/interface/function.rb +4 -0
  28. data/lib/steep/interface/method_type.rb +14 -26
  29. data/lib/steep/interface/type_param.rb +103 -0
  30. data/lib/steep/server/base_worker.rb +1 -0
  31. data/lib/steep/server/interaction_worker.rb +1 -1
  32. data/lib/steep/server/type_check_worker.rb +2 -2
  33. data/lib/steep/services/signature_service.rb +2 -2
  34. data/lib/steep/services/type_check_service.rb +4 -1
  35. data/lib/steep/signature/validator.rb +228 -49
  36. data/lib/steep/subtyping/cache.rb +30 -0
  37. data/lib/steep/subtyping/check.rb +600 -705
  38. data/lib/steep/subtyping/constraints.rb +66 -30
  39. data/lib/steep/subtyping/relation.rb +60 -0
  40. data/lib/steep/subtyping/result.rb +190 -16
  41. data/lib/steep/type_construction.rb +576 -395
  42. data/lib/steep/type_inference/block_params.rb +31 -3
  43. data/lib/steep/type_inference/context.rb +37 -3
  44. data/lib/steep/version.rb +1 -1
  45. data/lib/steep.rb +3 -3
  46. data/sample/lib/length.rb +35 -0
  47. data/sample/sig/length.rbs +34 -0
  48. data/smoke/diagnostics/proc_type_expected.rb +3 -0
  49. data/smoke/diagnostics/test_expectations.yml +12 -0
  50. data/smoke/diagnostics-rbs/nonregular-type-alias.rbs +3 -0
  51. data/smoke/diagnostics-rbs/recursive-type-alias.rbs +3 -0
  52. data/smoke/diagnostics-rbs/test_expectations.yml +57 -12
  53. data/smoke/tsort/a.rb +1 -1
  54. data/smoke/tsort/test_expectations.yml +1 -63
  55. data/steep.gemspec +1 -1
  56. metadata +14 -10
  57. data/lib/steep/drivers/trace_printer.rb +0 -29
  58. data/lib/steep/interface/method.rb +0 -78
  59. data/lib/steep/subtyping/trace.rb +0 -71
@@ -63,6 +63,10 @@ module Steep
63
63
  context.type_env
64
64
  end
65
65
 
66
+ def variable_context
67
+ context.variable_context
68
+ end
69
+
66
70
  def initialize(checker:, source:, annotations:, typing:, context:)
67
71
  @checker = checker
68
72
  @source = source
@@ -114,13 +118,15 @@ module Steep
114
118
  def check_relation(sub_type:, super_type:, constraints: Subtyping::Constraints.empty)
115
119
  Steep.logger.debug { "check_relation: self:#{self_type}, instance:#{module_context.instance_type}, class:#{module_context.module_type} |- #{sub_type} <: #{super_type}" }
116
120
  relation = Subtyping::Relation.new(sub_type: sub_type, super_type: super_type)
117
- checker.check(
118
- relation,
119
- self_type: self_type,
120
- instance_type: module_context.instance_type,
121
- class_type: module_context.module_type,
122
- constraints: constraints
123
- )
121
+ checker.push_variable_bounds(variable_context.upper_bounds) do
122
+ checker.check(
123
+ relation,
124
+ self_type: self_type,
125
+ instance_type: module_context.instance_type,
126
+ class_type: module_context.module_type,
127
+ constraints: constraints
128
+ )
129
+ end
124
130
  end
125
131
 
126
132
  def for_new_method(method_name, node, args:, self_type:, definition:)
@@ -169,6 +175,12 @@ module Steep
169
175
  end
170
176
  end
171
177
 
178
+ if definition && method_type
179
+ variable_context = TypeInference::Context::TypeVariableContext.new(method_type.type_params, parent_context: self.variable_context)
180
+ else
181
+ variable_context = self.variable_context
182
+ end
183
+
172
184
  method_context = TypeInference::Context::MethodContext.new(
173
185
  name: method_name,
174
186
  method: definition && definition.methods[method_name],
@@ -245,12 +257,20 @@ module Steep
245
257
  self_type: annots.self_type || self_type,
246
258
  type_env: type_env,
247
259
  lvar_env: lvar_env,
248
- call_context: call_context
260
+ call_context: call_context,
261
+ variable_context: variable_context
249
262
  ),
250
263
  typing: typing,
251
264
  )
252
265
  end
253
266
 
267
+ def with_method_constr(method_name, node, args:, self_type:, definition:)
268
+ constr = for_new_method(method_name, node, args: args, self_type: self_type, definition: definition)
269
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
270
+ yield constr
271
+ end
272
+ end
273
+
254
274
  def implement_module(module_name:, super_name: nil, annotations:)
255
275
  if (annotation = annotations.implement_module_annotation)
256
276
  absolute_name(annotation.name.name).yield_self do |absolute_name|
@@ -332,7 +352,10 @@ module Steep
332
352
  module_context = default_module_context(implement_module_name, const_env: module_const_env, current_namespace: new_namespace)
333
353
 
334
354
  unless implement_module_name
335
- module_context = module_context.update(module_type: AST::Builtin::Module.instance_type)
355
+ module_context = module_context.update(
356
+ module_type: AST::Builtin::Module.instance_type,
357
+ instance_type: AST::Builtin::BasicObject.instance_type
358
+ )
336
359
  end
337
360
 
338
361
  if implement_module_name
@@ -389,6 +412,21 @@ module Steep
389
412
  class_type: module_context.module_type
390
413
  ).annotate(annots)
391
414
 
415
+ if implement_module_name
416
+ definition = checker.factory.definition_builder.build_instance(implement_module_name.name)
417
+ type_params = definition.type_params_decl.map do |param|
418
+ Interface::TypeParam.new(
419
+ name: param.name,
420
+ upper_bound: checker.factory.type_opt(param.upper_bound),
421
+ variance: param.variance,
422
+ unchecked: param.unchecked?
423
+ )
424
+ end
425
+ variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
426
+ else
427
+ variable_context = TypeInference::Context::TypeVariableContext.empty
428
+ end
429
+
392
430
  self.class.new(
393
431
  checker: checker,
394
432
  source: source,
@@ -402,11 +440,19 @@ module Steep
402
440
  self_type: module_context.module_type,
403
441
  type_env: module_type_env,
404
442
  lvar_env: lvar_env,
405
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
443
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
444
+ variable_context: variable_context
406
445
  )
407
446
  )
408
447
  end
409
448
 
449
+ def with_module_constr(node)
450
+ constr = for_module(node)
451
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
452
+ yield constr
453
+ end
454
+ end
455
+
410
456
  def for_class(node)
411
457
  new_class_name = module_name_from_node(node.children.first) or raise "Unexpected class name: #{node.children.first}"
412
458
  super_class_name = node.children[1] && module_name_from_node(node.children[1])
@@ -422,10 +468,7 @@ module Steep
422
468
 
423
469
  if implement_module_name
424
470
  if super_class_name && implement_module_name.name == absolute_name(super_class_name)
425
- module_context = module_context.update(
426
- instance_definition: nil,
427
- module_definition: nil
428
- )
471
+ module_context = module_context.update(instance_definition: nil, module_definition: nil)
429
472
  end
430
473
  else
431
474
  module_context = module_context.update(
@@ -446,6 +489,18 @@ module Steep
446
489
  module_context = module_context.update(module_type: annots.self_type)
447
490
  end
448
491
 
492
+ definition = checker.factory.definition_builder.build_instance(module_context.class_name)
493
+ type_params = definition.type_params_decl.map do |type_param|
494
+ Interface::TypeParam.new(
495
+ name: type_param.name,
496
+ upper_bound: type_param.upper_bound&.yield_self {|t| checker.factory.type(t) },
497
+ variance: type_param.variance,
498
+ unchecked: type_param.unchecked?,
499
+ location: type_param.location
500
+ )
501
+ end
502
+ variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
503
+
449
504
  class_type_env = TypeInference::TypeEnv.build(annotations: annots,
450
505
  subtyping: checker,
451
506
  const_env: class_const_env,
@@ -466,7 +521,8 @@ module Steep
466
521
  self_type: module_context.module_type,
467
522
  type_env: class_type_env,
468
523
  lvar_env: lvar_env,
469
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
524
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
525
+ variable_context: variable_context
470
526
  )
471
527
 
472
528
  self.class.new(
@@ -478,6 +534,24 @@ module Steep
478
534
  )
479
535
  end
480
536
 
537
+ def with_class_constr(node)
538
+ constr = for_class(node)
539
+
540
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
541
+ yield constr
542
+ end
543
+ end
544
+
545
+ def with_sclass_constr(node, type)
546
+ if constr = for_sclass(node, type)
547
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
548
+ yield constr
549
+ end
550
+ else
551
+ yield nil
552
+ end
553
+ end
554
+
481
555
  def for_sclass(node, type)
482
556
  annots = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
483
557
 
@@ -554,7 +628,8 @@ module Steep
554
628
  self_type: module_context.module_type,
555
629
  type_env: type_env,
556
630
  lvar_env: lvar_env,
557
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
631
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
632
+ variable_context: TypeInference::Context::TypeVariableContext.empty # Assuming `::Class` and `::Module` don't have type params
558
633
  )
559
634
 
560
635
  self.class.new(
@@ -930,7 +1005,7 @@ module Steep
930
1005
  yield_self do
931
1006
  send_node, params, body = node.children
932
1007
  if send_node.type == :lambda
933
- type_lambda(node, block_params: params, block_body: body, type_hint: hint)
1008
+ type_lambda(node, params_node: params, body_node: body, type_hint: hint)
934
1009
  else
935
1010
  type_send(node, send_node: send_node, block_params: params, block_body: body, unwrap: send_node.type == :csend)
936
1011
  end
@@ -949,7 +1024,7 @@ module Steep
949
1024
  params = Parser::AST::Node.new(:args, arg_nodes)
950
1025
 
951
1026
  if send_node.type == :lambda
952
- type_lambda(node, block_params: params, block_body: body, type_hint: hint)
1027
+ type_lambda(node, params_node: params, body_node: body, type_hint: hint)
953
1028
  else
954
1029
  type_send(node, send_node: send_node, block_params: params, block_body: body, unwrap: send_node.type == :csend)
955
1030
  end
@@ -959,75 +1034,76 @@ module Steep
959
1034
  yield_self do
960
1035
  name, args_node, body_node = node.children
961
1036
 
962
- new = for_new_method(
1037
+ with_method_constr(
963
1038
  name,
964
1039
  node,
965
1040
  args: args_node.children,
966
1041
  self_type: module_context&.instance_type,
967
1042
  definition: module_context&.instance_definition
968
- )
969
- new.typing.add_context_for_node(node, context: new.context)
970
- new.typing.add_context_for_body(node, context: new.context)
971
-
972
- new.method_context.tap do |method_context|
973
- if method_context.method
974
- method_name = InstanceMethodName.new(type_name: method_context.method.implemented_in, method_name: name)
975
- new.typing.source_index.add_definition(method: method_name, definition: node)
1043
+ ) do |new|
1044
+ new.typing.add_context_for_node(node, context: new.context)
1045
+ new.typing.add_context_for_body(node, context: new.context)
1046
+
1047
+ new.method_context.tap do |method_context|
1048
+ if method_context.method
1049
+ method_name = InstanceMethodName.new(type_name: method_context.method.implemented_in, method_name: name)
1050
+ new.typing.source_index.add_definition(method: method_name, definition: node)
1051
+ end
976
1052
  end
977
- end
978
-
979
- new = new.synthesize_children(args_node)
980
1053
 
981
- body_pair = if body_node
982
- return_type = expand_alias(new.method_context&.return_type)
983
- if return_type && !return_type.is_a?(AST::Types::Void)
984
- new.check(body_node, return_type) do |_, actual_type, result|
985
- typing.add_error(
986
- Diagnostic::Ruby::MethodBodyTypeMismatch.new(
987
- node: node,
988
- expected: new.method_context&.return_type,
989
- actual: actual_type,
990
- result: result
1054
+ new = new.synthesize_children(args_node)
1055
+
1056
+ body_pair = if body_node
1057
+ return_type = expand_alias(new.method_context&.return_type)
1058
+ if return_type && !return_type.is_a?(AST::Types::Void)
1059
+ new.check(body_node, return_type) do |_, actual_type, result|
1060
+ typing.add_error(
1061
+ Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1062
+ node: node,
1063
+ expected: new.method_context&.return_type,
1064
+ actual: actual_type,
1065
+ result: result
1066
+ )
991
1067
  )
992
- )
1068
+ end
1069
+ else
1070
+ new.synthesize(body_node)
993
1071
  end
994
1072
  else
995
- new.synthesize(body_node)
996
- end
997
- else
998
- return_type = expand_alias(new.method_context&.return_type)
999
- if return_type && !return_type.is_a?(AST::Types::Void)
1000
- result = check_relation(sub_type: AST::Builtin.nil_type, super_type: return_type)
1001
- if result.failure?
1002
- typing.add_error(
1003
- Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1004
- node: node,
1005
- expected: new.method_context&.return_type,
1006
- actual: AST::Builtin.nil_type,
1007
- result: result
1073
+ return_type = expand_alias(new.method_context&.return_type)
1074
+ if return_type && !return_type.is_a?(AST::Types::Void)
1075
+ result = check_relation(sub_type: AST::Builtin.nil_type, super_type: return_type)
1076
+ if result.failure?
1077
+ typing.add_error(
1078
+ Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1079
+ node: node,
1080
+ expected: new.method_context&.return_type,
1081
+ actual: AST::Builtin.nil_type,
1082
+ result: result
1083
+ )
1008
1084
  )
1009
- )
1085
+ end
1010
1086
  end
1087
+
1088
+ Pair.new(type: AST::Builtin.nil_type, constr: new)
1011
1089
  end
1012
1090
 
1013
- Pair.new(type: AST::Builtin.nil_type, constr: new)
1014
- end
1091
+ if body_node
1092
+ # Add context to ranges from the end of the method body to the beginning of the `end` keyword
1093
+ if node.loc.end
1094
+ # Skip end-less def
1095
+ begin_pos = body_node.loc.expression.end_pos
1096
+ end_pos = node.loc.end.begin_pos
1097
+ typing.add_context(begin_pos..end_pos, context: body_pair.context)
1098
+ end
1099
+ end
1015
1100
 
1016
- if body_node
1017
- # Add context to ranges from the end of the method body to the beginning of the `end` keyword
1018
- if node.loc.end
1019
- # Skip end-less def
1020
- begin_pos = body_node.loc.expression.end_pos
1021
- end_pos = node.loc.end.begin_pos
1022
- typing.add_context(begin_pos..end_pos, context: body_pair.context)
1101
+ if module_context
1102
+ module_context.defined_instance_methods << node.children[0]
1023
1103
  end
1024
- end
1025
1104
 
1026
- if module_context
1027
- module_context.defined_instance_methods << node.children[0]
1105
+ add_typing(node, type: AST::Builtin::Symbol.instance_type)
1028
1106
  end
1029
-
1030
- add_typing(node, type: AST::Builtin::Symbol.instance_type)
1031
1107
  end
1032
1108
 
1033
1109
  when :defs
@@ -1105,16 +1181,18 @@ module Steep
1105
1181
 
1106
1182
  when :return
1107
1183
  yield_self do
1108
- if node.children.size > 0
1109
- method_return_type = expand_alias(method_context&.return_type)
1184
+ method_return_type = expand_alias(method_context&.return_type)
1110
1185
 
1186
+ if node.children.size > 0
1111
1187
  return_types = node.children.map do |value|
1112
- synthesize(value,
1113
- hint: if method_return_type.is_a?(AST::Types::Void)
1114
- nil
1115
- else
1116
- method_return_type
1117
- end).type
1188
+ synthesize(
1189
+ value,
1190
+ hint: if method_return_type.is_a?(AST::Types::Void)
1191
+ nil
1192
+ else
1193
+ method_return_type
1194
+ end
1195
+ ).type
1118
1196
  end
1119
1197
 
1120
1198
  value_type = if return_types.size == 1
@@ -1123,20 +1201,23 @@ module Steep
1123
1201
  AST::Builtin::Array.instance_type(union_type(*return_types))
1124
1202
  end
1125
1203
 
1126
- if method_return_type
1127
- unless method_return_type.is_a?(AST::Types::Void)
1128
- result = check_relation(sub_type: value_type, super_type: method_return_type)
1204
+ else
1205
+ value_type = AST::Builtin.nil_type
1206
+ end
1207
+
1208
+ if method_return_type
1209
+ unless method_return_type.is_a?(AST::Types::Void)
1210
+ result = check_relation(sub_type: value_type, super_type: method_return_type)
1129
1211
 
1130
- if result.failure?
1131
- typing.add_error(
1132
- Diagnostic::Ruby::ReturnTypeMismatch.new(
1133
- node: node,
1134
- expected: method_context&.return_type,
1135
- actual: value_type,
1136
- result: result
1137
- )
1212
+ if result.failure?
1213
+ typing.add_error(
1214
+ Diagnostic::Ruby::ReturnTypeMismatch.new(
1215
+ node: node,
1216
+ expected: method_context&.return_type,
1217
+ actual: value_type,
1218
+ result: result
1138
1219
  )
1139
- end
1220
+ )
1140
1221
  end
1141
1222
  end
1142
1223
  end
@@ -1427,7 +1508,7 @@ module Steep
1427
1508
  end
1428
1509
  _, constr = constr.synthesize(sup) if sup
1429
1510
 
1430
- constr.for_class(node).tap do |constructor|
1511
+ with_class_constr(node) do |constructor|
1431
1512
  if module_type = constructor.module_context&.module_type
1432
1513
  _, constructor = constructor.add_typing(name, type: module_type)
1433
1514
  else
@@ -1467,7 +1548,7 @@ module Steep
1467
1548
  _, constr = constr.synthesize(name)
1468
1549
  end
1469
1550
 
1470
- for_module(node).yield_self do |constructor|
1551
+ with_module_constr(node) do |constructor|
1471
1552
  if module_type = constructor.module_context&.module_type
1472
1553
  _, constructor = constructor.add_typing(name, type: module_type)
1473
1554
  else
@@ -1495,27 +1576,28 @@ module Steep
1495
1576
  when :sclass
1496
1577
  yield_self do
1497
1578
  type, constr = synthesize(node.children[0])
1498
- constructor = constr.for_sclass(node, type)
1499
1579
 
1500
- unless constructor
1501
- typing.add_error(
1502
- Diagnostic::Ruby::UnsupportedSyntax.new(
1503
- node: node,
1504
- message: "sclass receiver must be instance type or singleton type, but type given `#{type}`"
1580
+ with_sclass_constr(node, type) do |constructor|
1581
+ unless constructor
1582
+ typing.add_error(
1583
+ Diagnostic::Ruby::UnsupportedSyntax.new(
1584
+ node: node,
1585
+ message: "sclass receiver must be instance type or singleton type, but type given `#{type}`"
1586
+ )
1505
1587
  )
1506
- )
1507
- constr.add_typing(node, type: AST::Builtin.nil_type)
1508
- return
1509
- end
1588
+ constr.add_typing(node, type: AST::Builtin.nil_type)
1589
+ return
1590
+ end
1510
1591
 
1511
- constructor.typing.add_context_for_node(node, context: constructor.context)
1512
- constructor.typing.add_context_for_body(node, context: constructor.context)
1592
+ constructor.typing.add_context_for_node(node, context: constructor.context)
1593
+ constructor.typing.add_context_for_body(node, context: constructor.context)
1513
1594
 
1514
- constructor.synthesize(node.children[1]) if node.children[1]
1595
+ constructor.synthesize(node.children[1]) if node.children[1]
1515
1596
 
1516
- if constructor.module_context.instance_definition && module_context.module_definition
1517
- if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
1518
- module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
1597
+ if constructor.module_context.instance_definition && module_context.module_definition
1598
+ if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
1599
+ module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
1600
+ end
1519
1601
  end
1520
1602
  end
1521
1603
 
@@ -1573,8 +1655,7 @@ module Steep
1573
1655
  instance_type: module_context.instance_type,
1574
1656
  class_type: module_context.module_type
1575
1657
  ) do |error|
1576
- case error
1577
- when Subtyping::Result::Failure
1658
+ if error
1578
1659
  const_type = type_env.get(const: const_name)
1579
1660
  typing.add_error(
1580
1661
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -1584,7 +1665,7 @@ module Steep
1584
1665
  result: error
1585
1666
  )
1586
1667
  )
1587
- when nil
1668
+ else
1588
1669
  typing.add_error(
1589
1670
  Diagnostic::Ruby::UnknownConstantAssigned.new(
1590
1671
  node: node,
@@ -2385,7 +2466,7 @@ module Steep
2385
2466
  raise "#synthesize should return an instance of Pair: #{pair.class}, node=#{node.inspect}"
2386
2467
  end
2387
2468
  end
2388
- rescue RBS::ErrorBase => exn
2469
+ rescue RBS::BaseError => exn
2389
2470
  Steep.logger.warn { "Unexpected RBS error: #{exn.message}" }
2390
2471
  exn.backtrace.each {|loc| Steep.logger.warn " #{loc}" }
2391
2472
  typing.add_error(Diagnostic::Ruby::UnexpectedError.new(node: node, error: exn))
@@ -2411,6 +2492,7 @@ module Steep
2411
2492
 
2412
2493
  def type_ivasgn(name, rhs, node)
2413
2494
  rhs_type = synthesize(rhs, hint: type_env.get(ivar: name) { fallback_to_any(node) }).type
2495
+
2414
2496
  ivar_type = type_env.assign(
2415
2497
  ivar: name,
2416
2498
  type: rhs_type,
@@ -2418,8 +2500,7 @@ module Steep
2418
2500
  instance_type: module_context.instance_type,
2419
2501
  class_type: module_context.module_type
2420
2502
  ) do |error|
2421
- case error
2422
- when Subtyping::Result::Failure
2503
+ if error
2423
2504
  type = type_env.get(ivar: name)
2424
2505
  typing.add_error(
2425
2506
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -2429,7 +2510,7 @@ module Steep
2429
2510
  result: error
2430
2511
  )
2431
2512
  )
2432
- when nil
2513
+ else
2433
2514
  fallback_to_any node
2434
2515
  end
2435
2516
  end
@@ -2473,8 +2554,7 @@ module Steep
2473
2554
  instance_type: module_context.instance_type,
2474
2555
  class_type: module_context.module_type
2475
2556
  ) do |error|
2476
- case error
2477
- when Subtyping::Result::Failure
2557
+ if error
2478
2558
  var_type = type_env.get(ivar: ivar)
2479
2559
  typing.add_error(
2480
2560
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -2484,7 +2564,7 @@ module Steep
2484
2564
  result: error
2485
2565
  )
2486
2566
  )
2487
- when nil
2567
+ else
2488
2568
  fallback_to_any node
2489
2569
  end
2490
2570
  end
@@ -2647,34 +2727,89 @@ module Steep
2647
2727
  end
2648
2728
  end
2649
2729
 
2650
- def type_lambda(node, block_params:, block_body:, type_hint:)
2730
+ def optional_proc?(type)
2731
+ if type.is_a?(AST::Types::Union)
2732
+ if type.types.size == 2
2733
+ if type.types.find {|t| t.is_a?(AST::Types::Nil) }
2734
+ if proc_type = type.types.find {|t| t.is_a?(AST::Types::Proc) }
2735
+ proc_type
2736
+ end
2737
+ end
2738
+ end
2739
+ end
2740
+ end
2741
+
2742
+ def type_lambda(node, params_node:, body_node:, type_hint:)
2651
2743
  block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
2652
- params = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
2744
+ params = TypeInference::BlockParams.from_node(params_node, annotations: block_annotations)
2745
+
2746
+ type_hint = deep_expand_alias(type_hint) if type_hint
2653
2747
 
2654
2748
  case type_hint
2655
2749
  when AST::Types::Proc
2656
2750
  params_hint = type_hint.type.params
2657
2751
  return_hint = type_hint.type.return_type
2752
+ block_hint = type_hint.block
2658
2753
  end
2659
2754
 
2660
2755
  block_constr = for_block(
2661
2756
  block_params: params,
2662
2757
  block_param_hint: params_hint,
2663
2758
  block_type_hint: return_hint,
2759
+ block_block_hint: block_hint,
2664
2760
  block_annotations: block_annotations,
2665
2761
  node_type_hint: nil
2666
2762
  )
2667
2763
 
2668
2764
  block_constr.typing.add_context_for_body(node, context: block_constr.context)
2669
2765
 
2766
+ default_proc_function =
2767
+ Interface::Function.new(
2768
+ params: Interface::Function::Params.empty,
2769
+ return_type: AST::Builtin.any_type,
2770
+ location: nil
2771
+ )
2772
+
2670
2773
  params.params.each do |param|
2671
2774
  _, block_constr = block_constr.synthesize(param.node, hint: param.type)
2672
2775
  end
2673
2776
 
2674
- if block_body
2777
+ block =
2778
+ if block_param = params.block_param
2779
+ if block_param_type = block_param.type
2780
+ case block_param_type
2781
+ when AST::Types::Proc
2782
+ Interface::Block.new(type: block_param_type.type, optional: false)
2783
+ else
2784
+ if proc_type = optional_proc?(block_param_type)
2785
+ Interface::Block.new(type: proc_type.type, optional: true)
2786
+ else
2787
+ block_constr.typing.add_error(
2788
+ Diagnostic::Ruby::ProcTypeExpected.new(
2789
+ node: block_param.node,
2790
+ type: block_param_type
2791
+ )
2792
+ )
2793
+
2794
+ Interface::Block.new(
2795
+ type: Interface::Function.new(
2796
+ params: Interface::Function::Params.empty,
2797
+ return_type: AST::Builtin.any_type,
2798
+ location: nil
2799
+ ),
2800
+ optional: false
2801
+ )
2802
+ end
2803
+ end
2804
+ else
2805
+ block_hint
2806
+ end
2807
+ end
2808
+
2809
+ if body_node
2675
2810
  return_type = block_constr.synthesize_block(
2676
2811
  node: node,
2677
- block_body: block_body,
2812
+ block_body: body_node,
2678
2813
  block_type_hint: return_hint
2679
2814
  )
2680
2815
 
@@ -2702,7 +2837,7 @@ module Steep
2702
2837
  return_type: return_type,
2703
2838
  location: nil
2704
2839
  ),
2705
- block: nil
2840
+ block: block
2706
2841
  )
2707
2842
 
2708
2843
  add_typing node, type: block_type
@@ -2830,6 +2965,32 @@ module Steep
2830
2965
  )
2831
2966
  )
2832
2967
 
2968
+ when AST::Types::Var
2969
+ if upper_bound = variable_context[receiver_type.name]
2970
+ interface = calculate_interface(upper_bound, private: false)
2971
+
2972
+ constr.type_send_interface(
2973
+ node,
2974
+ interface: interface,
2975
+ receiver: receiver,
2976
+ receiver_type: receiver_type,
2977
+ method_name: method_name,
2978
+ arguments: arguments,
2979
+ block_params: block_params,
2980
+ block_body: block_body
2981
+ )
2982
+ else
2983
+ constr = constr.synthesize_children(node, skips: [receiver])
2984
+ constr.add_call(
2985
+ TypeInference::MethodCall::NoMethodError.new(
2986
+ node: node,
2987
+ context: context.method_context,
2988
+ method_name: method_name,
2989
+ receiver_type: receiver_type,
2990
+ error: Diagnostic::Ruby::NoMethod.new(node: node, method: method_name, type: receiver_type)
2991
+ )
2992
+ )
2993
+ end
2833
2994
  when AST::Types::Void, AST::Types::Bot, AST::Types::Top, AST::Types::Var
2834
2995
  constr = constr.synthesize_children(node, skips: [receiver])
2835
2996
  constr.add_call(
@@ -2872,6 +3033,7 @@ module Steep
2872
3033
  block_params: block_params,
2873
3034
  block_body: block_body)
2874
3035
  end
3036
+
2875
3037
  else
2876
3038
  interface = calculate_interface(receiver_type, private: private, self_type: receiver_type)
2877
3039
 
@@ -2991,119 +3153,166 @@ module Steep
2991
3153
  end
2992
3154
  end
2993
3155
 
3156
+ def apply_solution(errors, node:, method_type:)
3157
+ subst = yield
3158
+
3159
+ [
3160
+ method_type.subst(subst),
3161
+ true,
3162
+ subst
3163
+ ]
3164
+
3165
+ rescue Subtyping::Constraints::UnsatisfiableConstraint => exn
3166
+ errors << Diagnostic::Ruby::UnsatisfiableConstraint.new(
3167
+ node: node,
3168
+ method_type: method_type,
3169
+ var: exn.var,
3170
+ sub_type: exn.sub_type,
3171
+ super_type: exn.super_type,
3172
+ result: exn.result
3173
+ )
3174
+ [method_type, false, Interface::Substitution.empty]
3175
+ end
3176
+
3177
+ def eliminate_vars(type, variables, to: AST::Builtin.any_type)
3178
+ if variables.empty?
3179
+ type
3180
+ else
3181
+ subst = Interface::Substitution.build(variables, Array.new(variables.size, to))
3182
+ type.subst(subst)
3183
+ end
3184
+ end
3185
+
2994
3186
  def try_method_type(node, receiver_type:, method_name:, method_type:, arguments:, block_params:, block_body:, topdown_hint:)
2995
- fresh_types = method_type.type_params.map {|x| AST::Types::Var.fresh(x)}
2996
- fresh_vars = Set.new(fresh_types.map(&:name))
2997
- instantiation = Interface::Substitution.build(method_type.type_params, fresh_types)
3187
+ type_params, instantiation = Interface::TypeParam.rename(method_type.type_params)
3188
+ type_param_names = type_params.map(&:name)
2998
3189
 
2999
3190
  constr = self
3000
3191
 
3001
3192
  method_type = method_type.instantiate(instantiation)
3002
3193
 
3003
- constraints = Subtyping::Constraints.new(unknowns: fresh_types.map(&:name))
3004
3194
  variance = Subtyping::VariableVariance.from_method_type(method_type)
3005
3195
  occurence = Subtyping::VariableOccurence.from_method_type(method_type)
3196
+ constraints = Subtyping::Constraints.new(unknowns: type_params.map(&:name))
3197
+ ccontext = Subtyping::Constraints::Context.new(
3198
+ self_type: self_type,
3199
+ instance_type: module_context.instance_type,
3200
+ class_type: module_context.module_type,
3201
+ variance: variance
3202
+ )
3006
3203
 
3007
- errors = []
3204
+ upper_bounds = {}
3008
3205
 
3009
- args = TypeInference::SendArgs.new(node: node, arguments: arguments, method_name: method_name, method_type: method_type)
3010
- es = args.each do |arg|
3011
- case arg
3012
- when TypeInference::SendArgs::PositionalArgs::NodeParamPair
3013
- _, constr = constr.type_check_argument(
3014
- arg.node,
3015
- type: arg.param.type,
3016
- receiver_type: receiver_type,
3017
- constraints: constraints,
3018
- errors: errors
3019
- )
3206
+ type_params.each do |param|
3207
+ if ub = param.upper_bound
3208
+ constraints.add(param.name, super_type: ub, skip: true)
3209
+ upper_bounds[param.name] = ub
3210
+ end
3211
+ end
3020
3212
 
3021
- when TypeInference::SendArgs::PositionalArgs::NodeTypePair
3022
- _, constr = bypass_splat(arg.node) do |n|
3023
- constr.type_check_argument(
3024
- n,
3025
- type: arg.node_type,
3213
+ checker.push_variable_bounds(upper_bounds) do
3214
+ errors = []
3215
+
3216
+ args = TypeInference::SendArgs.new(node: node, arguments: arguments, method_name: method_name, method_type: method_type)
3217
+ es = args.each do |arg|
3218
+ case arg
3219
+ when TypeInference::SendArgs::PositionalArgs::NodeParamPair
3220
+ _, constr = constr.type_check_argument(
3221
+ arg.node,
3222
+ type: arg.param.type,
3026
3223
  receiver_type: receiver_type,
3027
3224
  constraints: constraints,
3028
- report_node: arg.node,
3029
3225
  errors: errors
3030
3226
  )
3031
- end
3032
3227
 
3033
- when TypeInference::SendArgs::PositionalArgs::UnexpectedArg
3034
- _, constr = bypass_splat(arg.node) do |n|
3035
- constr.synthesize(n)
3036
- end
3037
-
3038
- when TypeInference::SendArgs::PositionalArgs::SplatArg
3039
- arg_type, _ = constr
3040
- .with_child_typing(range: arg.node.loc.expression.begin_pos ... arg.node.loc.expression.end_pos)
3041
- .try_tuple_type!(arg.node.children[0])
3042
- arg.type = arg_type
3043
-
3044
- when TypeInference::SendArgs::PositionalArgs::MissingArg
3045
- # ignore
3046
-
3047
- when TypeInference::SendArgs::KeywordArgs::ArgTypePairs
3048
- arg.pairs.each do |node, type|
3049
- _, constr = bypass_splat(node) do |node|
3228
+ when TypeInference::SendArgs::PositionalArgs::NodeTypePair
3229
+ _, constr = bypass_splat(arg.node) do |n|
3050
3230
  constr.type_check_argument(
3051
- node,
3052
- type: type,
3231
+ n,
3232
+ type: arg.node_type,
3053
3233
  receiver_type: receiver_type,
3054
3234
  constraints: constraints,
3235
+ report_node: arg.node,
3055
3236
  errors: errors
3056
3237
  )
3057
3238
  end
3058
- end
3059
3239
 
3060
- when TypeInference::SendArgs::KeywordArgs::UnexpectedKeyword
3061
- if arg.node.type == :pair
3062
- arg.node.children.each do |nn|
3063
- _, constr = constr.synthesize(nn)
3064
- end
3065
- else
3240
+ when TypeInference::SendArgs::PositionalArgs::UnexpectedArg
3066
3241
  _, constr = bypass_splat(arg.node) do |n|
3067
3242
  constr.synthesize(n)
3068
3243
  end
3069
- end
3070
3244
 
3071
- when TypeInference::SendArgs::KeywordArgs::SplatArg
3072
- type, _ = bypass_splat(arg.node) do |sp_node|
3073
- if sp_node.type == :hash
3074
- pair = constr.type_hash_record(sp_node, nil) and break pair
3245
+ when TypeInference::SendArgs::PositionalArgs::SplatArg
3246
+ arg_type, _ = constr
3247
+ .with_child_typing(range: arg.node.loc.expression.begin_pos ... arg.node.loc.expression.end_pos)
3248
+ .try_tuple_type!(arg.node.children[0])
3249
+ arg.type = arg_type
3250
+
3251
+ when TypeInference::SendArgs::PositionalArgs::MissingArg
3252
+ # ignore
3253
+
3254
+ when TypeInference::SendArgs::KeywordArgs::ArgTypePairs
3255
+ arg.pairs.each do |node, type|
3256
+ _, constr = bypass_splat(node) do |node|
3257
+ constr.type_check_argument(
3258
+ node,
3259
+ type: type,
3260
+ receiver_type: receiver_type,
3261
+ constraints: constraints,
3262
+ errors: errors
3263
+ )
3264
+ end
3075
3265
  end
3076
3266
 
3077
- constr.synthesize(sp_node)
3078
- end
3267
+ when TypeInference::SendArgs::KeywordArgs::UnexpectedKeyword
3268
+ if arg.node.type == :pair
3269
+ arg.node.children.each do |nn|
3270
+ _, constr = constr.synthesize(nn)
3271
+ end
3272
+ else
3273
+ _, constr = bypass_splat(arg.node) do |n|
3274
+ constr.synthesize(n)
3275
+ end
3276
+ end
3277
+
3278
+ when TypeInference::SendArgs::KeywordArgs::SplatArg
3279
+ type, _ = bypass_splat(arg.node) do |sp_node|
3280
+ if sp_node.type == :hash
3281
+ pair = constr.type_hash_record(sp_node, nil) and break pair
3282
+ end
3079
3283
 
3080
- arg.type = type
3284
+ constr.synthesize(sp_node)
3285
+ end
3081
3286
 
3082
- when TypeInference::SendArgs::KeywordArgs::MissingKeyword
3083
- # ignore
3084
- else
3085
- raise arg.inspect
3086
- end
3287
+ arg.type = type
3087
3288
 
3088
- constr
3089
- end
3289
+ when TypeInference::SendArgs::KeywordArgs::MissingKeyword
3290
+ # ignore
3291
+ else
3292
+ raise arg.inspect
3293
+ end
3090
3294
 
3091
- errors.push(*es)
3295
+ constr
3296
+ end
3297
+
3298
+ errors.push(*es)
3092
3299
 
3093
- if block_params
3094
- # block is given
3095
- block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
3096
- block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
3300
+ if block_params
3301
+ # block is given
3302
+ block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
3303
+ block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
3097
3304
 
3098
- if method_type.block
3099
- pairs = method_type.block && block_params_&.zip(method_type.block.type.params)
3305
+ if method_type.block
3306
+ # Method accepts block
3307
+ pairs = method_type.block && block_params_&.zip(method_type.block.type.params, nil)
3100
3308
 
3101
- if pairs
3102
- begin
3309
+ if pairs
3310
+ # Block parameters are compatible with the block type
3103
3311
  block_constr = constr.for_block(
3104
3312
  block_params: block_params_,
3105
3313
  block_param_hint: method_type.block.type.params,
3106
3314
  block_type_hint: method_type.block.type.return_type,
3315
+ block_block_hint: nil,
3107
3316
  block_annotations: block_annotations,
3108
3317
  node_type_hint: method_type.type.return_type
3109
3318
  )
@@ -3131,236 +3340,206 @@ module Steep
3131
3340
  end
3132
3341
  end
3133
3342
 
3134
- s = constraints.solution(
3135
- checker,
3136
- self_type: self_type,
3137
- instance_type: module_context.instance_type,
3138
- class_type: module_context.module_type,
3139
- variance: variance,
3140
- variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables
3141
- )
3142
- method_type = method_type.subst(s)
3143
- block_constr = block_constr.update_lvar_env {|env| env.subst(s) }
3144
- if block_body
3145
- block_body_type = block_constr.synthesize_block(
3146
- node: node,
3147
- block_body: block_body,
3148
- block_type_hint: method_type.block.type.return_type
3343
+ method_type, solved, s = apply_solution(errors, node: node, method_type: method_type) {
3344
+ constraints.solution(
3345
+ checker,
3346
+ variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables,
3347
+ context: ccontext
3149
3348
  )
3150
- else
3151
- block_body_type = AST::Builtin.nil_type
3152
- end
3349
+ }
3153
3350
 
3154
- result = check_relation(sub_type: block_body_type,
3155
- super_type: method_type.block.type.return_type,
3156
- constraints: constraints)
3351
+ if solved
3352
+ # Ready for type check the body of the block
3353
+ block_constr = block_constr.update_lvar_env {|env| env.subst(s) }
3354
+ if block_body
3355
+ block_body_type = block_constr.synthesize_block(
3356
+ node: node,
3357
+ block_body: block_body,
3358
+ block_type_hint: method_type.block.type.return_type
3359
+ )
3360
+ else
3361
+ block_body_type = AST::Builtin.nil_type
3362
+ end
3157
3363
 
3158
- case result
3159
- when Subtyping::Result::Success
3160
- s = constraints.solution(
3161
- checker,
3162
- self_type: self_type,
3163
- instance_type: module_context.instance_type,
3164
- class_type: module_context.module_type,
3165
- variance: variance,
3166
- variables: fresh_vars
3167
- )
3168
- method_type = method_type.subst(s)
3364
+ result = check_relation(sub_type: block_body_type,
3365
+ super_type: method_type.block.type.return_type,
3366
+ constraints: constraints)
3169
3367
 
3170
- return_type = method_type.type.return_type
3171
- if break_type = block_annotations.break_type
3172
- return_type = union_type(break_type, return_type)
3368
+ if result.success?
3369
+ # Successfully type checked the body
3370
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) do
3371
+ constraints.solution(checker, variables: type_param_names, context: ccontext)
3372
+ end
3373
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3374
+
3375
+ return_type = method_type.type.return_type
3376
+ if break_type = block_annotations.break_type
3377
+ return_type = union_type(break_type, return_type)
3378
+ end
3379
+ else
3380
+ # The block body has incompatible type
3381
+ errors << Diagnostic::Ruby::BlockBodyTypeMismatch.new(
3382
+ node: node,
3383
+ expected: method_type.block.type.return_type,
3384
+ actual: block_body_type,
3385
+ result: result
3386
+ )
3387
+
3388
+ method_type = eliminate_vars(method_type, type_param_names)
3389
+ return_type = method_type.type.return_type
3173
3390
  end
3174
3391
 
3175
- when Subtyping::Result::Failure
3176
- errors << Diagnostic::Ruby::BlockBodyTypeMismatch.new(
3177
- node: node,
3178
- expected: method_type.block.type.return_type,
3179
- actual: block_body_type,
3180
- result: result
3181
- )
3392
+ block_constr.typing.save!
3393
+ else
3394
+ # Failed to infer the type of block parameters
3395
+ constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3396
+ errors << error
3397
+ end
3182
3398
 
3399
+ method_type = eliminate_vars(method_type, type_param_names)
3183
3400
  return_type = method_type.type.return_type
3184
3401
  end
3185
-
3186
- block_constr.typing.save!
3187
-
3188
- rescue Subtyping::Constraints::UnsatisfiableConstraint => exn
3189
- errors << Diagnostic::Ruby::UnsatisfiableConstraint.new(
3190
- node: node,
3191
- method_type: method_type,
3192
- var: exn.var,
3193
- sub_type: exn.sub_type,
3194
- super_type: exn.super_type,
3195
- result: exn.result
3402
+ else
3403
+ # Block parameters are unsupported syntax
3404
+ errors << Diagnostic::Ruby::UnsupportedSyntax.new(
3405
+ node: block_params,
3406
+ message: "Unsupported block params pattern, probably masgn?"
3196
3407
  )
3197
3408
 
3198
- constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3199
- errors << error
3200
- end
3409
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3410
+ constraints.solution(checker, variables: type_param_names, context: ccontext)
3411
+ }
3412
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3201
3413
 
3202
- s = Interface::Substitution.build(method_type.free_variables,
3203
- Array.new(method_type.free_variables.size, AST::Builtin.any_type))
3204
- method_type = method_type.subst(s)
3414
+ return_type = method_type.type.return_type
3205
3415
  end
3206
3416
  else
3207
- errors << Diagnostic::Ruby::UnsupportedSyntax.new(
3208
- node: block_params,
3209
- message: "Unsupported block params pattern, probably masgn?"
3210
- )
3417
+ # Block is given but method doesn't accept
3418
+ #
3419
+ constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3420
+ errors << error
3421
+ end
3211
3422
 
3212
- s = constraints.solution(
3213
- checker,
3214
- variance: variance,
3215
- variables: fresh_vars,
3216
- self_type: self_type,
3217
- instance_type: module_context.instance_type,
3218
- class_type: module_context.module_type
3423
+ errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3424
+ node: node,
3425
+ method_type: method_type
3219
3426
  )
3220
- method_type = method_type.subst(s)
3427
+
3428
+ method_type = eliminate_vars(method_type, type_param_names)
3429
+ return_type = method_type.type.return_type
3221
3430
  end
3222
3431
  else
3223
- # Block is given but method doesn't accept
3224
- #
3225
- constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3226
- errors << error
3227
- end
3432
+ # Block syntax is not given
3433
+ arg = args.block_pass_arg
3434
+
3435
+ case
3436
+ when arg.compatible?
3437
+ if arg.node
3438
+ # Block pass (&block) is given
3439
+ node_type, constr = constr.synthesize(arg.node, hint: arg.node_type)
3440
+
3441
+ nil_given =
3442
+ constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success? &&
3443
+ !node_type.is_a?(AST::Types::Any)
3444
+
3445
+ if nil_given
3446
+ # nil is given ==> no block arg node is given
3447
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3448
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3449
+ }
3450
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3228
3451
 
3229
- errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3230
- node: node,
3231
- method_type: method_type
3232
- )
3233
- end
3234
- else
3235
- arg = args.block_pass_arg
3452
+ # Passing no block
3453
+ errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3454
+ node: node,
3455
+ method_type: method_type
3456
+ )
3457
+ else
3458
+ # non-nil value is given
3459
+ constr.check_relation(sub_type: node_type, super_type: arg.node_type, constraints: constraints).else do |result|
3460
+ errors << Diagnostic::Ruby::BlockTypeMismatch.new(
3461
+ node: arg.node,
3462
+ expected: arg.node_type,
3463
+ actual: node_type,
3464
+ result: result
3465
+ )
3466
+ end
3236
3467
 
3237
- case
3238
- when arg.compatible?
3239
- if arg.node
3240
- subst = constraints.solution(
3241
- checker,
3242
- self_type: self_type,
3243
- instance_type: module_context.instance_type,
3244
- class_type: module_context.module_type,
3245
- variance: variance,
3246
- variables: occurence.params
3247
- )
3468
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3469
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3470
+ }
3471
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3472
+ end
3473
+ else
3474
+ # Block is not given
3475
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3476
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3477
+ }
3478
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3479
+ end
3248
3480
 
3249
- block_type = arg.node_type.subst(subst)
3481
+ return_type = method_type.type.return_type
3250
3482
 
3251
- node_type, constr = constr.synthesize(arg.node, hint: block_type)
3252
- nil_block =
3253
- constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success? &&
3254
- !node_type.is_a?(AST::Types::Any)
3483
+ when arg.block_missing?
3484
+ # Block is required but not given
3485
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3486
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3487
+ }
3255
3488
 
3256
- unless nil_block
3257
- constr.check_relation(sub_type: node_type, super_type: block_type, constraints: constraints).else do |result|
3258
- errors << Diagnostic::Ruby::BlockTypeMismatch.new(
3259
- node: arg.node,
3260
- expected: block_type,
3261
- actual: node_type,
3262
- result: result
3263
- )
3264
- end
3265
- end
3489
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3490
+ return_type = method_type.type.return_type
3266
3491
 
3267
- subst = constraints.solution(
3268
- checker,
3269
- self_type: self_type,
3270
- instance_type: module_context.instance_type,
3271
- class_type: module_context.module_type,
3272
- variance: variance,
3273
- variables: method_type.free_variables
3492
+ errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3493
+ node: node,
3494
+ method_type: method_type
3274
3495
  )
3275
3496
 
3276
- method_type = method_type.subst(subst)
3497
+ when arg.unexpected_block?
3498
+ # Unexpected block is given
3499
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3500
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3501
+ }
3502
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3503
+ return_type = method_type.type.return_type
3277
3504
 
3278
- if nil_block && arg.block.required?
3279
- # Passing no block
3280
- errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3505
+ node_type, constr = constr.synthesize(arg.node)
3506
+
3507
+ unless constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success?
3508
+ errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3281
3509
  node: node,
3282
3510
  method_type: method_type
3283
3511
  )
3284
3512
  end
3285
- else
3286
- subst = constraints.solution(
3287
- checker,
3288
- self_type: self_type,
3289
- instance_type: module_context.instance_type,
3290
- class_type: module_context.module_type,
3291
- variance: variance,
3292
- variables: method_type.free_variables
3293
- )
3294
-
3295
- method_type = method_type.subst(subst)
3296
- end
3297
-
3298
- when arg.block_missing?
3299
- subst = constraints.solution(
3300
- checker,
3301
- self_type: self_type,
3302
- instance_type: module_context.instance_type,
3303
- class_type: module_context.module_type,
3304
- variance: variance,
3305
- variables: method_type.free_variables
3306
- )
3307
-
3308
- method_type = method_type.subst(subst)
3309
-
3310
- errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3311
- node: node,
3312
- method_type: method_type
3313
- )
3314
-
3315
- when arg.unexpected_block?
3316
- subst = constraints.solution(
3317
- checker,
3318
- self_type: self_type,
3319
- instance_type: module_context.instance_type,
3320
- class_type: module_context.module_type,
3321
- variance: variance,
3322
- variables: method_type.free_variables
3323
- )
3324
-
3325
- method_type = method_type.subst(subst)
3326
-
3327
- node_type, constr = constr.synthesize(arg.node)
3328
-
3329
- unless constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success?
3330
- errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3331
- node: node,
3332
- method_type: method_type
3333
- )
3334
3513
  end
3335
3514
  end
3336
- end
3337
-
3338
- call = if errors.empty?
3339
- TypeInference::MethodCall::Typed.new(
3340
- node: node,
3341
- context: context.method_context,
3342
- receiver_type: receiver_type,
3343
- method_name: method_name,
3344
- actual_method_type: method_type,
3345
- return_type: return_type || method_type.type.return_type,
3346
- method_decls: method_type.method_decls
3347
- )
3348
- else
3349
- TypeInference::MethodCall::Error.new(
3350
- node: node,
3351
- context: context.method_context,
3352
- receiver_type: receiver_type,
3353
- method_name: method_name,
3354
- return_type: return_type || method_type.type.return_type,
3355
- method_decls: method_type.method_decls,
3356
- errors: errors
3357
- )
3358
- end
3359
3515
 
3360
- [
3361
- call,
3362
- constr
3363
- ]
3516
+ call = if errors.empty?
3517
+ TypeInference::MethodCall::Typed.new(
3518
+ node: node,
3519
+ context: context.method_context,
3520
+ receiver_type: receiver_type,
3521
+ method_name: method_name,
3522
+ actual_method_type: method_type,
3523
+ return_type: return_type || method_type.type.return_type,
3524
+ method_decls: method_type.method_decls
3525
+ )
3526
+ else
3527
+ TypeInference::MethodCall::Error.new(
3528
+ node: node,
3529
+ context: context.method_context,
3530
+ receiver_type: receiver_type,
3531
+ method_name: method_name,
3532
+ return_type: return_type || method_type.type.return_type,
3533
+ method_decls: method_type.method_decls,
3534
+ errors: errors
3535
+ )
3536
+ end
3537
+
3538
+ [
3539
+ call,
3540
+ constr
3541
+ ]
3542
+ end
3364
3543
  end
3365
3544
 
3366
3545
  def type_check_argument(node, receiver_type:, type:, constraints:, report_node: node, errors:)
@@ -3383,13 +3562,14 @@ module Steep
3383
3562
  message: "Unsupported block params pattern, probably masgn?"
3384
3563
  )
3385
3564
  )
3386
- block_params = TypeInference::BlockParams.new(leading_params: [], optional_params: [], rest_param: nil, trailing_params: [])
3565
+ block_params = TypeInference::BlockParams.new(leading_params: [], optional_params: [], rest_param: nil, trailing_params: [], block_param: nil)
3387
3566
  end
3388
3567
 
3389
3568
  block_constr = for_block(
3390
3569
  block_params: block_params,
3391
3570
  block_param_hint: nil,
3392
3571
  block_type_hint: AST::Builtin.any_type,
3572
+ block_block_hint: nil,
3393
3573
  block_annotations: block_annotations,
3394
3574
  node_type_hint: AST::Builtin.any_type
3395
3575
  )
@@ -3416,8 +3596,8 @@ module Steep
3416
3596
  end
3417
3597
  end
3418
3598
 
3419
- def for_block(block_params:, block_param_hint:, block_type_hint:, block_annotations:, node_type_hint:)
3420
- block_param_pairs = block_param_hint && block_params.zip(block_param_hint)
3599
+ def for_block(block_params:, block_param_hint:, block_type_hint:, block_block_hint:, block_annotations:, node_type_hint:)
3600
+ block_param_pairs = block_param_hint && block_params.zip(block_param_hint, block_block_hint)
3421
3601
 
3422
3602
  param_types_hash = {}
3423
3603
  if block_param_pairs
@@ -3485,7 +3665,8 @@ module Steep
3485
3665
  self_type: self_type,
3486
3666
  type_env: type_env.dup,
3487
3667
  lvar_env: lvar_env,
3488
- call_context: self.context.call_context
3668
+ call_context: self.context.call_context,
3669
+ variable_context: variable_context
3489
3670
  )
3490
3671
  )
3491
3672
  end