jaclang 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (82) hide show
  1. jaclang/__init__.py +9 -3
  2. jaclang/cli/__init__.py +0 -1
  3. jaclang/cli/__jac_gen__/cli.py +6 -6
  4. jaclang/cli/__jac_gen__/cli_impl.py +2 -2
  5. jaclang/cli/__jac_gen__/cmds.py +2 -3
  6. jaclang/cli/__jac_gen__/cmds_impl.py +2 -3
  7. jaclang/cli/cmds.jac +1 -1
  8. jaclang/cli/cmds_impl.jac +2 -3
  9. jaclang/core/__init__.py +5 -11
  10. jaclang/core/__jac_gen__/corelib.py +289 -0
  11. jaclang/core/__jac_gen__/corelib_impl.py +220 -0
  12. jaclang/core/corelib.jac +21 -34
  13. jaclang/core/corelib_impl.jac +317 -0
  14. jaclang/jac/__init__.py +1 -0
  15. jaclang/jac/__jac_gen__/jac_parser.py +2 -2
  16. jaclang/jac/absyntree.py +32 -62
  17. jaclang/jac/constant.py +3 -7
  18. jaclang/jac/importer.py +1 -1
  19. jaclang/jac/parser.py +14 -10
  20. jaclang/jac/passes/main/__init__.py +2 -0
  21. jaclang/jac/passes/main/def_use_pass.py +4 -7
  22. jaclang/jac/passes/main/pyast_gen_pass.py +116 -35
  23. jaclang/jac/passes/main/schedules.py +6 -0
  24. jaclang/jac/passes/main/sym_tab_build_pass.py +40 -19
  25. jaclang/jac/passes/main/tests/test_jac_format_pass.py +22 -4
  26. jaclang/jac/passes/main/tests/test_pyast_gen_pass.py +3 -1
  27. jaclang/jac/passes/main/tests/test_type_check_pass.py +42 -0
  28. jaclang/jac/passes/main/type_check_pass.py +103 -0
  29. jaclang/jac/passes/tool/ast_printer_pass.py +8 -2
  30. jaclang/jac/passes/tool/fuse_comments_pass.py +57 -39
  31. jaclang/jac/passes/tool/jac_formatter_pass.py +419 -192
  32. jaclang/jac/passes/tool/sym_tab_printer_pass.py +10 -93
  33. jaclang/jac/passes/tool/tests/test_ast_print_pass.py +2 -1
  34. jaclang/jac/passes/transform.py +0 -39
  35. jaclang/jac/passes/utils/__init__.py +1 -0
  36. jaclang/jac/passes/utils/mypy_ast_build.py +302 -0
  37. jaclang/jac/plugin/__init__.py +5 -2
  38. jaclang/jac/plugin/default.py +20 -4
  39. jaclang/jac/plugin/feature.py +16 -7
  40. jaclang/jac/plugin/spec.py +34 -6
  41. jaclang/jac/symtable.py +6 -0
  42. jaclang/jac/tests/test_workspace.py +55 -1
  43. jaclang/jac/transpiler.py +4 -9
  44. jaclang/utils/helpers.py +0 -33
  45. jaclang/utils/lang_tools.py +3 -0
  46. jaclang/utils/test.py +3 -1
  47. jaclang/utils/treeprinter.py +171 -0
  48. jaclang/vendor/lark/py.typed +0 -0
  49. jaclang/vendor/mypy/checker.py +19 -12
  50. jaclang/vendor/mypy/checkexpr.py +31 -10
  51. jaclang/vendor/mypy/constraints.py +56 -38
  52. jaclang/vendor/mypy/expandtype.py +1 -0
  53. jaclang/vendor/mypy/meet.py +10 -1
  54. jaclang/vendor/mypy/messages.py +16 -4
  55. jaclang/vendor/mypy/moduleinspect.py +10 -4
  56. jaclang/vendor/mypy/py.typed +1 -0
  57. jaclang/vendor/mypy/semanal.py +18 -17
  58. jaclang/vendor/mypy/semanal_enum.py +7 -4
  59. jaclang/vendor/mypy/semanal_namedtuple.py +11 -1
  60. jaclang/vendor/mypy/semanal_typeddict.py +25 -11
  61. jaclang/vendor/mypy/stubdoc.py +18 -4
  62. jaclang/vendor/mypy/stubgen.py +80 -1
  63. jaclang/vendor/mypy/stubgenc.py +47 -5
  64. jaclang/vendor/mypy/stubtest.py +53 -3
  65. jaclang/vendor/mypy/stubutil.py +9 -9
  66. jaclang/vendor/mypy/test/testipc.py +16 -7
  67. jaclang/vendor/mypy/test/teststubtest.py +20 -2
  68. jaclang/vendor/mypy/types.py +1 -1
  69. jaclang/vendor/mypyc/irbuild/prebuildvisitor.py +2 -1
  70. jaclang/vendor/mypyc/test/test_run.py +2 -4
  71. jaclang/vendor/pluggy/py.typed +0 -0
  72. {jaclang-0.2.4.dist-info → jaclang-0.3.0.dist-info}/METADATA +1 -1
  73. {jaclang-0.2.4.dist-info → jaclang-0.3.0.dist-info}/RECORD +77 -71
  74. {jaclang-0.2.4.dist-info → jaclang-0.3.0.dist-info}/WHEEL +1 -1
  75. {jaclang-0.2.4.dist-info → jaclang-0.3.0.dist-info}/entry_points.txt +3 -0
  76. jaclang/core/arch_impl.jac +0 -131
  77. jaclang/core/element_impl.jac +0 -109
  78. jaclang/core/exec_ctx_impl.jac +0 -14
  79. jaclang/core/memory_impl.jac +0 -57
  80. jaclang/jac/tests/fixtures/__jac_gen__/hello_world.py +0 -5
  81. /jaclang/{jac/tests/fixtures → core}/__jac_gen__/__init__.py +0 -0
  82. {jaclang-0.2.4.dist-info → jaclang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -238,29 +238,29 @@ def infer_constraints_for_callable(
238
238
  callee.arg_names[i],
239
239
  callee.arg_kinds[i],
240
240
  )
241
- if (
242
- param_spec
243
- and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
244
- and not incomplete_star_mapping
245
- ):
241
+ if param_spec and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
246
242
  # If actual arguments are mapped to ParamSpec type, we can't infer individual
247
243
  # constraints, instead store them and infer single constraint at the end.
248
244
  # It is impossible to map actual kind to formal kind, so use some heuristic.
249
245
  # This inference is used as a fallback, so relying on heuristic should be OK.
250
- param_spec_arg_types.append(
251
- mapper.expand_actual_type(
252
- actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
246
+ if not incomplete_star_mapping:
247
+ param_spec_arg_types.append(
248
+ mapper.expand_actual_type(
249
+ actual_arg_type,
250
+ arg_kinds[actual],
251
+ None,
252
+ arg_kinds[actual],
253
+ )
254
+ )
255
+ actual_kind = arg_kinds[actual]
256
+ param_spec_arg_kinds.append(
257
+ ARG_POS
258
+ if actual_kind not in (ARG_STAR, ARG_STAR2)
259
+ else actual_kind
260
+ )
261
+ param_spec_arg_names.append(
262
+ arg_names[actual] if arg_names else None
253
263
  )
254
- )
255
- actual_kind = arg_kinds[actual]
256
- param_spec_arg_kinds.append(
257
- ARG_POS
258
- if actual_kind not in (ARG_STAR, ARG_STAR2)
259
- else actual_kind
260
- )
261
- param_spec_arg_names.append(
262
- arg_names[actual] if arg_names else None
263
- )
264
264
  else:
265
265
  c = infer_constraints(
266
266
  callee.arg_types[i], actual_type, SUPERTYPE_OF
@@ -285,6 +285,9 @@ def infer_constraints_for_callable(
285
285
  ),
286
286
  )
287
287
  )
288
+ if any(isinstance(v, ParamSpecType) for v in callee.variables):
289
+ # As a perf optimization filter imprecise constraints only when we can have them.
290
+ constraints = filter_imprecise_kinds(constraints)
288
291
  return constraints
289
292
 
290
293
 
@@ -1197,32 +1200,23 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
1197
1200
  )
1198
1201
 
1199
1202
  param_spec_target: Type | None = None
1200
- skip_imprecise = (
1201
- any(c.type_var == param_spec.id for c in res)
1202
- and cactual.imprecise_arg_kinds
1203
- )
1204
1203
  if not cactual_ps:
1205
1204
  max_prefix_len = len(
1206
1205
  [k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]
1207
1206
  )
1208
1207
  prefix_len = min(prefix_len, max_prefix_len)
1209
- # This logic matches top-level callable constraint exception, if we managed
1210
- # to get other constraints for ParamSpec, don't infer one with imprecise kinds
1211
- if not skip_imprecise:
1212
- param_spec_target = Parameters(
1213
- arg_types=cactual.arg_types[prefix_len:],
1214
- arg_kinds=cactual.arg_kinds[prefix_len:],
1215
- arg_names=cactual.arg_names[prefix_len:],
1216
- variables=cactual.variables
1217
- if not type_state.infer_polymorphic
1218
- else [],
1219
- imprecise_arg_kinds=cactual.imprecise_arg_kinds,
1220
- )
1208
+ param_spec_target = Parameters(
1209
+ arg_types=cactual.arg_types[prefix_len:],
1210
+ arg_kinds=cactual.arg_kinds[prefix_len:],
1211
+ arg_names=cactual.arg_names[prefix_len:],
1212
+ variables=cactual.variables
1213
+ if not type_state.infer_polymorphic
1214
+ else [],
1215
+ imprecise_arg_kinds=cactual.imprecise_arg_kinds,
1216
+ )
1221
1217
  else:
1222
- if (
1223
- len(param_spec.prefix.arg_types)
1224
- <= len(cactual_ps.prefix.arg_types)
1225
- and not skip_imprecise
1218
+ if len(param_spec.prefix.arg_types) <= len(
1219
+ cactual_ps.prefix.arg_types
1226
1220
  ):
1227
1221
  param_spec_target = cactual_ps.copy_modified(
1228
1222
  prefix=Parameters(
@@ -1770,3 +1764,27 @@ def infer_callable_arguments_constraints(
1770
1764
  )
1771
1765
  )
1772
1766
  return res
1767
+
1768
+
1769
+ def filter_imprecise_kinds(cs: list[Constraint]) -> list[Constraint]:
1770
+ """For each ParamSpec remove all imprecise constraints, if at least one precise available."""
1771
+ have_precise = set()
1772
+ for c in cs:
1773
+ if not isinstance(c.origin_type_var, ParamSpecType):
1774
+ continue
1775
+ if (
1776
+ isinstance(c.target, ParamSpecType)
1777
+ or isinstance(c.target, Parameters)
1778
+ and not c.target.imprecise_arg_kinds
1779
+ ):
1780
+ have_precise.add(c.type_var)
1781
+ new_cs = []
1782
+ for c in cs:
1783
+ if (
1784
+ not isinstance(c.origin_type_var, ParamSpecType)
1785
+ or c.type_var not in have_precise
1786
+ ):
1787
+ new_cs.append(c)
1788
+ if not isinstance(c.target, Parameters) or not c.target.imprecise_arg_kinds:
1789
+ new_cs.append(c)
1790
+ return new_cs
@@ -256,6 +256,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
256
256
  t.prefix.arg_kinds + repl.arg_kinds,
257
257
  t.prefix.arg_names + repl.arg_names,
258
258
  variables=[*t.prefix.variables, *repl.variables],
259
+ imprecise_arg_kinds=repl.imprecise_arg_kinds,
259
260
  )
260
261
  else:
261
262
  # We could encode Any as trivial parameters etc., but it would be too verbose.
@@ -268,6 +268,7 @@ def is_overlapping_types(
268
268
  ignore_promotions: bool = False,
269
269
  prohibit_none_typevar_overlap: bool = False,
270
270
  ignore_uninhabited: bool = False,
271
+ seen_types: set[tuple[Type, Type]] | None = None,
271
272
  ) -> bool:
272
273
  """Can a value of type 'left' also be of type 'right' or vice-versa?
273
274
 
@@ -281,12 +282,19 @@ def is_overlapping_types(
281
282
  # A type guard forces the new type even if it doesn't overlap the old.
282
283
  return True
283
284
 
285
+ if seen_types is None:
286
+ seen_types = set()
287
+ if (left, right) in seen_types:
288
+ return True
289
+ if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType):
290
+ seen_types.add((left, right))
291
+
284
292
  left, right = get_proper_types((left, right))
285
293
 
286
294
  def _is_overlapping_types(left: Type, right: Type) -> bool:
287
295
  """Encode the kind of overlapping check to perform.
288
296
 
289
- This function mostly exists so we don't have to repeat keyword arguments everywhere.
297
+ This function mostly exists, so we don't have to repeat keyword arguments everywhere.
290
298
  """
291
299
  return is_overlapping_types(
292
300
  left,
@@ -294,6 +302,7 @@ def is_overlapping_types(
294
302
  ignore_promotions=ignore_promotions,
295
303
  prohibit_none_typevar_overlap=prohibit_none_typevar_overlap,
296
304
  ignore_uninhabited=ignore_uninhabited,
305
+ seen_types=seen_types.copy(),
297
306
  )
298
307
 
299
308
  # We should never encounter this type.
@@ -1052,10 +1052,22 @@ class MessageBuilder:
1052
1052
  context,
1053
1053
  )
1054
1054
 
1055
+ def unexpected_keyword_argument_for_function(
1056
+ self,
1057
+ for_func: str,
1058
+ name: str,
1059
+ context: Context,
1060
+ *,
1061
+ matches: list[str] | None = None,
1062
+ ) -> None:
1063
+ msg = f'Unexpected keyword argument "{name}"' + for_func
1064
+ if matches:
1065
+ msg += f"; did you mean {pretty_seq(matches, 'or')}?"
1066
+ self.fail(msg, context, code=codes.CALL_ARG)
1067
+
1055
1068
  def unexpected_keyword_argument(
1056
1069
  self, callee: CallableType, name: str, arg_type: Type, context: Context
1057
1070
  ) -> None:
1058
- msg = f'Unexpected keyword argument "{name}"' + for_function(callee)
1059
1071
  # Suggest intended keyword, look for type match else fallback on any match.
1060
1072
  matching_type_args = []
1061
1073
  not_matching_type_args = []
@@ -1069,9 +1081,9 @@ class MessageBuilder:
1069
1081
  matches = best_matches(name, matching_type_args, n=3)
1070
1082
  if not matches:
1071
1083
  matches = best_matches(name, not_matching_type_args, n=3)
1072
- if matches:
1073
- msg += f"; did you mean {pretty_seq(matches, 'or')}?"
1074
- self.fail(msg, context, code=codes.CALL_ARG)
1084
+ self.unexpected_keyword_argument_for_function(
1085
+ for_function(callee), name, context, matches=matches
1086
+ )
1075
1087
  module = find_defining_module(self.modules, callee)
1076
1088
  if module:
1077
1089
  assert callee.definition is not None
@@ -8,7 +8,7 @@ import os
8
8
  import pkgutil
9
9
  import queue
10
10
  import sys
11
- from multiprocessing import Process, Queue
11
+ from multiprocessing import Queue, get_context
12
12
  from types import ModuleType
13
13
 
14
14
 
@@ -133,9 +133,15 @@ class ModuleInspect:
133
133
  self._start()
134
134
 
135
135
  def _start(self) -> None:
136
- self.tasks: Queue[str] = Queue()
137
- self.results: Queue[ModuleProperties | str] = Queue()
138
- self.proc = Process(target=worker, args=(self.tasks, self.results, sys.path))
136
+ if sys.platform == "linux":
137
+ ctx = get_context("forkserver")
138
+ else:
139
+ ctx = get_context("spawn")
140
+ self.tasks: Queue[str] = ctx.Queue()
141
+ self.results: Queue[ModuleProperties | str] = ctx.Queue()
142
+ self.proc = ctx.Process(
143
+ target=worker, args=(self.tasks, self.results, sys.path)
144
+ )
139
145
  self.proc.start()
140
146
  self.counter = 0 # Number of successful roundtrips
141
147
 
@@ -0,0 +1 @@
1
+ # Marker file for PEP 561. The mypy package uses inline types.
@@ -811,7 +811,7 @@ class SemanticAnalyzer(
811
811
  self.globals = file_node.names
812
812
  self.tvar_scope = TypeVarLikeScope()
813
813
 
814
- self.named_tuple_analyzer = NamedTupleAnalyzer(options, self)
814
+ self.named_tuple_analyzer = NamedTupleAnalyzer(options, self, self.msg)
815
815
  self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg)
816
816
  self.enum_call_analyzer = EnumCallAnalyzer(options, self)
817
817
  self.newtype_analyzer = NewTypeAnalyzer(options, self, self.msg)
@@ -3035,22 +3035,23 @@ class SemanticAnalyzer(
3035
3035
  if self.check_and_set_up_type_alias(s):
3036
3036
  s.is_alias_def = True
3037
3037
  special_form = True
3038
- # * type variable definition
3039
- elif self.process_typevar_declaration(s):
3040
- special_form = True
3041
- elif self.process_paramspec_declaration(s):
3042
- special_form = True
3043
- elif self.process_typevartuple_declaration(s):
3044
- special_form = True
3045
- # * type constructors
3046
- elif self.analyze_namedtuple_assign(s):
3047
- special_form = True
3048
- elif self.analyze_typeddict_assign(s):
3049
- special_form = True
3050
- elif self.newtype_analyzer.process_newtype_declaration(s):
3051
- special_form = True
3052
- elif self.analyze_enum_assign(s):
3053
- special_form = True
3038
+ elif isinstance(s.rvalue, CallExpr):
3039
+ # * type variable definition
3040
+ if self.process_typevar_declaration(s):
3041
+ special_form = True
3042
+ elif self.process_paramspec_declaration(s):
3043
+ special_form = True
3044
+ elif self.process_typevartuple_declaration(s):
3045
+ special_form = True
3046
+ # * type constructors
3047
+ elif self.analyze_namedtuple_assign(s):
3048
+ special_form = True
3049
+ elif self.analyze_typeddict_assign(s):
3050
+ special_form = True
3051
+ elif self.newtype_analyzer.process_newtype_declaration(s):
3052
+ special_form = True
3053
+ elif self.analyze_enum_assign(s):
3054
+ special_form = True
3054
3055
 
3055
3056
  if special_form:
3056
3057
  self.record_special_form_lvalue(s)
@@ -106,16 +106,19 @@ class EnumCallAnalyzer:
106
106
  items, values, ok = self.parse_enum_call_args(call, fullname.split(".")[-1])
107
107
  if not ok:
108
108
  # Error. Construct dummy return value.
109
- info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line)
109
+ name = var_name
110
+ if is_func_scope:
111
+ name += "@" + str(call.line)
112
+ info = self.build_enum_call_typeinfo(name, [], fullname, node.line)
110
113
  else:
111
114
  name = cast(StrExpr, call.args[0]).value
112
115
  if name != var_name or is_func_scope:
113
116
  # Give it a unique name derived from the line number.
114
117
  name += "@" + str(call.line)
115
118
  info = self.build_enum_call_typeinfo(name, items, fullname, call.line)
116
- # Store generated TypeInfo under both names, see semanal_namedtuple for more details.
117
- if name != var_name or is_func_scope:
118
- self.api.add_symbol_skip_local(name, info)
119
+ # Store generated TypeInfo under both names, see semanal_namedtuple for more details.
120
+ if name != var_name or is_func_scope:
121
+ self.api.add_symbol_skip_local(name, info)
119
122
  call.analyzed = EnumCallExpr(info, items, values)
120
123
  call.analyzed.set_line(call)
121
124
  info.line = node.line
@@ -9,6 +9,7 @@ from contextlib import contextmanager
9
9
  from typing import Final, Iterator, List, Mapping, cast
10
10
 
11
11
  from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
12
+ from mypy.messages import MessageBuilder
12
13
  from mypy.nodes import (
13
14
  ARG_NAMED_OPT,
14
15
  ARG_OPT,
@@ -92,9 +93,12 @@ SELF_TVAR_NAME: Final = "_NT"
92
93
 
93
94
 
94
95
  class NamedTupleAnalyzer:
95
- def __init__(self, options: Options, api: SemanticAnalyzerInterface) -> None:
96
+ def __init__(
97
+ self, options: Options, api: SemanticAnalyzerInterface, msg: MessageBuilder
98
+ ) -> None:
96
99
  self.options = options
97
100
  self.api = api
101
+ self.msg = msg
98
102
 
99
103
  def analyze_namedtuple_classdef(
100
104
  self, defn: ClassDef, is_stub_file: bool, is_func_scope: bool
@@ -212,6 +216,12 @@ class NamedTupleAnalyzer:
212
216
  )
213
217
  else:
214
218
  default_items[name] = stmt.rvalue
219
+ if defn.keywords:
220
+ for_function = ' for "__init_subclass__" of "NamedTuple"'
221
+ for key in defn.keywords:
222
+ self.msg.unexpected_keyword_argument_for_function(
223
+ for_function, key, defn
224
+ )
215
225
  return items, types, default_items, statements
216
226
 
217
227
  def check_namedtuple(
@@ -349,6 +349,14 @@ class TypedDictAnalyzer:
349
349
  total = require_bool_literal_argument(
350
350
  self.api, defn.keywords["total"], "total", True
351
351
  )
352
+ if defn.keywords and defn.keywords.keys() != {"total"}:
353
+ for_function = ' for "__init_subclass__" of "TypedDict"'
354
+ for key in defn.keywords:
355
+ if key == "total":
356
+ continue
357
+ self.msg.unexpected_keyword_argument_for_function(
358
+ for_function, key, defn
359
+ )
352
360
  required_keys = {
353
361
  field
354
362
  for (field, t) in zip(fields, types)
@@ -420,6 +428,23 @@ class TypedDictAnalyzer:
420
428
  types = [ # unwrap Required[T] to just T
421
429
  t.item if isinstance(t, RequiredType) else t for t in types
422
430
  ]
431
+
432
+ # Perform various validations after unwrapping.
433
+ for t in types:
434
+ check_for_explicit_any(
435
+ t,
436
+ self.options,
437
+ self.api.is_typeshed_stub_file,
438
+ self.msg,
439
+ context=call,
440
+ )
441
+ if self.options.disallow_any_unimported:
442
+ for t in types:
443
+ if has_any_from_unimported_type(t):
444
+ self.msg.unimported_type_becomes_any(
445
+ "Type of a TypedDict key", t, call
446
+ )
447
+
423
448
  existing_info = None
424
449
  if isinstance(node.analyzed, TypedDictExpr):
425
450
  existing_info = node.analyzed.info
@@ -478,17 +503,6 @@ class TypedDictAnalyzer:
478
503
  # One of the types is not ready, defer.
479
504
  return None
480
505
  items, types, ok = res
481
- for t in types:
482
- check_for_explicit_any(
483
- t, self.options, self.api.is_typeshed_stub_file, self.msg, context=call
484
- )
485
-
486
- if self.options.disallow_any_unimported:
487
- for t in types:
488
- if has_any_from_unimported_type(t):
489
- self.msg.unimported_type_becomes_any(
490
- "Type of a TypedDict key", t, dictexpr
491
- )
492
506
  assert total is not None
493
507
  return args[0].value, items, types, total, tvar_defs, ok
494
508
 
@@ -44,11 +44,19 @@ def is_valid_type(s: str) -> bool:
44
44
  class ArgSig:
45
45
  """Signature info for a single argument."""
46
46
 
47
- def __init__(self, name: str, type: str | None = None, default: bool = False):
47
+ def __init__(
48
+ self,
49
+ name: str,
50
+ type: str | None = None,
51
+ *,
52
+ default: bool = False,
53
+ default_value: str = "...",
54
+ ) -> None:
48
55
  self.name = name
49
56
  self.type = type
50
57
  # Does this argument have a default value?
51
58
  self.default = default
59
+ self.default_value = default_value
52
60
 
53
61
  def is_star_arg(self) -> bool:
54
62
  return self.name.startswith("*") and not self.name.startswith("**")
@@ -67,6 +75,7 @@ class ArgSig:
67
75
  self.name == other.name
68
76
  and self.type == other.type
69
77
  and self.default == other.default
78
+ and self.default_value == other.default_value
70
79
  )
71
80
  return False
72
81
 
@@ -127,10 +136,10 @@ class FunctionSig(NamedTuple):
127
136
  if arg_type:
128
137
  arg_def += ": " + arg_type
129
138
  if arg.default:
130
- arg_def += " = ..."
139
+ arg_def += f" = {arg.default_value}"
131
140
 
132
141
  elif arg.default:
133
- arg_def += "=..."
142
+ arg_def += f"={arg.default_value}"
134
143
 
135
144
  args.append(arg_def)
136
145
 
@@ -400,7 +409,12 @@ def infer_ret_type_sig_from_docstring(docstr: str, name: str) -> str | None:
400
409
 
401
410
  def infer_ret_type_sig_from_anon_docstring(docstr: str) -> str | None:
402
411
  """Convert signature in form of "(self: TestClass, arg0) -> int" to their return type."""
403
- return infer_ret_type_sig_from_docstring("stub" + docstr.strip(), "stub")
412
+ lines = [
413
+ "stub" + line.strip()
414
+ for line in docstr.splitlines()
415
+ if line.strip().startswith("(")
416
+ ]
417
+ return infer_ret_type_sig_from_docstring("".join(lines), "stub")
404
418
 
405
419
 
406
420
  def parse_signature(sig: str) -> tuple[str, list[str], list[str]] | None:
@@ -99,6 +99,7 @@ from mypy.nodes import (
99
99
  NameExpr,
100
100
  OpExpr,
101
101
  OverloadedFuncDef,
102
+ SetExpr,
102
103
  Statement,
103
104
  StrExpr,
104
105
  TempNode,
@@ -499,15 +500,28 @@ class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor):
499
500
  if kind.is_named() and not any(arg.name.startswith("*") for arg in args):
500
501
  args.append(ArgSig("*"))
501
502
 
503
+ default = "..."
502
504
  if arg_.initializer:
503
505
  if not typename:
504
506
  typename = self.get_str_type_of_node(arg_.initializer, True, False)
507
+ potential_default, valid = self.get_str_default_of_node(
508
+ arg_.initializer
509
+ )
510
+ if valid and len(potential_default) <= 200:
511
+ default = potential_default
505
512
  elif kind == ARG_STAR:
506
513
  name = f"*{name}"
507
514
  elif kind == ARG_STAR2:
508
515
  name = f"**{name}"
509
516
 
510
- args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
517
+ args.append(
518
+ ArgSig(
519
+ name,
520
+ typename,
521
+ default=bool(arg_.initializer),
522
+ default_value=default,
523
+ )
524
+ )
511
525
 
512
526
  if ctx.class_info is not None and all(
513
527
  arg.type is None and arg.default is False for arg in args
@@ -1282,6 +1296,70 @@ class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor):
1282
1296
  # This is some other unary expr, we cannot do anything with it (yet?).
1283
1297
  return expr
1284
1298
 
1299
+ def get_str_default_of_node(self, rvalue: Expression) -> tuple[str, bool]:
1300
+ """Get a string representation of the default value of a node.
1301
+
1302
+ Returns a 2-tuple of the default and whether or not it is valid.
1303
+ """
1304
+ if isinstance(rvalue, NameExpr):
1305
+ if rvalue.name in ("None", "True", "False"):
1306
+ return rvalue.name, True
1307
+ elif isinstance(rvalue, (IntExpr, FloatExpr)):
1308
+ return f"{rvalue.value}", True
1309
+ elif isinstance(rvalue, UnaryExpr):
1310
+ if isinstance(rvalue.expr, (IntExpr, FloatExpr)):
1311
+ return f"{rvalue.op}{rvalue.expr.value}", True
1312
+ elif isinstance(rvalue, StrExpr):
1313
+ return repr(rvalue.value), True
1314
+ elif isinstance(rvalue, BytesExpr):
1315
+ return "b" + repr(rvalue.value).replace("\\\\", "\\"), True
1316
+ elif isinstance(rvalue, TupleExpr):
1317
+ items_defaults = []
1318
+ for e in rvalue.items:
1319
+ e_default, valid = self.get_str_default_of_node(e)
1320
+ if not valid:
1321
+ break
1322
+ items_defaults.append(e_default)
1323
+ else:
1324
+ closing = ",)" if len(items_defaults) == 1 else ")"
1325
+ default = "(" + ", ".join(items_defaults) + closing
1326
+ return default, True
1327
+ elif isinstance(rvalue, ListExpr):
1328
+ items_defaults = []
1329
+ for e in rvalue.items:
1330
+ e_default, valid = self.get_str_default_of_node(e)
1331
+ if not valid:
1332
+ break
1333
+ items_defaults.append(e_default)
1334
+ else:
1335
+ default = "[" + ", ".join(items_defaults) + "]"
1336
+ return default, True
1337
+ elif isinstance(rvalue, SetExpr):
1338
+ items_defaults = []
1339
+ for e in rvalue.items:
1340
+ e_default, valid = self.get_str_default_of_node(e)
1341
+ if not valid:
1342
+ break
1343
+ items_defaults.append(e_default)
1344
+ else:
1345
+ if items_defaults:
1346
+ default = "{" + ", ".join(items_defaults) + "}"
1347
+ return default, True
1348
+ elif isinstance(rvalue, DictExpr):
1349
+ items_defaults = []
1350
+ for k, v in rvalue.items:
1351
+ if k is None:
1352
+ break
1353
+ k_default, k_valid = self.get_str_default_of_node(k)
1354
+ v_default, v_valid = self.get_str_default_of_node(v)
1355
+ if not (k_valid and v_valid):
1356
+ break
1357
+ items_defaults.append(f"{k_default}: {v_default}")
1358
+ else:
1359
+ default = "{" + ", ".join(items_defaults) + "}"
1360
+ return default, True
1361
+ return "...", False
1362
+
1285
1363
  def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
1286
1364
  is_private = self.is_private_name(name, full_module + "." + name)
1287
1365
  if (
@@ -1702,6 +1780,7 @@ def generate_stubs(options: Options) -> None:
1702
1780
  doc_dir=options.doc_dir,
1703
1781
  include_private=options.include_private,
1704
1782
  export_less=options.export_less,
1783
+ include_docstrings=options.include_docstrings,
1705
1784
  )
1706
1785
  num_modules = len(all_modules)
1707
1786
  if not options.quiet and num_modules > 0:
@@ -136,10 +136,12 @@ class DocstringSignatureGenerator(SignatureGenerator):
136
136
  """Infer property type from docstring or docstring signature."""
137
137
  if ctx.docstring is not None:
138
138
  inferred = infer_ret_type_sig_from_anon_docstring(ctx.docstring)
139
- if not inferred:
140
- inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name)
141
- if not inferred:
142
- inferred = infer_prop_type_from_docstring(ctx.docstring)
139
+ if inferred:
140
+ return inferred
141
+ inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name)
142
+ if inferred:
143
+ return inferred
144
+ inferred = infer_prop_type_from_docstring(ctx.docstring)
143
145
  return inferred
144
146
  else:
145
147
  return None
@@ -249,6 +251,26 @@ class InspectionStubGenerator(BaseStubGenerator):
249
251
  self.resort_members = self.is_c_module
250
252
  super().__init__(_all_, include_private, export_less, include_docstrings)
251
253
  self.module_name = module_name
254
+ if self.is_c_module:
255
+ # Add additional implicit imports.
256
+ # C-extensions are given more lattitude since they do not import the typing module.
257
+ self.known_imports.update(
258
+ {
259
+ "typing": [
260
+ "Any",
261
+ "Callable",
262
+ "ClassVar",
263
+ "Dict",
264
+ "Iterable",
265
+ "Iterator",
266
+ "List",
267
+ "NamedTuple",
268
+ "Optional",
269
+ "Tuple",
270
+ "Union",
271
+ ]
272
+ }
273
+ )
252
274
 
253
275
  def get_default_function_sig(
254
276
  self, func: object, ctx: FunctionContext
@@ -617,11 +639,31 @@ class InspectionStubGenerator(BaseStubGenerator):
617
639
  if inferred[0].args and inferred[0].args[0].name == "cls":
618
640
  decorators.append("@classmethod")
619
641
 
642
+ if docstring:
643
+ docstring = self._indent_docstring(docstring)
620
644
  output.extend(
621
645
  self.format_func_def(inferred, decorators=decorators, docstring=docstring)
622
646
  )
623
647
  self._fix_iter(ctx, inferred, output)
624
648
 
649
+ def _indent_docstring(self, docstring: str) -> str:
650
+ """Fix indentation of docstring extracted from pybind11 or other binding generators."""
651
+ lines = docstring.splitlines(keepends=True)
652
+ indent = self._indent + " "
653
+ if len(lines) > 1:
654
+ if not all(line.startswith(indent) or not line.strip() for line in lines):
655
+ # if the docstring is not indented, then indent all but the first line
656
+ for i, line in enumerate(lines[1:]):
657
+ if line.strip():
658
+ lines[i + 1] = indent + line
659
+ # if there's a trailing newline, add a final line to visually indent the quoted docstring
660
+ if lines[-1].endswith("\n"):
661
+ if len(lines) > 1:
662
+ lines.append(indent)
663
+ else:
664
+ lines[-1] = lines[-1][:-1]
665
+ return "".join(lines)
666
+
625
667
  def _fix_iter(
626
668
  self, ctx: FunctionContext, inferred: list[FunctionSig], output: list[str]
627
669
  ) -> None:
@@ -671,7 +713,7 @@ class InspectionStubGenerator(BaseStubGenerator):
671
713
  if fget:
672
714
  alt_docstr = getattr(fget, "__doc__", None)
673
715
  if alt_docstr and docstring:
674
- docstring += alt_docstr
716
+ docstring += "\n" + alt_docstr
675
717
  elif alt_docstr:
676
718
  docstring = alt_docstr
677
719