steep 0.47.1 → 0.49.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (59) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/Gemfile.lock +7 -7
  4. data/lib/steep/ast/types/any.rb +2 -0
  5. data/lib/steep/ast/types/boolean.rb +2 -0
  6. data/lib/steep/ast/types/bot.rb +2 -0
  7. data/lib/steep/ast/types/class.rb +2 -0
  8. data/lib/steep/ast/types/factory.rb +162 -138
  9. data/lib/steep/ast/types/helper.rb +8 -0
  10. data/lib/steep/ast/types/instance.rb +2 -0
  11. data/lib/steep/ast/types/intersection.rb +4 -0
  12. data/lib/steep/ast/types/literal.rb +2 -0
  13. data/lib/steep/ast/types/logic.rb +2 -0
  14. data/lib/steep/ast/types/name.rb +6 -0
  15. data/lib/steep/ast/types/nil.rb +2 -0
  16. data/lib/steep/ast/types/proc.rb +9 -0
  17. data/lib/steep/ast/types/record.rb +4 -0
  18. data/lib/steep/ast/types/self.rb +2 -0
  19. data/lib/steep/ast/types/top.rb +2 -0
  20. data/lib/steep/ast/types/tuple.rb +4 -0
  21. data/lib/steep/ast/types/union.rb +4 -0
  22. data/lib/steep/ast/types/var.rb +16 -3
  23. data/lib/steep/ast/types/void.rb +2 -0
  24. data/lib/steep/diagnostic/ruby.rb +36 -11
  25. data/lib/steep/diagnostic/signature.rb +56 -0
  26. data/lib/steep/drivers/check.rb +2 -2
  27. data/lib/steep/interface/function.rb +4 -0
  28. data/lib/steep/interface/method_type.rb +14 -26
  29. data/lib/steep/interface/type_param.rb +103 -0
  30. data/lib/steep/server/base_worker.rb +1 -0
  31. data/lib/steep/server/interaction_worker.rb +1 -1
  32. data/lib/steep/server/type_check_worker.rb +2 -2
  33. data/lib/steep/services/signature_service.rb +2 -2
  34. data/lib/steep/services/type_check_service.rb +4 -1
  35. data/lib/steep/signature/validator.rb +228 -49
  36. data/lib/steep/subtyping/cache.rb +30 -0
  37. data/lib/steep/subtyping/check.rb +600 -705
  38. data/lib/steep/subtyping/constraints.rb +66 -30
  39. data/lib/steep/subtyping/relation.rb +60 -0
  40. data/lib/steep/subtyping/result.rb +190 -16
  41. data/lib/steep/type_construction.rb +576 -395
  42. data/lib/steep/type_inference/block_params.rb +31 -3
  43. data/lib/steep/type_inference/context.rb +37 -3
  44. data/lib/steep/version.rb +1 -1
  45. data/lib/steep.rb +3 -3
  46. data/sample/lib/length.rb +35 -0
  47. data/sample/sig/length.rbs +34 -0
  48. data/smoke/diagnostics/proc_type_expected.rb +3 -0
  49. data/smoke/diagnostics/test_expectations.yml +12 -0
  50. data/smoke/diagnostics-rbs/nonregular-type-alias.rbs +3 -0
  51. data/smoke/diagnostics-rbs/recursive-type-alias.rbs +3 -0
  52. data/smoke/diagnostics-rbs/test_expectations.yml +57 -12
  53. data/smoke/tsort/a.rb +1 -1
  54. data/smoke/tsort/test_expectations.yml +1 -63
  55. data/steep.gemspec +1 -1
  56. metadata +14 -10
  57. data/lib/steep/drivers/trace_printer.rb +0 -29
  58. data/lib/steep/interface/method.rb +0 -78
  59. data/lib/steep/subtyping/trace.rb +0 -71
@@ -63,6 +63,10 @@ module Steep
63
63
  context.type_env
64
64
  end
65
65
 
66
+ def variable_context
67
+ context.variable_context
68
+ end
69
+
66
70
  def initialize(checker:, source:, annotations:, typing:, context:)
67
71
  @checker = checker
68
72
  @source = source
@@ -114,13 +118,15 @@ module Steep
114
118
  def check_relation(sub_type:, super_type:, constraints: Subtyping::Constraints.empty)
115
119
  Steep.logger.debug { "check_relation: self:#{self_type}, instance:#{module_context.instance_type}, class:#{module_context.module_type} |- #{sub_type} <: #{super_type}" }
116
120
  relation = Subtyping::Relation.new(sub_type: sub_type, super_type: super_type)
117
- checker.check(
118
- relation,
119
- self_type: self_type,
120
- instance_type: module_context.instance_type,
121
- class_type: module_context.module_type,
122
- constraints: constraints
123
- )
121
+ checker.push_variable_bounds(variable_context.upper_bounds) do
122
+ checker.check(
123
+ relation,
124
+ self_type: self_type,
125
+ instance_type: module_context.instance_type,
126
+ class_type: module_context.module_type,
127
+ constraints: constraints
128
+ )
129
+ end
124
130
  end
125
131
 
126
132
  def for_new_method(method_name, node, args:, self_type:, definition:)
@@ -169,6 +175,12 @@ module Steep
169
175
  end
170
176
  end
171
177
 
178
+ if definition && method_type
179
+ variable_context = TypeInference::Context::TypeVariableContext.new(method_type.type_params, parent_context: self.variable_context)
180
+ else
181
+ variable_context = self.variable_context
182
+ end
183
+
172
184
  method_context = TypeInference::Context::MethodContext.new(
173
185
  name: method_name,
174
186
  method: definition && definition.methods[method_name],
@@ -245,12 +257,20 @@ module Steep
245
257
  self_type: annots.self_type || self_type,
246
258
  type_env: type_env,
247
259
  lvar_env: lvar_env,
248
- call_context: call_context
260
+ call_context: call_context,
261
+ variable_context: variable_context
249
262
  ),
250
263
  typing: typing,
251
264
  )
252
265
  end
253
266
 
267
+ def with_method_constr(method_name, node, args:, self_type:, definition:)
268
+ constr = for_new_method(method_name, node, args: args, self_type: self_type, definition: definition)
269
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
270
+ yield constr
271
+ end
272
+ end
273
+
254
274
  def implement_module(module_name:, super_name: nil, annotations:)
255
275
  if (annotation = annotations.implement_module_annotation)
256
276
  absolute_name(annotation.name.name).yield_self do |absolute_name|
@@ -332,7 +352,10 @@ module Steep
332
352
  module_context = default_module_context(implement_module_name, const_env: module_const_env, current_namespace: new_namespace)
333
353
 
334
354
  unless implement_module_name
335
- module_context = module_context.update(module_type: AST::Builtin::Module.instance_type)
355
+ module_context = module_context.update(
356
+ module_type: AST::Builtin::Module.instance_type,
357
+ instance_type: AST::Builtin::BasicObject.instance_type
358
+ )
336
359
  end
337
360
 
338
361
  if implement_module_name
@@ -389,6 +412,21 @@ module Steep
389
412
  class_type: module_context.module_type
390
413
  ).annotate(annots)
391
414
 
415
+ if implement_module_name
416
+ definition = checker.factory.definition_builder.build_instance(implement_module_name.name)
417
+ type_params = definition.type_params_decl.map do |param|
418
+ Interface::TypeParam.new(
419
+ name: param.name,
420
+ upper_bound: checker.factory.type_opt(param.upper_bound),
421
+ variance: param.variance,
422
+ unchecked: param.unchecked?
423
+ )
424
+ end
425
+ variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
426
+ else
427
+ variable_context = TypeInference::Context::TypeVariableContext.empty
428
+ end
429
+
392
430
  self.class.new(
393
431
  checker: checker,
394
432
  source: source,
@@ -402,11 +440,19 @@ module Steep
402
440
  self_type: module_context.module_type,
403
441
  type_env: module_type_env,
404
442
  lvar_env: lvar_env,
405
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
443
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
444
+ variable_context: variable_context
406
445
  )
407
446
  )
408
447
  end
409
448
 
449
+ def with_module_constr(node)
450
+ constr = for_module(node)
451
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
452
+ yield constr
453
+ end
454
+ end
455
+
410
456
  def for_class(node)
411
457
  new_class_name = module_name_from_node(node.children.first) or raise "Unexpected class name: #{node.children.first}"
412
458
  super_class_name = node.children[1] && module_name_from_node(node.children[1])
@@ -422,10 +468,7 @@ module Steep
422
468
 
423
469
  if implement_module_name
424
470
  if super_class_name && implement_module_name.name == absolute_name(super_class_name)
425
- module_context = module_context.update(
426
- instance_definition: nil,
427
- module_definition: nil
428
- )
471
+ module_context = module_context.update(instance_definition: nil, module_definition: nil)
429
472
  end
430
473
  else
431
474
  module_context = module_context.update(
@@ -446,6 +489,18 @@ module Steep
446
489
  module_context = module_context.update(module_type: annots.self_type)
447
490
  end
448
491
 
492
+ definition = checker.factory.definition_builder.build_instance(module_context.class_name)
493
+ type_params = definition.type_params_decl.map do |type_param|
494
+ Interface::TypeParam.new(
495
+ name: type_param.name,
496
+ upper_bound: type_param.upper_bound&.yield_self {|t| checker.factory.type(t) },
497
+ variance: type_param.variance,
498
+ unchecked: type_param.unchecked?,
499
+ location: type_param.location
500
+ )
501
+ end
502
+ variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
503
+
449
504
  class_type_env = TypeInference::TypeEnv.build(annotations: annots,
450
505
  subtyping: checker,
451
506
  const_env: class_const_env,
@@ -466,7 +521,8 @@ module Steep
466
521
  self_type: module_context.module_type,
467
522
  type_env: class_type_env,
468
523
  lvar_env: lvar_env,
469
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
524
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
525
+ variable_context: variable_context
470
526
  )
471
527
 
472
528
  self.class.new(
@@ -478,6 +534,24 @@ module Steep
478
534
  )
479
535
  end
480
536
 
537
+ def with_class_constr(node)
538
+ constr = for_class(node)
539
+
540
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
541
+ yield constr
542
+ end
543
+ end
544
+
545
+ def with_sclass_constr(node, type)
546
+ if constr = for_sclass(node, type)
547
+ constr.checker.push_variable_bounds(constr.variable_context.upper_bounds) do
548
+ yield constr
549
+ end
550
+ else
551
+ yield nil
552
+ end
553
+ end
554
+
481
555
  def for_sclass(node, type)
482
556
  annots = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
483
557
 
@@ -554,7 +628,8 @@ module Steep
554
628
  self_type: module_context.module_type,
555
629
  type_env: type_env,
556
630
  lvar_env: lvar_env,
557
- call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name)
631
+ call_context: TypeInference::MethodCall::ModuleContext.new(type_name: module_context.class_name),
632
+ variable_context: TypeInference::Context::TypeVariableContext.empty # Assuming `::Class` and `::Module` don't have type params
558
633
  )
559
634
 
560
635
  self.class.new(
@@ -930,7 +1005,7 @@ module Steep
930
1005
  yield_self do
931
1006
  send_node, params, body = node.children
932
1007
  if send_node.type == :lambda
933
- type_lambda(node, block_params: params, block_body: body, type_hint: hint)
1008
+ type_lambda(node, params_node: params, body_node: body, type_hint: hint)
934
1009
  else
935
1010
  type_send(node, send_node: send_node, block_params: params, block_body: body, unwrap: send_node.type == :csend)
936
1011
  end
@@ -949,7 +1024,7 @@ module Steep
949
1024
  params = Parser::AST::Node.new(:args, arg_nodes)
950
1025
 
951
1026
  if send_node.type == :lambda
952
- type_lambda(node, block_params: params, block_body: body, type_hint: hint)
1027
+ type_lambda(node, params_node: params, body_node: body, type_hint: hint)
953
1028
  else
954
1029
  type_send(node, send_node: send_node, block_params: params, block_body: body, unwrap: send_node.type == :csend)
955
1030
  end
@@ -959,75 +1034,76 @@ module Steep
959
1034
  yield_self do
960
1035
  name, args_node, body_node = node.children
961
1036
 
962
- new = for_new_method(
1037
+ with_method_constr(
963
1038
  name,
964
1039
  node,
965
1040
  args: args_node.children,
966
1041
  self_type: module_context&.instance_type,
967
1042
  definition: module_context&.instance_definition
968
- )
969
- new.typing.add_context_for_node(node, context: new.context)
970
- new.typing.add_context_for_body(node, context: new.context)
971
-
972
- new.method_context.tap do |method_context|
973
- if method_context.method
974
- method_name = InstanceMethodName.new(type_name: method_context.method.implemented_in, method_name: name)
975
- new.typing.source_index.add_definition(method: method_name, definition: node)
1043
+ ) do |new|
1044
+ new.typing.add_context_for_node(node, context: new.context)
1045
+ new.typing.add_context_for_body(node, context: new.context)
1046
+
1047
+ new.method_context.tap do |method_context|
1048
+ if method_context.method
1049
+ method_name = InstanceMethodName.new(type_name: method_context.method.implemented_in, method_name: name)
1050
+ new.typing.source_index.add_definition(method: method_name, definition: node)
1051
+ end
976
1052
  end
977
- end
978
-
979
- new = new.synthesize_children(args_node)
980
1053
 
981
- body_pair = if body_node
982
- return_type = expand_alias(new.method_context&.return_type)
983
- if return_type && !return_type.is_a?(AST::Types::Void)
984
- new.check(body_node, return_type) do |_, actual_type, result|
985
- typing.add_error(
986
- Diagnostic::Ruby::MethodBodyTypeMismatch.new(
987
- node: node,
988
- expected: new.method_context&.return_type,
989
- actual: actual_type,
990
- result: result
1054
+ new = new.synthesize_children(args_node)
1055
+
1056
+ body_pair = if body_node
1057
+ return_type = expand_alias(new.method_context&.return_type)
1058
+ if return_type && !return_type.is_a?(AST::Types::Void)
1059
+ new.check(body_node, return_type) do |_, actual_type, result|
1060
+ typing.add_error(
1061
+ Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1062
+ node: node,
1063
+ expected: new.method_context&.return_type,
1064
+ actual: actual_type,
1065
+ result: result
1066
+ )
991
1067
  )
992
- )
1068
+ end
1069
+ else
1070
+ new.synthesize(body_node)
993
1071
  end
994
1072
  else
995
- new.synthesize(body_node)
996
- end
997
- else
998
- return_type = expand_alias(new.method_context&.return_type)
999
- if return_type && !return_type.is_a?(AST::Types::Void)
1000
- result = check_relation(sub_type: AST::Builtin.nil_type, super_type: return_type)
1001
- if result.failure?
1002
- typing.add_error(
1003
- Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1004
- node: node,
1005
- expected: new.method_context&.return_type,
1006
- actual: AST::Builtin.nil_type,
1007
- result: result
1073
+ return_type = expand_alias(new.method_context&.return_type)
1074
+ if return_type && !return_type.is_a?(AST::Types::Void)
1075
+ result = check_relation(sub_type: AST::Builtin.nil_type, super_type: return_type)
1076
+ if result.failure?
1077
+ typing.add_error(
1078
+ Diagnostic::Ruby::MethodBodyTypeMismatch.new(
1079
+ node: node,
1080
+ expected: new.method_context&.return_type,
1081
+ actual: AST::Builtin.nil_type,
1082
+ result: result
1083
+ )
1008
1084
  )
1009
- )
1085
+ end
1010
1086
  end
1087
+
1088
+ Pair.new(type: AST::Builtin.nil_type, constr: new)
1011
1089
  end
1012
1090
 
1013
- Pair.new(type: AST::Builtin.nil_type, constr: new)
1014
- end
1091
+ if body_node
1092
+ # Add context to ranges from the end of the method body to the beginning of the `end` keyword
1093
+ if node.loc.end
1094
+ # Skip end-less def
1095
+ begin_pos = body_node.loc.expression.end_pos
1096
+ end_pos = node.loc.end.begin_pos
1097
+ typing.add_context(begin_pos..end_pos, context: body_pair.context)
1098
+ end
1099
+ end
1015
1100
 
1016
- if body_node
1017
- # Add context to ranges from the end of the method body to the beginning of the `end` keyword
1018
- if node.loc.end
1019
- # Skip end-less def
1020
- begin_pos = body_node.loc.expression.end_pos
1021
- end_pos = node.loc.end.begin_pos
1022
- typing.add_context(begin_pos..end_pos, context: body_pair.context)
1101
+ if module_context
1102
+ module_context.defined_instance_methods << node.children[0]
1023
1103
  end
1024
- end
1025
1104
 
1026
- if module_context
1027
- module_context.defined_instance_methods << node.children[0]
1105
+ add_typing(node, type: AST::Builtin::Symbol.instance_type)
1028
1106
  end
1029
-
1030
- add_typing(node, type: AST::Builtin::Symbol.instance_type)
1031
1107
  end
1032
1108
 
1033
1109
  when :defs
@@ -1105,16 +1181,18 @@ module Steep
1105
1181
 
1106
1182
  when :return
1107
1183
  yield_self do
1108
- if node.children.size > 0
1109
- method_return_type = expand_alias(method_context&.return_type)
1184
+ method_return_type = expand_alias(method_context&.return_type)
1110
1185
 
1186
+ if node.children.size > 0
1111
1187
  return_types = node.children.map do |value|
1112
- synthesize(value,
1113
- hint: if method_return_type.is_a?(AST::Types::Void)
1114
- nil
1115
- else
1116
- method_return_type
1117
- end).type
1188
+ synthesize(
1189
+ value,
1190
+ hint: if method_return_type.is_a?(AST::Types::Void)
1191
+ nil
1192
+ else
1193
+ method_return_type
1194
+ end
1195
+ ).type
1118
1196
  end
1119
1197
 
1120
1198
  value_type = if return_types.size == 1
@@ -1123,20 +1201,23 @@ module Steep
1123
1201
  AST::Builtin::Array.instance_type(union_type(*return_types))
1124
1202
  end
1125
1203
 
1126
- if method_return_type
1127
- unless method_return_type.is_a?(AST::Types::Void)
1128
- result = check_relation(sub_type: value_type, super_type: method_return_type)
1204
+ else
1205
+ value_type = AST::Builtin.nil_type
1206
+ end
1207
+
1208
+ if method_return_type
1209
+ unless method_return_type.is_a?(AST::Types::Void)
1210
+ result = check_relation(sub_type: value_type, super_type: method_return_type)
1129
1211
 
1130
- if result.failure?
1131
- typing.add_error(
1132
- Diagnostic::Ruby::ReturnTypeMismatch.new(
1133
- node: node,
1134
- expected: method_context&.return_type,
1135
- actual: value_type,
1136
- result: result
1137
- )
1212
+ if result.failure?
1213
+ typing.add_error(
1214
+ Diagnostic::Ruby::ReturnTypeMismatch.new(
1215
+ node: node,
1216
+ expected: method_context&.return_type,
1217
+ actual: value_type,
1218
+ result: result
1138
1219
  )
1139
- end
1220
+ )
1140
1221
  end
1141
1222
  end
1142
1223
  end
@@ -1427,7 +1508,7 @@ module Steep
1427
1508
  end
1428
1509
  _, constr = constr.synthesize(sup) if sup
1429
1510
 
1430
- constr.for_class(node).tap do |constructor|
1511
+ with_class_constr(node) do |constructor|
1431
1512
  if module_type = constructor.module_context&.module_type
1432
1513
  _, constructor = constructor.add_typing(name, type: module_type)
1433
1514
  else
@@ -1467,7 +1548,7 @@ module Steep
1467
1548
  _, constr = constr.synthesize(name)
1468
1549
  end
1469
1550
 
1470
- for_module(node).yield_self do |constructor|
1551
+ with_module_constr(node) do |constructor|
1471
1552
  if module_type = constructor.module_context&.module_type
1472
1553
  _, constructor = constructor.add_typing(name, type: module_type)
1473
1554
  else
@@ -1495,27 +1576,28 @@ module Steep
1495
1576
  when :sclass
1496
1577
  yield_self do
1497
1578
  type, constr = synthesize(node.children[0])
1498
- constructor = constr.for_sclass(node, type)
1499
1579
 
1500
- unless constructor
1501
- typing.add_error(
1502
- Diagnostic::Ruby::UnsupportedSyntax.new(
1503
- node: node,
1504
- message: "sclass receiver must be instance type or singleton type, but type given `#{type}`"
1580
+ with_sclass_constr(node, type) do |constructor|
1581
+ unless constructor
1582
+ typing.add_error(
1583
+ Diagnostic::Ruby::UnsupportedSyntax.new(
1584
+ node: node,
1585
+ message: "sclass receiver must be instance type or singleton type, but type given `#{type}`"
1586
+ )
1505
1587
  )
1506
- )
1507
- constr.add_typing(node, type: AST::Builtin.nil_type)
1508
- return
1509
- end
1588
+ constr.add_typing(node, type: AST::Builtin.nil_type)
1589
+ return
1590
+ end
1510
1591
 
1511
- constructor.typing.add_context_for_node(node, context: constructor.context)
1512
- constructor.typing.add_context_for_body(node, context: constructor.context)
1592
+ constructor.typing.add_context_for_node(node, context: constructor.context)
1593
+ constructor.typing.add_context_for_body(node, context: constructor.context)
1513
1594
 
1514
- constructor.synthesize(node.children[1]) if node.children[1]
1595
+ constructor.synthesize(node.children[1]) if node.children[1]
1515
1596
 
1516
- if constructor.module_context.instance_definition && module_context.module_definition
1517
- if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
1518
- module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
1597
+ if constructor.module_context.instance_definition && module_context.module_definition
1598
+ if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
1599
+ module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
1600
+ end
1519
1601
  end
1520
1602
  end
1521
1603
 
@@ -1573,8 +1655,7 @@ module Steep
1573
1655
  instance_type: module_context.instance_type,
1574
1656
  class_type: module_context.module_type
1575
1657
  ) do |error|
1576
- case error
1577
- when Subtyping::Result::Failure
1658
+ if error
1578
1659
  const_type = type_env.get(const: const_name)
1579
1660
  typing.add_error(
1580
1661
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -1584,7 +1665,7 @@ module Steep
1584
1665
  result: error
1585
1666
  )
1586
1667
  )
1587
- when nil
1668
+ else
1588
1669
  typing.add_error(
1589
1670
  Diagnostic::Ruby::UnknownConstantAssigned.new(
1590
1671
  node: node,
@@ -2385,7 +2466,7 @@ module Steep
2385
2466
  raise "#synthesize should return an instance of Pair: #{pair.class}, node=#{node.inspect}"
2386
2467
  end
2387
2468
  end
2388
- rescue RBS::ErrorBase => exn
2469
+ rescue RBS::BaseError => exn
2389
2470
  Steep.logger.warn { "Unexpected RBS error: #{exn.message}" }
2390
2471
  exn.backtrace.each {|loc| Steep.logger.warn " #{loc}" }
2391
2472
  typing.add_error(Diagnostic::Ruby::UnexpectedError.new(node: node, error: exn))
@@ -2411,6 +2492,7 @@ module Steep
2411
2492
 
2412
2493
  def type_ivasgn(name, rhs, node)
2413
2494
  rhs_type = synthesize(rhs, hint: type_env.get(ivar: name) { fallback_to_any(node) }).type
2495
+
2414
2496
  ivar_type = type_env.assign(
2415
2497
  ivar: name,
2416
2498
  type: rhs_type,
@@ -2418,8 +2500,7 @@ module Steep
2418
2500
  instance_type: module_context.instance_type,
2419
2501
  class_type: module_context.module_type
2420
2502
  ) do |error|
2421
- case error
2422
- when Subtyping::Result::Failure
2503
+ if error
2423
2504
  type = type_env.get(ivar: name)
2424
2505
  typing.add_error(
2425
2506
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -2429,7 +2510,7 @@ module Steep
2429
2510
  result: error
2430
2511
  )
2431
2512
  )
2432
- when nil
2513
+ else
2433
2514
  fallback_to_any node
2434
2515
  end
2435
2516
  end
@@ -2473,8 +2554,7 @@ module Steep
2473
2554
  instance_type: module_context.instance_type,
2474
2555
  class_type: module_context.module_type
2475
2556
  ) do |error|
2476
- case error
2477
- when Subtyping::Result::Failure
2557
+ if error
2478
2558
  var_type = type_env.get(ivar: ivar)
2479
2559
  typing.add_error(
2480
2560
  Diagnostic::Ruby::IncompatibleAssignment.new(
@@ -2484,7 +2564,7 @@ module Steep
2484
2564
  result: error
2485
2565
  )
2486
2566
  )
2487
- when nil
2567
+ else
2488
2568
  fallback_to_any node
2489
2569
  end
2490
2570
  end
@@ -2647,34 +2727,89 @@ module Steep
2647
2727
  end
2648
2728
  end
2649
2729
 
2650
- def type_lambda(node, block_params:, block_body:, type_hint:)
2730
+ def optional_proc?(type)
2731
+ if type.is_a?(AST::Types::Union)
2732
+ if type.types.size == 2
2733
+ if type.types.find {|t| t.is_a?(AST::Types::Nil) }
2734
+ if proc_type = type.types.find {|t| t.is_a?(AST::Types::Proc) }
2735
+ proc_type
2736
+ end
2737
+ end
2738
+ end
2739
+ end
2740
+ end
2741
+
2742
+ def type_lambda(node, params_node:, body_node:, type_hint:)
2651
2743
  block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
2652
- params = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
2744
+ params = TypeInference::BlockParams.from_node(params_node, annotations: block_annotations)
2745
+
2746
+ type_hint = deep_expand_alias(type_hint) if type_hint
2653
2747
 
2654
2748
  case type_hint
2655
2749
  when AST::Types::Proc
2656
2750
  params_hint = type_hint.type.params
2657
2751
  return_hint = type_hint.type.return_type
2752
+ block_hint = type_hint.block
2658
2753
  end
2659
2754
 
2660
2755
  block_constr = for_block(
2661
2756
  block_params: params,
2662
2757
  block_param_hint: params_hint,
2663
2758
  block_type_hint: return_hint,
2759
+ block_block_hint: block_hint,
2664
2760
  block_annotations: block_annotations,
2665
2761
  node_type_hint: nil
2666
2762
  )
2667
2763
 
2668
2764
  block_constr.typing.add_context_for_body(node, context: block_constr.context)
2669
2765
 
2766
+ default_proc_function =
2767
+ Interface::Function.new(
2768
+ params: Interface::Function::Params.empty,
2769
+ return_type: AST::Builtin.any_type,
2770
+ location: nil
2771
+ )
2772
+
2670
2773
  params.params.each do |param|
2671
2774
  _, block_constr = block_constr.synthesize(param.node, hint: param.type)
2672
2775
  end
2673
2776
 
2674
- if block_body
2777
+ block =
2778
+ if block_param = params.block_param
2779
+ if block_param_type = block_param.type
2780
+ case block_param_type
2781
+ when AST::Types::Proc
2782
+ Interface::Block.new(type: block_param_type.type, optional: false)
2783
+ else
2784
+ if proc_type = optional_proc?(block_param_type)
2785
+ Interface::Block.new(type: proc_type.type, optional: true)
2786
+ else
2787
+ block_constr.typing.add_error(
2788
+ Diagnostic::Ruby::ProcTypeExpected.new(
2789
+ node: block_param.node,
2790
+ type: block_param_type
2791
+ )
2792
+ )
2793
+
2794
+ Interface::Block.new(
2795
+ type: Interface::Function.new(
2796
+ params: Interface::Function::Params.empty,
2797
+ return_type: AST::Builtin.any_type,
2798
+ location: nil
2799
+ ),
2800
+ optional: false
2801
+ )
2802
+ end
2803
+ end
2804
+ else
2805
+ block_hint
2806
+ end
2807
+ end
2808
+
2809
+ if body_node
2675
2810
  return_type = block_constr.synthesize_block(
2676
2811
  node: node,
2677
- block_body: block_body,
2812
+ block_body: body_node,
2678
2813
  block_type_hint: return_hint
2679
2814
  )
2680
2815
 
@@ -2702,7 +2837,7 @@ module Steep
2702
2837
  return_type: return_type,
2703
2838
  location: nil
2704
2839
  ),
2705
- block: nil
2840
+ block: block
2706
2841
  )
2707
2842
 
2708
2843
  add_typing node, type: block_type
@@ -2830,6 +2965,32 @@ module Steep
2830
2965
  )
2831
2966
  )
2832
2967
 
2968
+ when AST::Types::Var
2969
+ if upper_bound = variable_context[receiver_type.name]
2970
+ interface = calculate_interface(upper_bound, private: false)
2971
+
2972
+ constr.type_send_interface(
2973
+ node,
2974
+ interface: interface,
2975
+ receiver: receiver,
2976
+ receiver_type: receiver_type,
2977
+ method_name: method_name,
2978
+ arguments: arguments,
2979
+ block_params: block_params,
2980
+ block_body: block_body
2981
+ )
2982
+ else
2983
+ constr = constr.synthesize_children(node, skips: [receiver])
2984
+ constr.add_call(
2985
+ TypeInference::MethodCall::NoMethodError.new(
2986
+ node: node,
2987
+ context: context.method_context,
2988
+ method_name: method_name,
2989
+ receiver_type: receiver_type,
2990
+ error: Diagnostic::Ruby::NoMethod.new(node: node, method: method_name, type: receiver_type)
2991
+ )
2992
+ )
2993
+ end
2833
2994
  when AST::Types::Void, AST::Types::Bot, AST::Types::Top, AST::Types::Var
2834
2995
  constr = constr.synthesize_children(node, skips: [receiver])
2835
2996
  constr.add_call(
@@ -2872,6 +3033,7 @@ module Steep
2872
3033
  block_params: block_params,
2873
3034
  block_body: block_body)
2874
3035
  end
3036
+
2875
3037
  else
2876
3038
  interface = calculate_interface(receiver_type, private: private, self_type: receiver_type)
2877
3039
 
@@ -2991,119 +3153,166 @@ module Steep
2991
3153
  end
2992
3154
  end
2993
3155
 
3156
+ def apply_solution(errors, node:, method_type:)
3157
+ subst = yield
3158
+
3159
+ [
3160
+ method_type.subst(subst),
3161
+ true,
3162
+ subst
3163
+ ]
3164
+
3165
+ rescue Subtyping::Constraints::UnsatisfiableConstraint => exn
3166
+ errors << Diagnostic::Ruby::UnsatisfiableConstraint.new(
3167
+ node: node,
3168
+ method_type: method_type,
3169
+ var: exn.var,
3170
+ sub_type: exn.sub_type,
3171
+ super_type: exn.super_type,
3172
+ result: exn.result
3173
+ )
3174
+ [method_type, false, Interface::Substitution.empty]
3175
+ end
3176
+
3177
+ def eliminate_vars(type, variables, to: AST::Builtin.any_type)
3178
+ if variables.empty?
3179
+ type
3180
+ else
3181
+ subst = Interface::Substitution.build(variables, Array.new(variables.size, to))
3182
+ type.subst(subst)
3183
+ end
3184
+ end
3185
+
2994
3186
  def try_method_type(node, receiver_type:, method_name:, method_type:, arguments:, block_params:, block_body:, topdown_hint:)
2995
- fresh_types = method_type.type_params.map {|x| AST::Types::Var.fresh(x)}
2996
- fresh_vars = Set.new(fresh_types.map(&:name))
2997
- instantiation = Interface::Substitution.build(method_type.type_params, fresh_types)
3187
+ type_params, instantiation = Interface::TypeParam.rename(method_type.type_params)
3188
+ type_param_names = type_params.map(&:name)
2998
3189
 
2999
3190
  constr = self
3000
3191
 
3001
3192
  method_type = method_type.instantiate(instantiation)
3002
3193
 
3003
- constraints = Subtyping::Constraints.new(unknowns: fresh_types.map(&:name))
3004
3194
  variance = Subtyping::VariableVariance.from_method_type(method_type)
3005
3195
  occurence = Subtyping::VariableOccurence.from_method_type(method_type)
3196
+ constraints = Subtyping::Constraints.new(unknowns: type_params.map(&:name))
3197
+ ccontext = Subtyping::Constraints::Context.new(
3198
+ self_type: self_type,
3199
+ instance_type: module_context.instance_type,
3200
+ class_type: module_context.module_type,
3201
+ variance: variance
3202
+ )
3006
3203
 
3007
- errors = []
3204
+ upper_bounds = {}
3008
3205
 
3009
- args = TypeInference::SendArgs.new(node: node, arguments: arguments, method_name: method_name, method_type: method_type)
3010
- es = args.each do |arg|
3011
- case arg
3012
- when TypeInference::SendArgs::PositionalArgs::NodeParamPair
3013
- _, constr = constr.type_check_argument(
3014
- arg.node,
3015
- type: arg.param.type,
3016
- receiver_type: receiver_type,
3017
- constraints: constraints,
3018
- errors: errors
3019
- )
3206
+ type_params.each do |param|
3207
+ if ub = param.upper_bound
3208
+ constraints.add(param.name, super_type: ub, skip: true)
3209
+ upper_bounds[param.name] = ub
3210
+ end
3211
+ end
3020
3212
 
3021
- when TypeInference::SendArgs::PositionalArgs::NodeTypePair
3022
- _, constr = bypass_splat(arg.node) do |n|
3023
- constr.type_check_argument(
3024
- n,
3025
- type: arg.node_type,
3213
+ checker.push_variable_bounds(upper_bounds) do
3214
+ errors = []
3215
+
3216
+ args = TypeInference::SendArgs.new(node: node, arguments: arguments, method_name: method_name, method_type: method_type)
3217
+ es = args.each do |arg|
3218
+ case arg
3219
+ when TypeInference::SendArgs::PositionalArgs::NodeParamPair
3220
+ _, constr = constr.type_check_argument(
3221
+ arg.node,
3222
+ type: arg.param.type,
3026
3223
  receiver_type: receiver_type,
3027
3224
  constraints: constraints,
3028
- report_node: arg.node,
3029
3225
  errors: errors
3030
3226
  )
3031
- end
3032
3227
 
3033
- when TypeInference::SendArgs::PositionalArgs::UnexpectedArg
3034
- _, constr = bypass_splat(arg.node) do |n|
3035
- constr.synthesize(n)
3036
- end
3037
-
3038
- when TypeInference::SendArgs::PositionalArgs::SplatArg
3039
- arg_type, _ = constr
3040
- .with_child_typing(range: arg.node.loc.expression.begin_pos ... arg.node.loc.expression.end_pos)
3041
- .try_tuple_type!(arg.node.children[0])
3042
- arg.type = arg_type
3043
-
3044
- when TypeInference::SendArgs::PositionalArgs::MissingArg
3045
- # ignore
3046
-
3047
- when TypeInference::SendArgs::KeywordArgs::ArgTypePairs
3048
- arg.pairs.each do |node, type|
3049
- _, constr = bypass_splat(node) do |node|
3228
+ when TypeInference::SendArgs::PositionalArgs::NodeTypePair
3229
+ _, constr = bypass_splat(arg.node) do |n|
3050
3230
  constr.type_check_argument(
3051
- node,
3052
- type: type,
3231
+ n,
3232
+ type: arg.node_type,
3053
3233
  receiver_type: receiver_type,
3054
3234
  constraints: constraints,
3235
+ report_node: arg.node,
3055
3236
  errors: errors
3056
3237
  )
3057
3238
  end
3058
- end
3059
3239
 
3060
- when TypeInference::SendArgs::KeywordArgs::UnexpectedKeyword
3061
- if arg.node.type == :pair
3062
- arg.node.children.each do |nn|
3063
- _, constr = constr.synthesize(nn)
3064
- end
3065
- else
3240
+ when TypeInference::SendArgs::PositionalArgs::UnexpectedArg
3066
3241
  _, constr = bypass_splat(arg.node) do |n|
3067
3242
  constr.synthesize(n)
3068
3243
  end
3069
- end
3070
3244
 
3071
- when TypeInference::SendArgs::KeywordArgs::SplatArg
3072
- type, _ = bypass_splat(arg.node) do |sp_node|
3073
- if sp_node.type == :hash
3074
- pair = constr.type_hash_record(sp_node, nil) and break pair
3245
+ when TypeInference::SendArgs::PositionalArgs::SplatArg
3246
+ arg_type, _ = constr
3247
+ .with_child_typing(range: arg.node.loc.expression.begin_pos ... arg.node.loc.expression.end_pos)
3248
+ .try_tuple_type!(arg.node.children[0])
3249
+ arg.type = arg_type
3250
+
3251
+ when TypeInference::SendArgs::PositionalArgs::MissingArg
3252
+ # ignore
3253
+
3254
+ when TypeInference::SendArgs::KeywordArgs::ArgTypePairs
3255
+ arg.pairs.each do |node, type|
3256
+ _, constr = bypass_splat(node) do |node|
3257
+ constr.type_check_argument(
3258
+ node,
3259
+ type: type,
3260
+ receiver_type: receiver_type,
3261
+ constraints: constraints,
3262
+ errors: errors
3263
+ )
3264
+ end
3075
3265
  end
3076
3266
 
3077
- constr.synthesize(sp_node)
3078
- end
3267
+ when TypeInference::SendArgs::KeywordArgs::UnexpectedKeyword
3268
+ if arg.node.type == :pair
3269
+ arg.node.children.each do |nn|
3270
+ _, constr = constr.synthesize(nn)
3271
+ end
3272
+ else
3273
+ _, constr = bypass_splat(arg.node) do |n|
3274
+ constr.synthesize(n)
3275
+ end
3276
+ end
3277
+
3278
+ when TypeInference::SendArgs::KeywordArgs::SplatArg
3279
+ type, _ = bypass_splat(arg.node) do |sp_node|
3280
+ if sp_node.type == :hash
3281
+ pair = constr.type_hash_record(sp_node, nil) and break pair
3282
+ end
3079
3283
 
3080
- arg.type = type
3284
+ constr.synthesize(sp_node)
3285
+ end
3081
3286
 
3082
- when TypeInference::SendArgs::KeywordArgs::MissingKeyword
3083
- # ignore
3084
- else
3085
- raise arg.inspect
3086
- end
3287
+ arg.type = type
3087
3288
 
3088
- constr
3089
- end
3289
+ when TypeInference::SendArgs::KeywordArgs::MissingKeyword
3290
+ # ignore
3291
+ else
3292
+ raise arg.inspect
3293
+ end
3090
3294
 
3091
- errors.push(*es)
3295
+ constr
3296
+ end
3297
+
3298
+ errors.push(*es)
3092
3299
 
3093
- if block_params
3094
- # block is given
3095
- block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
3096
- block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
3300
+ if block_params
3301
+ # block is given
3302
+ block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)
3303
+ block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations)
3097
3304
 
3098
- if method_type.block
3099
- pairs = method_type.block && block_params_&.zip(method_type.block.type.params)
3305
+ if method_type.block
3306
+ # Method accepts block
3307
+ pairs = method_type.block && block_params_&.zip(method_type.block.type.params, nil)
3100
3308
 
3101
- if pairs
3102
- begin
3309
+ if pairs
3310
+ # Block parameters are compatible with the block type
3103
3311
  block_constr = constr.for_block(
3104
3312
  block_params: block_params_,
3105
3313
  block_param_hint: method_type.block.type.params,
3106
3314
  block_type_hint: method_type.block.type.return_type,
3315
+ block_block_hint: nil,
3107
3316
  block_annotations: block_annotations,
3108
3317
  node_type_hint: method_type.type.return_type
3109
3318
  )
@@ -3131,236 +3340,206 @@ module Steep
3131
3340
  end
3132
3341
  end
3133
3342
 
3134
- s = constraints.solution(
3135
- checker,
3136
- self_type: self_type,
3137
- instance_type: module_context.instance_type,
3138
- class_type: module_context.module_type,
3139
- variance: variance,
3140
- variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables
3141
- )
3142
- method_type = method_type.subst(s)
3143
- block_constr = block_constr.update_lvar_env {|env| env.subst(s) }
3144
- if block_body
3145
- block_body_type = block_constr.synthesize_block(
3146
- node: node,
3147
- block_body: block_body,
3148
- block_type_hint: method_type.block.type.return_type
3343
+ method_type, solved, s = apply_solution(errors, node: node, method_type: method_type) {
3344
+ constraints.solution(
3345
+ checker,
3346
+ variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables,
3347
+ context: ccontext
3149
3348
  )
3150
- else
3151
- block_body_type = AST::Builtin.nil_type
3152
- end
3349
+ }
3153
3350
 
3154
- result = check_relation(sub_type: block_body_type,
3155
- super_type: method_type.block.type.return_type,
3156
- constraints: constraints)
3351
+ if solved
3352
+ # Ready for type check the body of the block
3353
+ block_constr = block_constr.update_lvar_env {|env| env.subst(s) }
3354
+ if block_body
3355
+ block_body_type = block_constr.synthesize_block(
3356
+ node: node,
3357
+ block_body: block_body,
3358
+ block_type_hint: method_type.block.type.return_type
3359
+ )
3360
+ else
3361
+ block_body_type = AST::Builtin.nil_type
3362
+ end
3157
3363
 
3158
- case result
3159
- when Subtyping::Result::Success
3160
- s = constraints.solution(
3161
- checker,
3162
- self_type: self_type,
3163
- instance_type: module_context.instance_type,
3164
- class_type: module_context.module_type,
3165
- variance: variance,
3166
- variables: fresh_vars
3167
- )
3168
- method_type = method_type.subst(s)
3364
+ result = check_relation(sub_type: block_body_type,
3365
+ super_type: method_type.block.type.return_type,
3366
+ constraints: constraints)
3169
3367
 
3170
- return_type = method_type.type.return_type
3171
- if break_type = block_annotations.break_type
3172
- return_type = union_type(break_type, return_type)
3368
+ if result.success?
3369
+ # Successfully type checked the body
3370
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) do
3371
+ constraints.solution(checker, variables: type_param_names, context: ccontext)
3372
+ end
3373
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3374
+
3375
+ return_type = method_type.type.return_type
3376
+ if break_type = block_annotations.break_type
3377
+ return_type = union_type(break_type, return_type)
3378
+ end
3379
+ else
3380
+ # The block body has incompatible type
3381
+ errors << Diagnostic::Ruby::BlockBodyTypeMismatch.new(
3382
+ node: node,
3383
+ expected: method_type.block.type.return_type,
3384
+ actual: block_body_type,
3385
+ result: result
3386
+ )
3387
+
3388
+ method_type = eliminate_vars(method_type, type_param_names)
3389
+ return_type = method_type.type.return_type
3173
3390
  end
3174
3391
 
3175
- when Subtyping::Result::Failure
3176
- errors << Diagnostic::Ruby::BlockBodyTypeMismatch.new(
3177
- node: node,
3178
- expected: method_type.block.type.return_type,
3179
- actual: block_body_type,
3180
- result: result
3181
- )
3392
+ block_constr.typing.save!
3393
+ else
3394
+ # Failed to infer the type of block parameters
3395
+ constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3396
+ errors << error
3397
+ end
3182
3398
 
3399
+ method_type = eliminate_vars(method_type, type_param_names)
3183
3400
  return_type = method_type.type.return_type
3184
3401
  end
3185
-
3186
- block_constr.typing.save!
3187
-
3188
- rescue Subtyping::Constraints::UnsatisfiableConstraint => exn
3189
- errors << Diagnostic::Ruby::UnsatisfiableConstraint.new(
3190
- node: node,
3191
- method_type: method_type,
3192
- var: exn.var,
3193
- sub_type: exn.sub_type,
3194
- super_type: exn.super_type,
3195
- result: exn.result
3402
+ else
3403
+ # Block parameters are unsupported syntax
3404
+ errors << Diagnostic::Ruby::UnsupportedSyntax.new(
3405
+ node: block_params,
3406
+ message: "Unsupported block params pattern, probably masgn?"
3196
3407
  )
3197
3408
 
3198
- constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3199
- errors << error
3200
- end
3409
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3410
+ constraints.solution(checker, variables: type_param_names, context: ccontext)
3411
+ }
3412
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3201
3413
 
3202
- s = Interface::Substitution.build(method_type.free_variables,
3203
- Array.new(method_type.free_variables.size, AST::Builtin.any_type))
3204
- method_type = method_type.subst(s)
3414
+ return_type = method_type.type.return_type
3205
3415
  end
3206
3416
  else
3207
- errors << Diagnostic::Ruby::UnsupportedSyntax.new(
3208
- node: block_params,
3209
- message: "Unsupported block params pattern, probably masgn?"
3210
- )
3417
+ # Block is given but method doesn't accept
3418
+ #
3419
+ constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3420
+ errors << error
3421
+ end
3211
3422
 
3212
- s = constraints.solution(
3213
- checker,
3214
- variance: variance,
3215
- variables: fresh_vars,
3216
- self_type: self_type,
3217
- instance_type: module_context.instance_type,
3218
- class_type: module_context.module_type
3423
+ errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3424
+ node: node,
3425
+ method_type: method_type
3219
3426
  )
3220
- method_type = method_type.subst(s)
3427
+
3428
+ method_type = eliminate_vars(method_type, type_param_names)
3429
+ return_type = method_type.type.return_type
3221
3430
  end
3222
3431
  else
3223
- # Block is given but method doesn't accept
3224
- #
3225
- constr.type_block_without_hint(node: node, block_annotations: block_annotations, block_params: block_params_, block_body: block_body) do |error|
3226
- errors << error
3227
- end
3432
+ # Block syntax is not given
3433
+ arg = args.block_pass_arg
3434
+
3435
+ case
3436
+ when arg.compatible?
3437
+ if arg.node
3438
+ # Block pass (&block) is given
3439
+ node_type, constr = constr.synthesize(arg.node, hint: arg.node_type)
3440
+
3441
+ nil_given =
3442
+ constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success? &&
3443
+ !node_type.is_a?(AST::Types::Any)
3444
+
3445
+ if nil_given
3446
+ # nil is given ==> no block arg node is given
3447
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3448
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3449
+ }
3450
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3228
3451
 
3229
- errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3230
- node: node,
3231
- method_type: method_type
3232
- )
3233
- end
3234
- else
3235
- arg = args.block_pass_arg
3452
+ # Passing no block
3453
+ errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3454
+ node: node,
3455
+ method_type: method_type
3456
+ )
3457
+ else
3458
+ # non-nil value is given
3459
+ constr.check_relation(sub_type: node_type, super_type: arg.node_type, constraints: constraints).else do |result|
3460
+ errors << Diagnostic::Ruby::BlockTypeMismatch.new(
3461
+ node: arg.node,
3462
+ expected: arg.node_type,
3463
+ actual: node_type,
3464
+ result: result
3465
+ )
3466
+ end
3236
3467
 
3237
- case
3238
- when arg.compatible?
3239
- if arg.node
3240
- subst = constraints.solution(
3241
- checker,
3242
- self_type: self_type,
3243
- instance_type: module_context.instance_type,
3244
- class_type: module_context.module_type,
3245
- variance: variance,
3246
- variables: occurence.params
3247
- )
3468
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3469
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3470
+ }
3471
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3472
+ end
3473
+ else
3474
+ # Block is not given
3475
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3476
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3477
+ }
3478
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3479
+ end
3248
3480
 
3249
- block_type = arg.node_type.subst(subst)
3481
+ return_type = method_type.type.return_type
3250
3482
 
3251
- node_type, constr = constr.synthesize(arg.node, hint: block_type)
3252
- nil_block =
3253
- constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success? &&
3254
- !node_type.is_a?(AST::Types::Any)
3483
+ when arg.block_missing?
3484
+ # Block is required but not given
3485
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3486
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3487
+ }
3255
3488
 
3256
- unless nil_block
3257
- constr.check_relation(sub_type: node_type, super_type: block_type, constraints: constraints).else do |result|
3258
- errors << Diagnostic::Ruby::BlockTypeMismatch.new(
3259
- node: arg.node,
3260
- expected: block_type,
3261
- actual: node_type,
3262
- result: result
3263
- )
3264
- end
3265
- end
3489
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3490
+ return_type = method_type.type.return_type
3266
3491
 
3267
- subst = constraints.solution(
3268
- checker,
3269
- self_type: self_type,
3270
- instance_type: module_context.instance_type,
3271
- class_type: module_context.module_type,
3272
- variance: variance,
3273
- variables: method_type.free_variables
3492
+ errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3493
+ node: node,
3494
+ method_type: method_type
3274
3495
  )
3275
3496
 
3276
- method_type = method_type.subst(subst)
3497
+ when arg.unexpected_block?
3498
+ # Unexpected block is given
3499
+ method_type, solved, _ = apply_solution(errors, node: node, method_type: method_type) {
3500
+ constraints.solution(checker, variables: method_type.free_variables, context: ccontext)
3501
+ }
3502
+ method_type = eliminate_vars(method_type, type_param_names) unless solved
3503
+ return_type = method_type.type.return_type
3277
3504
 
3278
- if nil_block && arg.block.required?
3279
- # Passing no block
3280
- errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3505
+ node_type, constr = constr.synthesize(arg.node)
3506
+
3507
+ unless constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success?
3508
+ errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3281
3509
  node: node,
3282
3510
  method_type: method_type
3283
3511
  )
3284
3512
  end
3285
- else
3286
- subst = constraints.solution(
3287
- checker,
3288
- self_type: self_type,
3289
- instance_type: module_context.instance_type,
3290
- class_type: module_context.module_type,
3291
- variance: variance,
3292
- variables: method_type.free_variables
3293
- )
3294
-
3295
- method_type = method_type.subst(subst)
3296
- end
3297
-
3298
- when arg.block_missing?
3299
- subst = constraints.solution(
3300
- checker,
3301
- self_type: self_type,
3302
- instance_type: module_context.instance_type,
3303
- class_type: module_context.module_type,
3304
- variance: variance,
3305
- variables: method_type.free_variables
3306
- )
3307
-
3308
- method_type = method_type.subst(subst)
3309
-
3310
- errors << Diagnostic::Ruby::RequiredBlockMissing.new(
3311
- node: node,
3312
- method_type: method_type
3313
- )
3314
-
3315
- when arg.unexpected_block?
3316
- subst = constraints.solution(
3317
- checker,
3318
- self_type: self_type,
3319
- instance_type: module_context.instance_type,
3320
- class_type: module_context.module_type,
3321
- variance: variance,
3322
- variables: method_type.free_variables
3323
- )
3324
-
3325
- method_type = method_type.subst(subst)
3326
-
3327
- node_type, constr = constr.synthesize(arg.node)
3328
-
3329
- unless constr.check_relation(sub_type: node_type, super_type: AST::Builtin.nil_type).success?
3330
- errors << Diagnostic::Ruby::UnexpectedBlockGiven.new(
3331
- node: node,
3332
- method_type: method_type
3333
- )
3334
3513
  end
3335
3514
  end
3336
- end
3337
-
3338
- call = if errors.empty?
3339
- TypeInference::MethodCall::Typed.new(
3340
- node: node,
3341
- context: context.method_context,
3342
- receiver_type: receiver_type,
3343
- method_name: method_name,
3344
- actual_method_type: method_type,
3345
- return_type: return_type || method_type.type.return_type,
3346
- method_decls: method_type.method_decls
3347
- )
3348
- else
3349
- TypeInference::MethodCall::Error.new(
3350
- node: node,
3351
- context: context.method_context,
3352
- receiver_type: receiver_type,
3353
- method_name: method_name,
3354
- return_type: return_type || method_type.type.return_type,
3355
- method_decls: method_type.method_decls,
3356
- errors: errors
3357
- )
3358
- end
3359
3515
 
3360
- [
3361
- call,
3362
- constr
3363
- ]
3516
+ call = if errors.empty?
3517
+ TypeInference::MethodCall::Typed.new(
3518
+ node: node,
3519
+ context: context.method_context,
3520
+ receiver_type: receiver_type,
3521
+ method_name: method_name,
3522
+ actual_method_type: method_type,
3523
+ return_type: return_type || method_type.type.return_type,
3524
+ method_decls: method_type.method_decls
3525
+ )
3526
+ else
3527
+ TypeInference::MethodCall::Error.new(
3528
+ node: node,
3529
+ context: context.method_context,
3530
+ receiver_type: receiver_type,
3531
+ method_name: method_name,
3532
+ return_type: return_type || method_type.type.return_type,
3533
+ method_decls: method_type.method_decls,
3534
+ errors: errors
3535
+ )
3536
+ end
3537
+
3538
+ [
3539
+ call,
3540
+ constr
3541
+ ]
3542
+ end
3364
3543
  end
3365
3544
 
3366
3545
  def type_check_argument(node, receiver_type:, type:, constraints:, report_node: node, errors:)
@@ -3383,13 +3562,14 @@ module Steep
3383
3562
  message: "Unsupported block params pattern, probably masgn?"
3384
3563
  )
3385
3564
  )
3386
- block_params = TypeInference::BlockParams.new(leading_params: [], optional_params: [], rest_param: nil, trailing_params: [])
3565
+ block_params = TypeInference::BlockParams.new(leading_params: [], optional_params: [], rest_param: nil, trailing_params: [], block_param: nil)
3387
3566
  end
3388
3567
 
3389
3568
  block_constr = for_block(
3390
3569
  block_params: block_params,
3391
3570
  block_param_hint: nil,
3392
3571
  block_type_hint: AST::Builtin.any_type,
3572
+ block_block_hint: nil,
3393
3573
  block_annotations: block_annotations,
3394
3574
  node_type_hint: AST::Builtin.any_type
3395
3575
  )
@@ -3416,8 +3596,8 @@ module Steep
3416
3596
  end
3417
3597
  end
3418
3598
 
3419
- def for_block(block_params:, block_param_hint:, block_type_hint:, block_annotations:, node_type_hint:)
3420
- block_param_pairs = block_param_hint && block_params.zip(block_param_hint)
3599
+ def for_block(block_params:, block_param_hint:, block_type_hint:, block_block_hint:, block_annotations:, node_type_hint:)
3600
+ block_param_pairs = block_param_hint && block_params.zip(block_param_hint, block_block_hint)
3421
3601
 
3422
3602
  param_types_hash = {}
3423
3603
  if block_param_pairs
@@ -3485,7 +3665,8 @@ module Steep
3485
3665
  self_type: self_type,
3486
3666
  type_env: type_env.dup,
3487
3667
  lvar_env: lvar_env,
3488
- call_context: self.context.call_context
3668
+ call_context: self.context.call_context,
3669
+ variable_context: variable_context
3489
3670
  )
3490
3671
  )
3491
3672
  end