steep 0.47.0 → 0.49.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +23 -0
  3. data/Gemfile +1 -1
  4. data/Gemfile.lock +14 -16
  5. data/lib/steep/ast/types/any.rb +2 -0
  6. data/lib/steep/ast/types/boolean.rb +2 -0
  7. data/lib/steep/ast/types/bot.rb +2 -0
  8. data/lib/steep/ast/types/class.rb +2 -0
  9. data/lib/steep/ast/types/factory.rb +162 -138
  10. data/lib/steep/ast/types/helper.rb +8 -0
  11. data/lib/steep/ast/types/instance.rb +2 -0
  12. data/lib/steep/ast/types/intersection.rb +4 -0
  13. data/lib/steep/ast/types/literal.rb +2 -0
  14. data/lib/steep/ast/types/logic.rb +2 -0
  15. data/lib/steep/ast/types/name.rb +6 -0
  16. data/lib/steep/ast/types/nil.rb +2 -0
  17. data/lib/steep/ast/types/proc.rb +9 -0
  18. data/lib/steep/ast/types/record.rb +4 -0
  19. data/lib/steep/ast/types/self.rb +2 -0
  20. data/lib/steep/ast/types/top.rb +2 -0
  21. data/lib/steep/ast/types/tuple.rb +4 -0
  22. data/lib/steep/ast/types/union.rb +4 -0
  23. data/lib/steep/ast/types/var.rb +16 -3
  24. data/lib/steep/ast/types/void.rb +2 -0
  25. data/lib/steep/diagnostic/ruby.rb +23 -11
  26. data/lib/steep/diagnostic/signature.rb +56 -0
  27. data/lib/steep/interface/function.rb +4 -0
  28. data/lib/steep/interface/method_type.rb +14 -26
  29. data/lib/steep/interface/type_param.rb +103 -0
  30. data/lib/steep/server/base_worker.rb +1 -0
  31. data/lib/steep/server/interaction_worker.rb +1 -1
  32. data/lib/steep/server/type_check_worker.rb +2 -2
  33. data/lib/steep/services/signature_service.rb +2 -2
  34. data/lib/steep/services/type_check_service.rb +2 -1
  35. data/lib/steep/signature/validator.rb +221 -49
  36. data/lib/steep/subtyping/cache.rb +30 -0
  37. data/lib/steep/subtyping/check.rb +600 -705
  38. data/lib/steep/subtyping/constraints.rb +66 -30
  39. data/lib/steep/subtyping/relation.rb +60 -0
  40. data/lib/steep/subtyping/result.rb +190 -16
  41. data/lib/steep/type_construction.rb +543 -395
  42. data/lib/steep/type_inference/block_params.rb +31 -3
  43. data/lib/steep/type_inference/context.rb +37 -3
  44. data/lib/steep/version.rb +1 -1
  45. data/lib/steep.rb +4 -4
  46. data/sample/lib/length.rb +35 -0
  47. data/sample/sig/length.rbs +34 -0
  48. data/smoke/diagnostics-rbs/nonregular-type-alias.rbs +3 -0
  49. data/smoke/diagnostics-rbs/recursive-type-alias.rbs +3 -0
  50. data/smoke/diagnostics-rbs/test_expectations.yml +57 -12
  51. data/smoke/tsort/a.rb +1 -1
  52. data/smoke/tsort/test_expectations.yml +1 -63
  53. data/steep.gemspec +1 -1
  54. metadata +13 -10
  55. data/lib/steep/drivers/trace_printer.rb +0 -29
  56. data/lib/steep/interface/method.rb +0 -78
  57. data/lib/steep/subtyping/trace.rb +0 -71
@@ -63,6 +63,10 @@ module Steep
63
63
  context.type_env
64
64
  end
65
65
 
66
+ def variable_context
67
+ context.variable_context
68
+ end
69
+
66
70
  def initialize(checker:, source:, annotations:, typing:, context:)
67
71
  @checker = checker
68
72
  @source = source
@@ -114,13 +118,15 @@ module Steep
114
118
  def check_relation(sub_type:, super_type:, constraints: Subtyping::Constraints.empty)
115
119
  Steep.logger.debug { "check_relation: self:#{self_type}, instance:#{module_context.instance_type}, class:#{module_context.module_type} |- #{sub_type} <: #{super_type}" }
116
120
  relation = Subtyping::Relation.new(sub_type: sub_type, super_type: super_type)
117
- checker.check(
118
- relation,
119
- self_type: self_type,
120
- instance_type: module_context.instance_type,
121
- class_type: module_context.module_type,
122
- constraints: constraints
123
- )
121
+ checker.push_variable_bounds(variable_context.upper_bounds) do
122
+ checker.check(
123
+ relation,
124
+ self_type: self_type,
125
+ instance_type: module_context.instance_type,
126
+ class_type: module_context.module_type,
127
+ constraints: constraints
128
+ )
129
+ end
124
130
  end
125
131
 
126
132
  def for_new_method(method_name, node, args:, self_type:, definition:)
@@ -169,6 +175,12 @@ module Steep
169
175
  end
170
176
  end
171
177
 
178
+ if definition && method_type
179
+ variable_context = TypeInference::Context::TypeVariableContext.new(method_type.type_params, parent_context: self.variable_context)
180
+ else
181
+ variable_context = self.variable_context
182
+ end
183
+
172
184
  method_context = TypeInference::Context::MethodContext.new(
173
185
  name: method_name,
174
186
  method: definition && definition.methods[method_name],
@@ -245,12 +257,20 @@ module Steep
245
257
  self_type: annots.self_type || self_type,
246
258
  type_env: type_env,
247
259
  lvar_env: lvar_env,
248
- call_context: call_context
260
+ call_context: call_context,
261
+ variable_context: variable_context
249
262
  ),
250
263
  typing: typing,
251
264
  )
252
265
  end
253
266
 
267
+ def with_method_constr(method_name, node, args:, self_type:, definition:)
268
+ constr = for_new_method(method_name, node, args: args, self_type: self_type, definition: definition)
269
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
270
+ yield constr
271
+ end
272
+ end
273
+
254
274
  def implement_module(module_name:, super_name: nil, annotations:)
255
275
  if (annotation = annotations.implement_module_annotation)
256
276
  absolute_name(annotation.name.name).yield_self do |absolute_name|
@@ -332,7 +352,10 @@ module Steep
332
352
  module_context = default_module_context(implement_module_name, const_env: module_const_env, current_namespace: new_namespace)
333
353
 
334
354
  unless implement_module_name
335
- module_context = module_context.update(module_type: AST::Builtin::Module.instance_type)
355
+ module_context = module_context.update(
356
+ module_type: AST::Builtin::Module.instance_type,
357
+ instance_type: AST::Builtin::BasicObject.instance_type
358
+ )
336
359
  end
337
360
 
338
361
  if implement_module_name
@@ -389,6 +412,21 @@ module Steep
389
412
  class_type: module_context.module_type
390
413
  ).annotate(annots)
391
414
 
415
+ if implement_module_name
416
+ definition = checker.factory.definition_builder.build_instance(implement_module_name.name)
417
+ type_params = definition.type_params_decl.map do |param|
418
+ Interface::TypeParam.new(
419
+ name: param.name,
420
+ upper_bound: checker.factory.type_opt(param.upper_bound),
421
+ variance: param.variance,
422
+ unchecked: param.unchecked?
423
+ )
424
+ end
425
+ variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
426
+ else
427
+ variable_context = TypeInference::Context::TypeVariableContext.empty
428
+ end
429
+
392
430
  self.class.new(
393
431
  checker: checker,
394
432
  source: source,
@@ -402,11 +440,19 @@ module Steep
402
440
  self_type: module_context.module_type,
403
441
  type_env: module_type_env,
404
442
  lvar_env: lvar_env,
405
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
443
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
444
+ variable_context: variable_context
406
445
  )
407
446
  )
408
447
  end
409
448
 
449
+ def with_module_constr(node)
450
+ constr = for_module(node)
451
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
452
+ yield constr
453
+ end
454
+ end
455
+
410
456
  def for_class(node)
411
457
  new_class_name = module_name_from_node(node.children.first) or raise "Unexpected class name: #{node.children.first}"
412
458
  super_class_name = node.children[1] && module_name_from_node(node.children[1])
@@ -422,10 +468,7 @@ module Steep
422
468
 
423
469
  if implement_module_name
424
470
  if super_class_name && implement_module_name.name == absolute_name(super_class_name)
425
- module_context = module_context.update(
426
- instance_definition: nil,
427
- module_definition: nil
428
- )
471
+ module_context = module_context.update(instance_definition: nil, module_definition: nil)
429
472
  end
430
473
  else
431
474
  module_context = module_context.update(
@@ -446,6 +489,18 @@ module Steep
446
489
  module_context = module_context.update(module_type: annots.self_type)
447
490
  end
448
491
 
492
+ definition = checker.factory.definition_builder.build_instance(module_context.class_name)
493
+ type_params = definition.type_params_decl.map do |type_param|
494
+ Interface::TypeParam.new(
495
+ name: type_param.name,
496
+ upper_bound: type_param.upper_bound&.yield_self {|t| checker.factory.type(t) },
497
+ variance: type_param.variance,
498
+ unchecked: type_param.unchecked?,
499
+ location: type_param.location
500
+ )
501
+ end
502
+ variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
503
+
449
504
  class_type_env = TypeInference::TypeEnv.build(annotations: annots,
450
505
  subtyping: checker,
451
506
  const_env: class_const_env,
@@ -466,7 +521,8 @@ module Steep
466
521
  self_type: module_context.module_type,
467
522
  type_env: class_type_env,
468
523
  lvar_env: lvar_env,
469
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
524
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
525
+ variable_context: variable_context
470
526
  )
471
527
 
472
528
  self.class.new(
@@ -478,6 +534,24 @@ module Steep
478
534
  )
479
535
  end
480
536
 
537
+ def with_class_constr(node)
538
+ constr = for_class(node)
539
+
540
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
541
+ yield constr
542
+ end
543
+ end
544
+
545
+ def with_sclass_constr(node, type)
546
+ if constr = for_sclass(node, type)
547
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
548
+ yield constr
549
+ end
550
+ else
551
+ yield nil
552
+ end
553
+ end
554
+
481
555
  def for_sclass(node, type)
482
556
  annots = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
483
557
 
@@ -554,7 +628,8 @@ module Steep
554
628
  self_type: module_context.module_type,
555
629
  type_env: type_env,
556
630
  lvar_env: lvar_env,
557
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
631
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
632
+ variable_context: TypeInference::Context::TypeVariableContext.empty # Assuming `::Class` and `::Module` don't have type params
558
633
  )
559
634
 
560
635
  self.class.new(
@@ -930,7 +1005,7 @@ module Steep
930
1005
  yield_self do
931
1006
  send_node, params, body = node.children
932
1007
  if send_node.type == :lambda
933
- type_lambda(node, block_params: params, block_body: body, type_hint: hint)
1008
+ type_lambda(node, params_node: params, body_node: body, type_hint: hint)
934
1009
  else
935
1010
  type_send(node, send_node: send_node, block_params: params, block_body: body, unwrap: send_node.type == :csend)
936
1011
  end
@@ -949,7 +1024,7 @@ module Steep
949
1024
  params = Parser::AST::Node.new(:args, arg_nodes)
950
1025
 
951
1026
  if send_node.type == :lambda
952
- type_lambda(node, block_params: params, block_body: body, type_hint: hint)
1027
+ type_lambda(node, params_node: params, body_node: body, type_hint: hint)
953
1028
  else
954
1029
  type_send(node, send_node: send_node, block_params: params, block_body: body, unwrap: send_node.type == :csend)
955
1030
  end
@@ -959,75 +1034,76 @@ module Steep
959
1034
  yield_self do
960
1035
  name, args_node, body_node = node.children
961
1036
 
962
- new = for_new_method(
1037
+ with_method_constr(
963
1038
  name,
964
1039
  node,
965
1040
  args: args_node.children,
966
1041
  self_type: module_context&.instance_type,
967
1042
  definition: module_context&.instance_definition
968
- )
969
- new.typing.add_context_for_node(node, context: new.context)
970
- new.typing.add_context_for_body(node, context: new.context)
971
-
972
- new.method_context.tap do |method_context|
973
- if method_context.method
974
- method_name = InstanceMethodName.new(type_name: method_context.method.implemented_in, method_name: name)
975
- new.typing.source_index.add_definition(method: method_name, definition: node)
1043
+ ) do |new|
1044
+ new.typing.add_context_for_node(node, context: new.context)
1045
+ new.typing.add_context_for_body(node, context: new.context)
1046
+
1047
+ new.method_context.tap do |method_context|
1048
+ if method_context.method
1049
+ method_name = InstanceMethodName.new(type_name: method_context.method.implemented_in, method_name: name)
1050
+ new.typing.source_index.add_definition(method: method_name, definition: node)
1051
+ end
976
1052
  end
977
- end
978
-
979
- new = new.synthesize_children(args_node)
980
1053
 
981
- body_pair = if body_node
982
- return_type = expand_alias(new.method_context&.return_type)
983
- if return_type && !return_type.is_a?(AST::Types::Void)
984
- new.check(body_node, return_type) do |_, actual_type, result|
985
- typing.add_error(
986
- Diagnostic::Ruby::MethodBodyTypeMismatch.new(
987
- node: node,
988
- expected: new.method_context&.return_type,
989
- actual: actual_type,
990
- result: result
1054
+ new = new.synthesize_children(args_node)
1055
+
1056
+ body_pair = if body_node
1057
+ return_type = expand_alias(new.method_context&.return_type)
1058
+ if return_type && !return_type.is_a?(AST::Types::Void)
1059
+ new.check(body_node, return_type) do |_, actual_type, result|
1060
+ typing.add_error(
1061
+ Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1062
+ node: node,
1063
+ expected: new.method_context&.return_type,
1064
+ actual: actual_type,
1065
+ result: result
1066
+ )
991
1067
  )
992
- )
1068
+ end
1069
+ else
1070
+ new.synthesize(body_node)
993
1071
  end
994
1072
  else
995
- new.synthesize(body_node)
996
- end
997
- else
998
- return_type = expand_alias(new.method_context&.return_type)
999
- if return_type && !return_type.is_a?(AST::Types::Void)
1000
- result = check_relation(sub_type: AST::Builtin.nil_type, super_type: return_type)
1001
- if result.failure?
1002
- typing.add_error(
1003
- Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1004
- node: node,
1005
- expected: new.method_context&.return_type,
1006
- actual: AST::Builtin.nil_type,
1007
- result: result
1073
+ return_type = expand_alias(new.method_context&.return_type)
1074
+ if return_type && !return_type.is_a?(AST::Types::Void)
1075
+ result = check_relation(sub_type: AST::Builtin.nil_type, super_type: return_type)
1076
+ if result.failure?
1077
+ typing.add_error(
1078
+ Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1079
+ node: node,
1080
+ expected: new.method_context&.return_type,
1081
+ actual: AST::Builtin.nil_type,
1082
+ result: result
1083
+ )
1008
1084
  )
1009
- )
1085
+ end
1010
1086
  end
1087
+
1088
+ Pair.new(type: AST::Builtin.nil_type, constr: new)
1011
1089
  end
1012
1090
 
1013
- Pair.new(type: AST::Builtin.nil_type, constr: new)
1014
- end
1091
+ if body_node
1092
+ # Add context to ranges from the end of the method body to the beginning of the `end` keyword
1093
+ if node.loc.end
1094
+ # Skip end-less def
1095
+ begin_pos = body_node.loc.expression.end_pos
1096
+ end_pos = node.loc.end.begin_pos
1097
+ typing.add_context(begin_pos..end_pos, context: body_pair.context)
1098
+ end
1099
+ end
1015
1100
 
1016
- if body_node
1017
- # Add context to ranges from the end of the method body to the beginning of the `end` keyword
1018
- if node.loc.end
1019
- # Skip end-less def
1020
- begin_pos = body_node.loc.expression.end_pos
1021
- end_pos = node.loc.end.begin_pos
1022
- typing.add_context(begin_pos..end_pos, context: body_pair.context)
1101
+ if module_context
1102
+ module_context.defined_instance_methods << node.children[0]
1023
1103
  end
1024
- end
1025
1104
 
1026
- if module_context
1027
- module_context.defined_instance_methods << node.children[0]
1105
+ add_typing(node, type: AST::Builtin::Symbol.instance_type)
1028
1106
  end
1029
-
1030
- add_typing(node, type: AST::Builtin::Symbol.instance_type)
1031
1107
  end
1032
1108
 
1033
1109
  when :defs
@@ -1105,16 +1181,18 @@ module Steep
1105
1181
 
1106
1182
  when :return
1107
1183
  yield_self do
1108
- if node.children.size > 0
1109
- method_return_type = expand_alias(method_context&.return_type)
1184
+ method_return_type = expand_alias(method_context&.return_type)
1110
1185
 
1186
+ if node.children.size > 0
1111
1187
  return_types = node.children.map do |value|
1112
- synthesize(value,
1113
- hint: if method_return_type.is_a?(AST::Types::Void)
1114
- nil
1115
- else
1116
- method_return_type
1117
- end).type
1188
+ synthesize(
1189
+ value,
1190
+ hint: if method_return_type.is_a?(AST::Types::Void)
1191
+ nil
1192
+ else
1193
+ method_return_type
1194
+ end
1195
+ ).type
1118
1196
  end
1119
1197
 
1120
1198
  value_type = if return_types.size == 1
@@ -1123,20 +1201,23 @@ module Steep
1123
1201
  AST::Builtin::Array.instance_type(union_type(*return_types))
1124
1202
  end
1125
1203
 
1126
- if method_return_type
1127
- unless method_return_type.is_a?(AST::Types::Void)
1128
- result = check_relation(sub_type: value_type, super_type: method_return_type)
1204
+ else
1205
+ value_type = AST::Builtin.nil_type
1206
+ end
1129
1207
 
1130
- if result.failure?
1131
- typing.add_error(
1132
- Diagnostic::Ruby::ReturnTypeMismatch.new(
1133
- node: node,
1134
- expected: method_context&.return_type,
1135
- actual: value_type,
1136
- result: result
1137
- )
1208
+ if method_return_type
1209
+ unless method_return_type.is_a?(AST::Types::Void)
1210
+ result = check_relation(sub_type: value_type, super_type: method_return_type)
1211
+
1212
+ if result.failure?
1213
+ typing.add_error(
1214
+ Diagnostic::Ruby::ReturnTypeMismatch.new(
1215
+ node: node,
1216
+ expected: method_context&.return_type,
1217
+ actual: value_type,
1218
+ result: result
1138
1219
  )
1139
- end
1220
+ )
1140
1221
  end
1141
1222
  end
1142
1223
  end
@@ -1427,7 +1508,7 @@ module Steep
1427
1508
  end
1428
1509
  _, constr = constr.synthesize(sup) if sup
1429
1510
 
1430
- constr.for_class(node).tap do |constructor|
1511
+ with_class_constr(node) do |constructor|
1431
1512
  if module_type = constructor.module_context&.module_type
1432
1513
  _, constructor = constructor.add_typing(name, type: module_type)
1433
1514
  else
@@ -1467,7 +1548,7 @@ module Steep
1467
1548
  _, constr = constr.synthesize(name)
1468
1549
  end
1469
1550
 
1470
- for_module(node).yield_self do |constructor|
1551
+ with_module_constr(node) do |constructor|
1471
1552
  if module_type = constructor.module_context&.module_type
1472
1553
  _, constructor = constructor.add_typing(name, type: module_type)
1473
1554
  else
@@ -1495,27 +1576,28 @@ module Steep
1495
1576
  when :sclass
1496
1577
  yield_self do
1497
1578
  type, constr = synthesize(node.children[0])
1498
- constructor = constr.for_sclass(node, type)
1499
1579
 
1500
- unless constructor
1501
- typing.add_error(
1502
- Diagnostic::Ruby::UnsupportedSyntax.new(
1503
- node: node,
1504
- message: "sclass receiver must be instance type or singleton type, but type given `#{type}`"
1580
+ with_sclass_constr(node, type) do |constructor|
1581
+ unless constructor
1582
+ typing.add_error(
1583
+ Diagnostic::Ruby::UnsupportedSyntax.new(
1584
+ node: node,
1585
+ message: "sclass receiver must be instance type or singleton type, but type given `#{type}`"
1586
+ )
1505
1587
  )
1506
- )
1507
- constr.add_typing(node, type: AST::Builtin.nil_type)
1508
- return
1509
- end
1588
+ constr.add_typing(node, type: AST::Builtin.nil_type)
1589
+ return
1590
+ end
1510
1591
 
1511
- constructor.typing.add_context_for_node(node, context: constructor.context)
1512
- constructor.typing.add_context_for_body(node, context: constructor.context)
1592
+ constructor.typing.add_context_for_node(node, context: constructor.context)
1593
+ constructor.typing.add_context_for_body(node, context: constructor.context)
1513
1594
 
1514
- constructor.synthesize(node.children[1]) if node.children[1]
1595
+ constructor.synthesize(node.children[1]) if node.children[1]
1515
1596
 
1516
- if constructor.module_context.instance_definition && module_context.module_definition
1517
- if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
1518
- module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
1597
+ if constructor.module_context.instance_definition && module_context.module_definition
1598
+ if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
1599
+ module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
1600
+ end
1519
1601
  end
1520
1602
  end
1521
1603
 
@@ -1573,8 +1655,7 @@ module Steep
1573
1655
  instance_type: module_context.instance_type,
1574
1656
  class_type: module_context.module_type
1575
1657
  ) do |error|
1576
- case error
1577
- when Subtyping::Result::Failure
1658
+ if error
1578
1659
  const_type = type_env.get(const: const_name)
1579
1660
  typing.add_error(
1580
1661
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -1584,7 +1665,7 @@ module Steep
1584
1665
  result: error
1585
1666
  )
1586
1667
  )
1587
- when nil
1668
+ else
1588
1669
  typing.add_error(
1589
1670
  Diagnostic::Ruby::UnknownConstantAssigned.new(
1590
1671
  node: node,
@@ -2385,7 +2466,7 @@ module Steep
2385
2466
  raise "#synthesize should return an instance of Pair: #{pair.class}, node=#{node.inspect}"
2386
2467
  end
2387
2468
  end
2388
- rescue RBS::ErrorBase => exn
2469
+ rescue RBS::BaseError => exn
2389
2470
  Steep.logger.warn { "Unexpected RBS error: #{exn.message}" }
2390
2471
  exn.backtrace.each {|loc| Steep.logger.warn " #{loc}" }
2391
2472
  typing.add_error(Diagnostic::Ruby::UnexpectedError.new(node: node, error: exn))
@@ -2411,6 +2492,7 @@ module Steep
2411
2492
 
2412
2493
  def type_ivasgn(name, rhs, node)
2413
2494
  rhs_type = synthesize(rhs, hint: type_env.get(ivar: name) { fallback_to_any(node) }).type
2495
+
2414
2496
  ivar_type = type_env.assign(
2415
2497
  ivar: name,
2416
2498
  type: rhs_type,
@@ -2418,8 +2500,7 @@ module Steep
2418
2500
  instance_type: module_context.instance_type,
2419
2501
  class_type: module_context.module_type
2420
2502
  ) do |error|
2421
- case error
2422
- when Subtyping::Result::Failure
2503
+ if error
2423
2504
  type = type_env.get(ivar: name)
2424
2505
  typing.add_error(
2425
2506
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -2429,7 +2510,7 @@ module Steep
2429
2510
  result: error
2430
2511
  )
2431
2512
  )
2432
- when nil
2513
+ else
2433
2514
  fallback_to_any node
2434
2515
  end
2435
2516
  end
@@ -2473,8 +2554,7 @@ module Steep
2473
2554
  instance_type: module_context.instance_type,
2474
2555
  class_type: module_context.module_type
2475
2556
  ) do |error|
2476
- case error
2477
- when Subtyping::Result::Failure
2557
+ if error
2478
2558
  var_type = type_env.get(ivar: ivar)
2479
2559
  typing.add_error(
2480
2560
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -2484,7 +2564,7 @@ module Steep
2484
2564
  result: error
2485
2565
  )
2486
2566
  )
2487
- when nil
2567
+ else
2488
2568
  fallback_to_any node
2489
2569
  end
2490
2570
  end
@@ -2647,20 +2727,22 @@ module Steep
2647
2727
  end
2648
2728
  end
2649
2729
 
2650
- def type_lambda(node, block_params:, block_body:, type_hint:)
2730
+ def type_lambda(node, params_node:, body_node:, type_hint:)
2651
2731
  block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
2652
- params = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
2732
+ params = TypeInference::BlockParams.from_node(params_node, annotations: block_annotations)
2653
2733
 
2654
2734
  case type_hint
2655
2735
  when AST::Types::Proc
2656
2736
  params_hint = type_hint.type.params
2657
2737
  return_hint = type_hint.type.return_type
2738
+ block_hint = type_hint.block
2658
2739
  end
2659
2740
 
2660
2741
  block_constr = for_block(
2661
2742
  block_params: params,
2662
2743
  block_param_hint: params_hint,
2663
2744
  block_type_hint: return_hint,
2745
+ block_block_hint: block_hint,
2664
2746
  block_annotations: block_annotations,
2665
2747
  node_type_hint: nil
2666
2748
  )
@@ -2671,10 +2753,30 @@ module Steep
2671
2753
  _, block_constr = block_constr.synthesize(param.node, hint: param.type)
2672
2754
  end
2673
2755
 
2674
- if block_body
2756
+ block =
2757
+ if block_param = params.block_param
2758
+ if block_param_type = block_param.type
2759
+ case block_param_type
2760
+ when AST::Types::Proc
2761
+ Interface::Block.new(type: block_param_type.type, optional: false)
2762
+ when AST::Types::Union
2763
+ if block_param_type.types.size == 2
2764
+ if block_param_type.types.find {|t| t.is_a?(AST::Types::Nil) }
2765
+ if proc_type = block_param_type.types.find {|t| t.is_a?(AST::Types::Proc) }
2766
+ Interface::Block.new(type: proc_type.type, optional: true)
2767
+ end
2768
+ end
2769
+ end
2770
+ end
2771
+ else
2772
+ type_hint.block
2773
+ end
2774
+ end
2775
+
2776
+ if body_node
2675
2777
  return_type = block_constr.synthesize_block(
2676
2778
  node: node,
2677
- block_body: block_body,
2779
+ block_body: body_node,
2678
2780
  block_type_hint: return_hint
2679
2781
  )
2680
2782
 
@@ -2702,7 +2804,7 @@ module Steep
2702
2804
  return_type: return_type,
2703
2805
  location: nil
2704
2806
  ),
2705
- block: nil
2807
+ block: block
2706
2808
  )
2707
2809
 
2708
2810
  add_typing node, type: block_type
@@ -2830,6 +2932,32 @@ module Steep
2830
2932
  )
2831
2933
  )
2832
2934
 
2935
+ when AST::Types::Var
2936
+ if upper_bound = variable_context[receiver_type.name]
2937
+ interface = calculate_interface(upper_bound, private: false)
2938
+
2939
+ constr.type_send_interface(
2940
+ node,
2941
+ interface: interface,
2942
+ receiver: receiver,
2943
+ receiver_type: receiver_type,
2944
+ method_name: method_name,
2945
+ arguments: arguments,
2946
+ block_params: block_params,
2947
+ block_body: block_body
2948
+ )
2949
+ else
2950
+ constr = constr.synthesize_children(node, skips: [receiver])
2951
+ constr.add_call(
2952
+ TypeInference::MethodCall::NoMethodError.new(
2953
+ node: node,
2954
+ context: context.method_context,
2955
+ method_name: method_name,
2956
+ receiver_type: receiver_type,
2957
+ error: Diagnostic::Ruby::NoMethod.new(node: node, method: method_name, type: receiver_type)
2958
+ )
2959
+ )
2960
+ end
2833
2961
  when AST::Types::Void, AST::Types::Bot, AST::Types::Top, AST::Types::Var
2834
2962
  constr = constr.synthesize_children(node, skips: [receiver])
2835
2963
  constr.add_call(
@@ -2872,6 +3000,7 @@ module Steep
2872
3000
  block_params: block_params,
2873
3001
  block_body: block_body)
2874
3002
  end
3003
+
2875
3004
  else
2876
3005
  interface = calculate_interface(receiver_type, private: private, self_type: receiver_type)
2877
3006
 
@@ -2991,119 +3120,166 @@ module Steep
2991
3120
  end
2992
3121
  end
2993
3122
 
3123
+ def apply_solution(errors, node:, method_type:)
3124
+ subst = yield
3125
+
3126
+ [
3127
+ method_type.subst(subst),
3128
+ true,
3129
+ subst
3130
+ ]
3131
+
3132
+ rescue Subtyping::Constraints::UnsatisfiableConstraint => exn
3133
+ errors << Diagnostic::Ruby::UnsatisfiableConstraint.new(
3134
+ node: node,
3135
+ method_type: method_type,
3136
+ var: exn.var,
3137
+ sub_type: exn.sub_type,
3138
+ super_type: exn.super_type,
3139
+ result: exn.result
3140
+ )
3141
+ [method_type, false, Interface::Substitution.empty]
3142
+ end
3143
+
3144
+ def eliminate_vars(type, variables, to: AST::Builtin.any_type)
3145
+ if variables.empty?
3146
+ type
3147
+ else
3148
+ subst = Interface::Substitution.build(variables, Array.new(variables.size, to))
3149
+ type.subst(subst)
3150
+ end
3151
+ end
3152
+
2994
3153
  def try_method_type(node, receiver_type:, method_name:, method_type:, arguments:, block_params:, block_body:, topdown_hint:)
2995
- fresh_types = method_type.type_params.map {|x| AST::Types::Var.fresh(x)}
2996
- fresh_vars = Set.new(fresh_types.map(&:name))
2997
- instantiation = Interface::Substitution.build(method_type.type_params, fresh_types)
3154
+ type_params, instantiation = Interface::TypeParam.rename(method_type.type_params)
3155
+ type_param_names = type_params.map(&:name)
2998
3156
 
2999
3157
  constr = self
3000
3158
 
3001
3159
  method_type = method_type.instantiate(instantiation)
3002
3160
 
3003
- constraints = Subtyping::Constraints.new(unknowns: fresh_types.map(&:name))
3004
3161
  variance = Subtyping::VariableVariance.from_method_type(method_type)
3005
3162
  occurence = Subtyping::VariableOccurence.from_method_type(method_type)
3163
+ constraints = Subtyping::Constraints.new(unknowns: type_params.map(&:name))
3164
+ ccontext = Subtyping::Constraints::Context.new(
3165
+ self_type: self_type,
3166
+ instance_type: module_context.instance_type,
3167
+ class_type: module_context.module_type,
3168
+ variance: variance
3169
+ )
3006
3170
 
3007
- errors = []
3171
+ upper_bounds = {}
3008
3172
 
3009
- args = TypeInference::SendArgs.new(node: node, arguments: arguments, method_name: method_name, method_type: method_type)
3010
- es = args.each do |arg|
3011
- case arg
3012
- when TypeInference::SendArgs::PositionalArgs::NodeParamPair
3013
- _, constr = constr.type_check_argument(
3014
- arg.node,
3015
- type: arg.param.type,
3016
- receiver_type: receiver_type,
3017
- constraints: constraints,
3018
- errors: errors
3019
- )
3173
+ type_params.each do |param|
3174
+ if ub = param.upper_bound
3175
+ constraints.add(param.name, super_type: ub, skip: true)
3176
+ upper_bounds[param.name] = ub
3177
+ end
3178
+ end
3020
3179
 
3021
- when TypeInference::SendArgs::PositionalArgs::NodeTypePair
3022
- _, constr = bypass_splat(arg.node) do |n|
3023
- constr.type_check_argument(
3024
- n,
3025
- type: arg.node_type,
3180
+ checker.push_variable_bounds(upper_bounds) do
3181
+ errors = []
3182
+
3183
+ args = TypeInference::SendArgs.new(node: node, arguments: arguments, method_name: method_name, method_type: method_type)
3184
+ es = args.each do |arg|
3185
+ case arg
3186
+ when TypeInference::SendArgs::PositionalArgs::NodeParamPair
3187
+ _, constr = constr.type_check_argument(
3188
+ arg.node,
3189
+ type: arg.param.type,
3026
3190
  receiver_type: receiver_type,
3027
3191
  constraints: constraints,
3028
- report_node: arg.node,
3029
3192
  errors: errors
3030
3193
  )
3031
- end
3032
-
3033
- when TypeInference::SendArgs::PositionalArgs::UnexpectedArg
3034
- _, constr = bypass_splat(arg.node) do |n|
3035
- constr.synthesize(n)
3036
- end
3037
-
3038
- when TypeInference::SendArgs::PositionalArgs::SplatArg
3039
- arg_type, _ = constr
3040
- .with_child_typing(range: arg.node.loc.expression.begin_pos ... arg.node.loc.expression.end_pos)
3041
- .try_tuple_type!(arg.node.children[0])
3042
- arg.type = arg_type
3043
3194
 
3044
- when TypeInference::SendArgs::PositionalArgs::MissingArg
3045
- # ignore
3046
-
3047
- when TypeInference::SendArgs::KeywordArgs::ArgTypePairs
3048
- arg.pairs.each do |node, type|
3049
- _, constr = bypass_splat(node) do |node|
3195
+ when TypeInference::SendArgs::PositionalArgs::NodeTypePair
3196
+ _, constr = bypass_splat(arg.node) do |n|
3050
3197
  constr.type_check_argument(
3051
- node,
3052
- type: type,
3198
+ n,
3199
+ type: arg.node_type,
3053
3200
  receiver_type: receiver_type,
3054
3201
  constraints: constraints,
3202
+ report_node: arg.node,
3055
3203
  errors: errors
3056
3204
  )
3057
3205
  end
3058
- end
3059
3206
 
3060
- when TypeInference::SendArgs::KeywordArgs::UnexpectedKeyword
3061
- if arg.node.type == :pair
3062
- arg.node.children.each do |nn|
3063
- _, constr = constr.synthesize(nn)
3064
- end
3065
- else
3207
+ when TypeInference::SendArgs::PositionalArgs::UnexpectedArg
3066
3208
  _, constr = bypass_splat(arg.node) do |n|
3067
3209
  constr.synthesize(n)
3068
3210
  end
3069
- end
3070
3211
 
3071
- when TypeInference::SendArgs::KeywordArgs::SplatArg
3072
- type, _ = bypass_splat(arg.node) do |sp_node|
3073
- if sp_node.type == :hash
3074
- pair = constr.type_hash_record(sp_node, nil) and break pair
3212
+ when TypeInference::SendArgs::PositionalArgs::SplatArg
3213
+ arg_type, _ = constr
3214
+ .with_child_typing(range: arg.node.loc.expression.begin_pos ... arg.node.loc.expression.end_pos)
3215
+ .try_tuple_type!(arg.node.children[0])
3216
+ arg.type = arg_type
3217
+
3218
+ when TypeInference::SendArgs::PositionalArgs::MissingArg
3219
+ # ignore
3220
+
3221
+ when TypeInference::SendArgs::KeywordArgs::ArgTypePairs
3222
+ arg.pairs.each do |node, type|
3223
+ _, constr = bypass_splat(node) do |node|
3224
+ constr.type_check_argument(
3225
+ node,
3226
+ type: type,
3227
+ receiver_type: receiver_type,
3228
+ constraints: constraints,
3229
+ errors: errors
3230
+ )
3231
+ end
3075
3232
  end
3076
3233
 
3077
- constr.synthesize(sp_node)
3078
- end
3234
+ when TypeInference::SendArgs::KeywordArgs::UnexpectedKeyword
3235
+ if arg.node.type == :pair
3236
+ arg.node.children.each do |nn|
3237
+ _, constr = constr.synthesize(nn)
3238
+ end
3239
+ else
3240
+ _, constr = bypass_splat(arg.node) do |n|
3241
+ constr.synthesize(n)
3242
+ end
3243
+ end
3079
3244
 
3080
- arg.type = type
3245
+ when TypeInference::SendArgs::KeywordArgs::SplatArg
3246
+ type, _ = bypass_splat(arg.node) do |sp_node|
3247
+ if sp_node.type == :hash
3248
+ pair = constr.type_hash_record(sp_node, nil) and break pair
3249
+ end
3081
3250
 
3082
- when TypeInference::SendArgs::KeywordArgs::MissingKeyword
3083
- # ignore
3084
- else
3085
- raise arg.inspect
3086
- end
3251
+ constr.synthesize(sp_node)
3252
+ end
3087
3253
 
3088
- constr
3089
- end
3254
+ arg.type = type
3090
3255
 
3091
- errors.push(*es)
3256
+ when TypeInference::SendArgs::KeywordArgs::MissingKeyword
3257
+ # ignore
3258
+ else
3259
+ raise arg.inspect
3260
+ end
3092
3261
 
3093
- if block_params
3094
- # block is given
3095
- block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
3096
- block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
3262
+ constr
3263
+ end
3097
3264
 
3098
- if method_type.block
3099
- pairs = method_type.block && block_params_&.zip(method_type.block.type.params)
3265
+ errors.push(*es)
3100
3266
 
3101
- if pairs
3102
- begin
3267
+ if block_params
3268
+ # block is given
3269
+ block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
3270
+ block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
3271
+
3272
+ if method_type.block
3273
+ # Method accepts block
3274
+ pairs = method_type.block && block_params_&.zip(method_type.block.type.params, nil)
3275
+
3276
+ if pairs
3277
+ # Block parameters are compatible with the block type
3103
3278
  block_constr = constr.for_block(
3104
3279
  block_params: block_params_,
3105
3280
  block_param_hint: method_type.block.type.params,
3106
3281
  block_type_hint: method_type.block.type.return_type,
3282
+ block_block_hint: nil,
3107
3283
  block_annotations: block_annotations,
3108
3284
  node_type_hint: method_type.type.return_type
3109
3285
  )
@@ -3131,236 +3307,206 @@ module Steep
3131
3307
  end
3132
3308
  end
3133
3309
 
3134
- s = constraints.solution(
3135
- checker,
3136
- self_type: self_type,
3137
- instance_type: module_context.instance_type,
3138
- class_type: module_context.module_type,
3139
- variance: variance,
3140
- variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables
3141
- )
3142
- method_type = method_type.subst(s)
3143
- block_constr = block_constr.update_lvar_env {|env| env.subst(s) }
3144
- if block_body
3145
- block_body_type = block_constr.synthesize_block(
3146
- node: node,
3147
- block_body: block_body,
3148
- block_type_hint: method_type.block.type.return_type
3310
+ method_type, solved, s = apply_solution(errors, node: node, method_type: method_type) {
3311
+ constraints.solution(
3312
+ checker,
3313
+ variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables,
3314
+ context: ccontext
3149
3315
  )
3150
- else
3151
- block_body_type = AST::Builtin.nil_type
3152
- end
3316
+ }
3317
+
3318
+ if solved
3319
+ # Ready for type check the body of the block
3320
+ block_constr = block_constr.update_lvar_env {|env| env.subst(s) }
3321
+ if block_body
3322
+ block_body_type = block_constr.synthesize_block(
3323
+ node: node,
3324
+ block_body: block_body,
3325
+ block_type_hint: method_type.block.type.return_type
3326
+ )
3327
+ else
3328
+ block_body_type = AST::Builtin.nil_type
3329
+ end
3153
3330
 
3154
- result = check_relation(sub_type: block_body_type,
3155
- super_type: method_type.block.type.return_type,
3156
- constraints: constraints)
3331
+ result = check_relation(sub_type: block_body_type,
3332
+ super_type: method_type.block.type.return_type,
3333
+ constraints: constraints)
3157
3334
 
3158
- case result
3159
- when Subtyping::Result::Success
3160
- s = constraints.solution(
3161
- checker,
3162
- self_type: self_type,
3163
- instance_type: module_context.instance_type,
3164
- class_type: module_context.module_type,
3165
- variance: variance,
3166
- variables: fresh_vars
3167
- )
3168
- method_type = method_type.subst(s)
3335
+ if result.success?
3336
+ # Successfully type checked the body
3337
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) do
3338
+ constraints.solution(checker, variables: type_param_names, context: ccontext)
3339
+ end
3340
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3169
3341
 
3170
- return_type = method_type.type.return_type
3171
- if break_type = block_annotations.break_type
3172
- return_type = union_type(break_type, return_type)
3342
+ return_type = method_type.type.return_type
3343
+ if break_type = block_annotations.break_type
3344
+ return_type = union_type(break_type, return_type)
3345
+ end
3346
+ else
3347
+ # The block body has incompatible type
3348
+ errors << Diagnostic::Ruby::BlockBodyTypeMismatch.new(
3349
+ node: node,
3350
+ expected: method_type.block.type.return_type,
3351
+ actual: block_body_type,
3352
+ result: result
3353
+ )
3354
+
3355
+ method_type = eliminate_vars(method_type, type_param_names)
3356
+ return_type = method_type.type.return_type
3173
3357
  end
3174
3358
 
3175
- when Subtyping::Result::Failure
3176
- errors << Diagnostic::Ruby::BlockBodyTypeMismatch.new(
3177
- node: node,
3178
- expected: method_type.block.type.return_type,
3179
- actual: block_body_type,
3180
- result: result
3181
- )
3359
+ block_constr.typing.save!
3360
+ else
3361
+ # Failed to infer the type of block parameters
3362
+ constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3363
+ errors << error
3364
+ end
3182
3365
 
3366
+ method_type = eliminate_vars(method_type, type_param_names)
3183
3367
  return_type = method_type.type.return_type
3184
3368
  end
3185
-
3186
- block_constr.typing.save!
3187
-
3188
- rescue Subtyping::Constraints::UnsatisfiableConstraint => exn
3189
- errors << Diagnostic::Ruby::UnsatisfiableConstraint.new(
3190
- node: node,
3191
- method_type: method_type,
3192
- var: exn.var,
3193
- sub_type: exn.sub_type,
3194
- super_type: exn.super_type,
3195
- result: exn.result
3369
+ else
3370
+ # Block parameters are unsupported syntax
3371
+ errors << Diagnostic::Ruby::UnsupportedSyntax.new(
3372
+ node: block_params,
3373
+ message: "Unsupported block params pattern, probably masgn?"
3196
3374
  )
3197
3375
 
3198
- constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3199
- errors << error
3200
- end
3376
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3377
+ constraints.solution(checker, variables: type_param_names, context: ccontext)
3378
+ }
3379
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3201
3380
 
3202
- s = Interface::Substitution.build(method_type.free_variables,
3203
- Array.new(method_type.free_variables.size, AST::Builtin.any_type))
3204
- method_type = method_type.subst(s)
3381
+ return_type = method_type.type.return_type
3205
3382
  end
3206
3383
  else
3207
- errors << Diagnostic::Ruby::UnsupportedSyntax.new(
3208
- node: block_params,
3209
- message: "Unsupported block params pattern, probably masgn?"
3210
- )
3384
+ # Block is given but method doesn't accept
3385
+ #
3386
+ constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3387
+ errors << error
3388
+ end
3211
3389
 
3212
- s = constraints.solution(
3213
- checker,
3214
- variance: variance,
3215
- variables: fresh_vars,
3216
- self_type: self_type,
3217
- instance_type: module_context.instance_type,
3218
- class_type: module_context.module_type
3390
+ errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3391
+ node: node,
3392
+ method_type: method_type
3219
3393
  )
3220
- method_type = method_type.subst(s)
3394
+
3395
+ method_type = eliminate_vars(method_type, type_param_names)
3396
+ return_type = method_type.type.return_type
3221
3397
  end
3222
3398
  else
3223
- # Block is given but method doesn't accept
3224
- #
3225
- constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3226
- errors << error
3227
- end
3399
+ # Block syntax is not given
3400
+ arg = args.block_pass_arg
3401
+
3402
+ case
3403
+ when arg.compatible?
3404
+ if arg.node
3405
+ # Block pass (&block) is given
3406
+ node_type, constr = constr.synthesize(arg.node, hint: arg.node_type)
3407
+
3408
+ nil_given =
3409
+ constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success? &&
3410
+ !node_type.is_a?(AST::Types::Any)
3411
+
3412
+ if nil_given
3413
+ # nil is given ==> no block arg node is given
3414
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3415
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3416
+ }
3417
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3228
3418
 
3229
- errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3230
- node: node,
3231
- method_type: method_type
3232
- )
3233
- end
3234
- else
3235
- arg = args.block_pass_arg
3419
+ # Passing no block
3420
+ errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3421
+ node: node,
3422
+ method_type: method_type
3423
+ )
3424
+ else
3425
+ # non-nil value is given
3426
+ constr.check_relation(sub_type: node_type, super_type: arg.node_type, constraints: constraints).else do |result|
3427
+ errors << Diagnostic::Ruby::BlockTypeMismatch.new(
3428
+ node: arg.node,
3429
+ expected: arg.node_type,
3430
+ actual: node_type,
3431
+ result: result
3432
+ )
3433
+ end
3236
3434
 
3237
- case
3238
- when arg.compatible?
3239
- if arg.node
3240
- subst = constraints.solution(
3241
- checker,
3242
- self_type: self_type,
3243
- instance_type: module_context.instance_type,
3244
- class_type: module_context.module_type,
3245
- variance: variance,
3246
- variables: occurence.params
3247
- )
3435
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3436
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3437
+ }
3438
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3439
+ end
3440
+ else
3441
+ # Block is not given
3442
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3443
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3444
+ }
3445
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3446
+ end
3248
3447
 
3249
- block_type = arg.node_type.subst(subst)
3448
+ return_type = method_type.type.return_type
3250
3449
 
3251
- node_type, constr = constr.synthesize(arg.node, hint: block_type)
3252
- nil_block =
3253
- constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success? &&
3254
- !node_type.is_a?(AST::Types::Any)
3450
+ when arg.block_missing?
3451
+ # Block is required but not given
3452
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3453
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3454
+ }
3255
3455
 
3256
- unless nil_block
3257
- constr.check_relation(sub_type: node_type, super_type: block_type, constraints: constraints).else do |result|
3258
- errors << Diagnostic::Ruby::BlockTypeMismatch.new(
3259
- node: arg.node,
3260
- expected: block_type,
3261
- actual: node_type,
3262
- result: result
3263
- )
3264
- end
3265
- end
3456
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3457
+ return_type = method_type.type.return_type
3266
3458
 
3267
- subst = constraints.solution(
3268
- checker,
3269
- self_type: self_type,
3270
- instance_type: module_context.instance_type,
3271
- class_type: module_context.module_type,
3272
- variance: variance,
3273
- variables: method_type.free_variables
3459
+ errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3460
+ node: node,
3461
+ method_type: method_type
3274
3462
  )
3275
3463
 
3276
- method_type = method_type.subst(subst)
3464
+ when arg.unexpected_block?
3465
+ # Unexpected block is given
3466
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3467
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3468
+ }
3469
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3470
+ return_type = method_type.type.return_type
3277
3471
 
3278
- if nil_block && arg.block.required?
3279
- # Passing no block
3280
- errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3472
+ node_type, constr = constr.synthesize(arg.node)
3473
+
3474
+ unless constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success?
3475
+ errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3281
3476
  node: node,
3282
3477
  method_type: method_type
3283
3478
  )
3284
3479
  end
3285
- else
3286
- subst = constraints.solution(
3287
- checker,
3288
- self_type: self_type,
3289
- instance_type: module_context.instance_type,
3290
- class_type: module_context.module_type,
3291
- variance: variance,
3292
- variables: method_type.free_variables
3293
- )
3294
-
3295
- method_type = method_type.subst(subst)
3296
- end
3297
-
3298
- when arg.block_missing?
3299
- subst = constraints.solution(
3300
- checker,
3301
- self_type: self_type,
3302
- instance_type: module_context.instance_type,
3303
- class_type: module_context.module_type,
3304
- variance: variance,
3305
- variables: method_type.free_variables
3306
- )
3307
-
3308
- method_type = method_type.subst(subst)
3309
-
3310
- errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3311
- node: node,
3312
- method_type: method_type
3313
- )
3314
-
3315
- when arg.unexpected_block?
3316
- subst = constraints.solution(
3317
- checker,
3318
- self_type: self_type,
3319
- instance_type: module_context.instance_type,
3320
- class_type: module_context.module_type,
3321
- variance: variance,
3322
- variables: method_type.free_variables
3323
- )
3324
-
3325
- method_type = method_type.subst(subst)
3326
-
3327
- node_type, constr = constr.synthesize(arg.node)
3328
-
3329
- unless constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success?
3330
- errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3331
- node: node,
3332
- method_type: method_type
3333
- )
3334
3480
  end
3335
3481
  end
3336
- end
3337
-
3338
- call = if errors.empty?
3339
- TypeInference::MethodCall::Typed.new(
3340
- node: node,
3341
- context: context.method_context,
3342
- receiver_type: receiver_type,
3343
- method_name: method_name,
3344
- actual_method_type: method_type,
3345
- return_type: return_type || method_type.type.return_type,
3346
- method_decls: method_type.method_decls
3347
- )
3348
- else
3349
- TypeInference::MethodCall::Error.new(
3350
- node: node,
3351
- context: context.method_context,
3352
- receiver_type: receiver_type,
3353
- method_name: method_name,
3354
- return_type: return_type || method_type.type.return_type,
3355
- method_decls: method_type.method_decls,
3356
- errors: errors
3357
- )
3358
- end
3359
3482
 
3360
- [
3361
- call,
3362
- constr
3363
- ]
3483
+ call = if errors.empty?
3484
+ TypeInference::MethodCall::Typed.new(
3485
+ node: node,
3486
+ context: context.method_context,
3487
+ receiver_type: receiver_type,
3488
+ method_name: method_name,
3489
+ actual_method_type: method_type,
3490
+ return_type: return_type || method_type.type.return_type,
3491
+ method_decls: method_type.method_decls
3492
+ )
3493
+ else
3494
+ TypeInference::MethodCall::Error.new(
3495
+ node: node,
3496
+ context: context.method_context,
3497
+ receiver_type: receiver_type,
3498
+ method_name: method_name,
3499
+ return_type: return_type || method_type.type.return_type,
3500
+ method_decls: method_type.method_decls,
3501
+ errors: errors
3502
+ )
3503
+ end
3504
+
3505
+ [
3506
+ call,
3507
+ constr
3508
+ ]
3509
+ end
3364
3510
  end
3365
3511
 
3366
3512
  def type_check_argument(node, receiver_type:, type:, constraints:, report_node: node, errors:)
@@ -3383,13 +3529,14 @@ module Steep
3383
3529
  message: "Unsupported block params pattern, probably masgn?"
3384
3530
  )
3385
3531
  )
3386
- block_params = TypeInference::BlockParams.new(leading_params: [], optional_params: [], rest_param: nil, trailing_params: [])
3532
+ block_params = TypeInference::BlockParams.new(leading_params: [], optional_params: [], rest_param: nil, trailing_params: [], block_param: nil)
3387
3533
  end
3388
3534
 
3389
3535
  block_constr = for_block(
3390
3536
  block_params: block_params,
3391
3537
  block_param_hint: nil,
3392
3538
  block_type_hint: AST::Builtin.any_type,
3539
+ block_block_hint: nil,
3393
3540
  block_annotations: block_annotations,
3394
3541
  node_type_hint: AST::Builtin.any_type
3395
3542
  )
@@ -3416,8 +3563,8 @@ module Steep
3416
3563
  end
3417
3564
  end
3418
3565
 
3419
- def for_block(block_params:, block_param_hint:, block_type_hint:, block_annotations:, node_type_hint:)
3420
- block_param_pairs = block_param_hint && block_params.zip(block_param_hint)
3566
+ def for_block(block_params:, block_param_hint:, block_type_hint:, block_block_hint:, block_annotations:, node_type_hint:)
3567
+ block_param_pairs = block_param_hint && block_params.zip(block_param_hint, block_block_hint)
3421
3568
 
3422
3569
  param_types_hash = {}
3423
3570
  if block_param_pairs
@@ -3485,7 +3632,8 @@ module Steep
3485
3632
  self_type: self_type,
3486
3633
  type_env: type_env.dup,
3487
3634
  lvar_env: lvar_env,
3488
- call_context: self.context.call_context
3635
+ call_context: self.context.call_context,
3636
+ variable_context: variable_context
3489
3637
  )
3490
3638
  )
3491
3639
  end