output.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. import copy
  2. import itertools
  3. from functools import partial
  4. from typing import Any, Iterable, List, Optional, Set, Tuple, Type
  5. from isort.format import format_simplified
  6. from . import parse, sorting, wrap
  7. from .comments import add_to_line as with_comments
  8. from .identify import STATEMENT_DECLARATIONS
  9. from .settings import DEFAULT_CONFIG, Config
  10. def sorted_imports(
  11. parsed: parse.ParsedContent,
  12. config: Config = DEFAULT_CONFIG,
  13. extension: str = "py",
  14. import_type: str = "import",
  15. ) -> str:
  16. """Adds the imports back to the file.
  17. (at the index of the first import) sorted alphabetically and split between groups
  18. """
  19. if parsed.import_index == -1:
  20. return _output_as_string(parsed.lines_without_imports, parsed.line_separator)
  21. formatted_output: List[str] = parsed.lines_without_imports.copy()
  22. remove_imports = [format_simplified(removal) for removal in config.remove_imports]
  23. sections: Iterable[str] = itertools.chain(parsed.sections, config.forced_separate)
  24. if config.no_sections:
  25. parsed.imports["no_sections"] = {"straight": {}, "from": {}}
  26. base_sections: Tuple[str, ...] = ()
  27. for section in sections:
  28. if section == "FUTURE":
  29. base_sections = ("FUTURE",)
  30. continue
  31. parsed.imports["no_sections"]["straight"].update(
  32. parsed.imports[section].get("straight", {})
  33. )
  34. parsed.imports["no_sections"]["from"].update(parsed.imports[section].get("from", {}))
  35. sections = base_sections + ("no_sections",)
  36. output: List[str] = []
  37. seen_headings: Set[str] = set()
  38. pending_lines_before = False
  39. for section in sections:
  40. straight_modules = parsed.imports[section]["straight"]
  41. if not config.only_sections:
  42. straight_modules = sorting.sort(
  43. config,
  44. straight_modules,
  45. key=lambda key: sorting.module_key(
  46. key, config, section_name=section, straight_import=True
  47. ),
  48. reverse=config.reverse_sort,
  49. )
  50. from_modules = parsed.imports[section]["from"]
  51. if not config.only_sections:
  52. from_modules = sorting.sort(
  53. config,
  54. from_modules,
  55. key=lambda key: sorting.module_key(key, config, section_name=section),
  56. reverse=config.reverse_sort,
  57. )
  58. if config.star_first:
  59. star_modules = []
  60. other_modules = []
  61. for module in from_modules:
  62. if "*" in parsed.imports[section]["from"][module]:
  63. star_modules.append(module)
  64. else:
  65. other_modules.append(module)
  66. from_modules = star_modules + other_modules
  67. straight_imports = _with_straight_imports(
  68. parsed, config, straight_modules, section, remove_imports, import_type
  69. )
  70. from_imports = _with_from_imports(
  71. parsed, config, from_modules, section, remove_imports, import_type
  72. )
  73. lines_between = [""] * (
  74. config.lines_between_types if from_modules and straight_modules else 0
  75. )
  76. if config.from_first:
  77. section_output = from_imports + lines_between + straight_imports
  78. else:
  79. section_output = straight_imports + lines_between + from_imports
  80. if config.force_sort_within_sections:
  81. # collapse comments
  82. comments_above = []
  83. new_section_output: List[str] = []
  84. for line in section_output:
  85. if not line:
  86. continue
  87. if line.startswith("#"):
  88. comments_above.append(line)
  89. elif comments_above:
  90. new_section_output.append(_LineWithComments(line, comments_above))
  91. comments_above = []
  92. else:
  93. new_section_output.append(line)
  94. # only_sections options is not imposed if force_sort_within_sections is True
  95. new_section_output = sorting.sort(
  96. config,
  97. new_section_output,
  98. key=partial(sorting.section_key, config=config),
  99. reverse=config.reverse_sort,
  100. )
  101. # uncollapse comments
  102. section_output = []
  103. for line in new_section_output:
  104. comments = getattr(line, "comments", ())
  105. if comments:
  106. section_output.extend(comments)
  107. section_output.append(str(line))
  108. section_name = section
  109. no_lines_before = section_name in config.no_lines_before
  110. if section_output:
  111. if section_name in parsed.place_imports:
  112. parsed.place_imports[section_name] = section_output
  113. continue
  114. section_title = config.import_headings.get(section_name.lower(), "")
  115. if section_title and section_title not in seen_headings:
  116. if config.dedup_headings:
  117. seen_headings.add(section_title)
  118. section_comment = f"# {section_title}"
  119. if section_comment not in parsed.lines_without_imports[0:1]: # pragma: no branch
  120. section_output.insert(0, section_comment)
  121. section_footer = config.import_footers.get(section_name.lower(), "")
  122. if section_footer and section_footer not in seen_headings:
  123. if config.dedup_headings:
  124. seen_headings.add(section_footer)
  125. section_comment_end = f"# {section_footer}"
  126. if (
  127. section_comment_end not in parsed.lines_without_imports[-1:]
  128. ): # pragma: no branch
  129. section_output.append("") # Empty line for black compatibility
  130. section_output.append(section_comment_end)
  131. if pending_lines_before or not no_lines_before:
  132. output += [""] * config.lines_between_sections
  133. output += section_output
  134. pending_lines_before = False
  135. else:
  136. pending_lines_before = pending_lines_before or not no_lines_before
  137. if config.ensure_newline_before_comments:
  138. output = _ensure_newline_before_comment(output)
  139. while output and output[-1].strip() == "":
  140. output.pop() # pragma: no cover
  141. while output and output[0].strip() == "":
  142. output.pop(0)
  143. if config.formatting_function:
  144. output = config.formatting_function(
  145. parsed.line_separator.join(output), extension, config
  146. ).splitlines()
  147. output_at = 0
  148. if parsed.import_index < parsed.original_line_count:
  149. output_at = parsed.import_index
  150. formatted_output[output_at:0] = output
  151. if output:
  152. imports_tail = output_at + len(output)
  153. while [
  154. character.strip() for character in formatted_output[imports_tail : imports_tail + 1]
  155. ] == [""]:
  156. formatted_output.pop(imports_tail)
  157. if len(formatted_output) > imports_tail:
  158. next_construct = ""
  159. tail = formatted_output[imports_tail:]
  160. for index, line in enumerate(tail): # pragma: no branch
  161. should_skip, in_quote, *_ = parse.skip_line(
  162. line,
  163. in_quote="",
  164. index=len(formatted_output),
  165. section_comments=config.section_comments,
  166. needs_import=False,
  167. )
  168. if not should_skip and line.strip():
  169. if (
  170. line.strip().startswith("#")
  171. and len(tail) > (index + 1)
  172. and tail[index + 1].strip()
  173. ):
  174. continue
  175. next_construct = line
  176. break
  177. if in_quote: # pragma: no branch
  178. next_construct = line
  179. break
  180. if config.lines_after_imports != -1:
  181. formatted_output[imports_tail:0] = [
  182. "" for line in range(config.lines_after_imports)
  183. ]
  184. elif extension != "pyi" and next_construct.startswith(STATEMENT_DECLARATIONS):
  185. formatted_output[imports_tail:0] = ["", ""]
  186. else:
  187. formatted_output[imports_tail:0] = [""]
  188. if config.lines_before_imports != -1:
  189. formatted_output[:0] = ["" for line in range(config.lines_before_imports)]
  190. if parsed.place_imports:
  191. new_out_lines = []
  192. for index, line in enumerate(formatted_output):
  193. new_out_lines.append(line)
  194. if line in parsed.import_placements:
  195. new_out_lines.extend(parsed.place_imports[parsed.import_placements[line]])
  196. if (
  197. len(formatted_output) <= (index + 1)
  198. or formatted_output[index + 1].strip() != ""
  199. ):
  200. new_out_lines.append("")
  201. formatted_output = new_out_lines
  202. return _output_as_string(formatted_output, parsed.line_separator)
  203. def _with_from_imports(
  204. parsed: parse.ParsedContent,
  205. config: Config,
  206. from_modules: Iterable[str],
  207. section: str,
  208. remove_imports: List[str],
  209. import_type: str,
  210. ) -> List[str]:
  211. output: List[str] = []
  212. for module in from_modules:
  213. if module in remove_imports:
  214. continue
  215. import_start = f"from {module} {import_type} "
  216. from_imports = list(parsed.imports[section]["from"][module])
  217. if (
  218. not config.no_inline_sort
  219. or (config.force_single_line and module not in config.single_line_exclusions)
  220. ) and not config.only_sections:
  221. from_imports = sorting.sort(
  222. config,
  223. from_imports,
  224. key=lambda key: sorting.module_key(
  225. key,
  226. config,
  227. True,
  228. config.force_alphabetical_sort_within_sections,
  229. section_name=section,
  230. ),
  231. reverse=config.reverse_sort,
  232. )
  233. if remove_imports:
  234. from_imports = [
  235. line for line in from_imports if f"{module}.{line}" not in remove_imports
  236. ]
  237. sub_modules = [f"{module}.{from_import}" for from_import in from_imports]
  238. as_imports = {
  239. from_import: [
  240. f"{from_import} as {as_module}" for as_module in parsed.as_map["from"][sub_module]
  241. ]
  242. for from_import, sub_module in zip(from_imports, sub_modules)
  243. if sub_module in parsed.as_map["from"]
  244. }
  245. if config.combine_as_imports and not ("*" in from_imports and config.combine_star):
  246. if not config.no_inline_sort:
  247. for as_import in as_imports:
  248. if not config.only_sections:
  249. as_imports[as_import] = sorting.sort(config, as_imports[as_import])
  250. for from_import in copy.copy(from_imports):
  251. if from_import in as_imports:
  252. idx = from_imports.index(from_import)
  253. if parsed.imports[section]["from"][module][from_import]:
  254. from_imports[(idx + 1) : (idx + 1)] = as_imports.pop(from_import)
  255. else:
  256. from_imports[idx : (idx + 1)] = as_imports.pop(from_import)
  257. only_show_as_imports = False
  258. comments = parsed.categorized_comments["from"].pop(module, ())
  259. above_comments = parsed.categorized_comments["above"]["from"].pop(module, None)
  260. while from_imports:
  261. if above_comments:
  262. output.extend(above_comments)
  263. above_comments = None
  264. if "*" in from_imports and config.combine_star:
  265. import_statement = wrap.line(
  266. with_comments(
  267. _with_star_comments(parsed, module, list(comments or ())),
  268. f"{import_start}*",
  269. removed=config.ignore_comments,
  270. comment_prefix=config.comment_prefix,
  271. ),
  272. parsed.line_separator,
  273. config,
  274. )
  275. from_imports = [
  276. from_import for from_import in from_imports if from_import in as_imports
  277. ]
  278. only_show_as_imports = True
  279. elif config.force_single_line and module not in config.single_line_exclusions:
  280. import_statement = ""
  281. while from_imports:
  282. from_import = from_imports.pop(0)
  283. single_import_line = with_comments(
  284. comments,
  285. import_start + from_import,
  286. removed=config.ignore_comments,
  287. comment_prefix=config.comment_prefix,
  288. )
  289. comment = (
  290. parsed.categorized_comments["nested"].get(module, {}).pop(from_import, None)
  291. )
  292. if comment:
  293. single_import_line += (
  294. f"{comments and ';' or config.comment_prefix} " f"{comment}"
  295. )
  296. if from_import in as_imports:
  297. if (
  298. parsed.imports[section]["from"][module][from_import]
  299. and not only_show_as_imports
  300. ):
  301. output.append(
  302. wrap.line(single_import_line, parsed.line_separator, config)
  303. )
  304. from_comments = parsed.categorized_comments["straight"].get(
  305. f"{module}.{from_import}"
  306. )
  307. if not config.only_sections:
  308. output.extend(
  309. with_comments(
  310. from_comments,
  311. wrap.line(
  312. import_start + as_import, parsed.line_separator, config
  313. ),
  314. removed=config.ignore_comments,
  315. comment_prefix=config.comment_prefix,
  316. )
  317. for as_import in sorting.sort(config, as_imports[from_import])
  318. )
  319. else:
  320. output.extend(
  321. with_comments(
  322. from_comments,
  323. wrap.line(
  324. import_start + as_import, parsed.line_separator, config
  325. ),
  326. removed=config.ignore_comments,
  327. comment_prefix=config.comment_prefix,
  328. )
  329. for as_import in as_imports[from_import]
  330. )
  331. else:
  332. output.append(wrap.line(single_import_line, parsed.line_separator, config))
  333. comments = None
  334. else:
  335. while from_imports and from_imports[0] in as_imports:
  336. from_import = from_imports.pop(0)
  337. if not config.only_sections:
  338. as_imports[from_import] = sorting.sort(config, as_imports[from_import])
  339. from_comments = (
  340. parsed.categorized_comments["straight"].get(f"{module}.{from_import}") or []
  341. )
  342. if (
  343. parsed.imports[section]["from"][module][from_import]
  344. and not only_show_as_imports
  345. ):
  346. specific_comment = (
  347. parsed.categorized_comments["nested"]
  348. .get(module, {})
  349. .pop(from_import, None)
  350. )
  351. if specific_comment:
  352. from_comments.append(specific_comment)
  353. output.append(
  354. wrap.line(
  355. with_comments(
  356. from_comments,
  357. import_start + from_import,
  358. removed=config.ignore_comments,
  359. comment_prefix=config.comment_prefix,
  360. ),
  361. parsed.line_separator,
  362. config,
  363. )
  364. )
  365. from_comments = []
  366. for as_import in as_imports[from_import]:
  367. specific_comment = (
  368. parsed.categorized_comments["nested"]
  369. .get(module, {})
  370. .pop(as_import, None)
  371. )
  372. if specific_comment:
  373. from_comments.append(specific_comment)
  374. output.append(
  375. wrap.line(
  376. with_comments(
  377. from_comments,
  378. import_start + as_import,
  379. removed=config.ignore_comments,
  380. comment_prefix=config.comment_prefix,
  381. ),
  382. parsed.line_separator,
  383. config,
  384. )
  385. )
  386. from_comments = []
  387. if "*" in from_imports:
  388. output.append(
  389. with_comments(
  390. _with_star_comments(parsed, module, []),
  391. f"{import_start}*",
  392. removed=config.ignore_comments,
  393. comment_prefix=config.comment_prefix,
  394. )
  395. )
  396. from_imports.remove("*")
  397. for from_import in copy.copy(from_imports):
  398. comment = (
  399. parsed.categorized_comments["nested"].get(module, {}).pop(from_import, None)
  400. )
  401. if comment:
  402. from_imports.remove(from_import)
  403. if from_imports:
  404. use_comments = []
  405. else:
  406. use_comments = comments
  407. comments = None
  408. single_import_line = with_comments(
  409. use_comments,
  410. import_start + from_import,
  411. removed=config.ignore_comments,
  412. comment_prefix=config.comment_prefix,
  413. )
  414. single_import_line += (
  415. f"{use_comments and ';' or config.comment_prefix} " f"{comment}"
  416. )
  417. output.append(wrap.line(single_import_line, parsed.line_separator, config))
  418. from_import_section = []
  419. while from_imports and (
  420. from_imports[0] not in as_imports
  421. or (
  422. config.combine_as_imports
  423. and parsed.imports[section]["from"][module][from_import]
  424. )
  425. ):
  426. from_import_section.append(from_imports.pop(0))
  427. if config.combine_as_imports:
  428. comments = (comments or []) + list(
  429. parsed.categorized_comments["from"].pop(f"{module}.__combined_as__", ())
  430. )
  431. import_statement = with_comments(
  432. comments,
  433. import_start + (", ").join(from_import_section),
  434. removed=config.ignore_comments,
  435. comment_prefix=config.comment_prefix,
  436. )
  437. if not from_import_section:
  438. import_statement = ""
  439. do_multiline_reformat = False
  440. force_grid_wrap = config.force_grid_wrap
  441. if force_grid_wrap and len(from_import_section) >= force_grid_wrap:
  442. do_multiline_reformat = True
  443. if len(import_statement) > config.line_length and len(from_import_section) > 1:
  444. do_multiline_reformat = True
  445. # If line too long AND have imports AND we are
  446. # NOT using GRID or VERTICAL wrap modes
  447. if (
  448. len(import_statement) > config.line_length
  449. and len(from_import_section) > 0
  450. and config.multi_line_output
  451. not in (wrap.Modes.GRID, wrap.Modes.VERTICAL) # type: ignore
  452. ):
  453. do_multiline_reformat = True
  454. if do_multiline_reformat:
  455. import_statement = wrap.import_statement(
  456. import_start=import_start,
  457. from_imports=from_import_section,
  458. comments=comments,
  459. line_separator=parsed.line_separator,
  460. config=config,
  461. )
  462. if config.multi_line_output == wrap.Modes.GRID: # type: ignore
  463. other_import_statement = wrap.import_statement(
  464. import_start=import_start,
  465. from_imports=from_import_section,
  466. comments=comments,
  467. line_separator=parsed.line_separator,
  468. config=config,
  469. multi_line_output=wrap.Modes.VERTICAL_GRID, # type: ignore
  470. )
  471. if (
  472. max(
  473. len(import_line)
  474. for import_line in import_statement.split(parsed.line_separator)
  475. )
  476. > config.line_length
  477. ):
  478. import_statement = other_import_statement
  479. if not do_multiline_reformat and len(import_statement) > config.line_length:
  480. import_statement = wrap.line(import_statement, parsed.line_separator, config)
  481. if import_statement:
  482. output.append(import_statement)
  483. return output
  484. def _with_straight_imports(
  485. parsed: parse.ParsedContent,
  486. config: Config,
  487. straight_modules: Iterable[str],
  488. section: str,
  489. remove_imports: List[str],
  490. import_type: str,
  491. ) -> List[str]:
  492. output: List[str] = []
  493. as_imports = any((module in parsed.as_map["straight"] for module in straight_modules))
  494. # combine_straight_imports only works for bare imports, 'as' imports not included
  495. if config.combine_straight_imports and not as_imports:
  496. if not straight_modules:
  497. return []
  498. above_comments: List[str] = []
  499. inline_comments: List[str] = []
  500. for module in straight_modules:
  501. if module in parsed.categorized_comments["above"]["straight"]:
  502. above_comments.extend(parsed.categorized_comments["above"]["straight"].pop(module))
  503. if module in parsed.categorized_comments["straight"]:
  504. inline_comments.extend(parsed.categorized_comments["straight"][module])
  505. combined_straight_imports = ", ".join(straight_modules)
  506. if inline_comments:
  507. combined_inline_comments = " ".join(inline_comments)
  508. else:
  509. combined_inline_comments = ""
  510. output.extend(above_comments)
  511. if combined_inline_comments:
  512. output.append(
  513. f"{import_type} {combined_straight_imports} # {combined_inline_comments}"
  514. )
  515. else:
  516. output.append(f"{import_type} {combined_straight_imports}")
  517. return output
  518. for module in straight_modules:
  519. if module in remove_imports:
  520. continue
  521. import_definition = []
  522. if module in parsed.as_map["straight"]:
  523. if parsed.imports[section]["straight"][module]:
  524. import_definition.append((f"{import_type} {module}", module))
  525. import_definition.extend(
  526. (f"{import_type} {module} as {as_import}", f"{module} as {as_import}")
  527. for as_import in parsed.as_map["straight"][module]
  528. )
  529. else:
  530. import_definition.append((f"{import_type} {module}", module))
  531. comments_above = parsed.categorized_comments["above"]["straight"].pop(module, None)
  532. if comments_above:
  533. output.extend(comments_above)
  534. output.extend(
  535. with_comments(
  536. parsed.categorized_comments["straight"].get(imodule),
  537. idef,
  538. removed=config.ignore_comments,
  539. comment_prefix=config.comment_prefix,
  540. )
  541. for idef, imodule in import_definition
  542. )
  543. return output
  544. def _output_as_string(lines: List[str], line_separator: str) -> str:
  545. return line_separator.join(_normalize_empty_lines(lines))
  546. def _normalize_empty_lines(lines: List[str]) -> List[str]:
  547. while lines and lines[-1].strip() == "":
  548. lines.pop(-1)
  549. lines.append("")
  550. return lines
  551. class _LineWithComments(str):
  552. comments: List[str]
  553. def __new__(
  554. cls: Type["_LineWithComments"], value: Any, comments: List[str]
  555. ) -> "_LineWithComments":
  556. instance = super().__new__(cls, value)
  557. instance.comments = comments
  558. return instance
  559. def _ensure_newline_before_comment(output: List[str]) -> List[str]:
  560. new_output: List[str] = []
  561. def is_comment(line: Optional[str]) -> bool:
  562. return line.startswith("#") if line else False
  563. for line, prev_line in zip(output, [None] + output): # type: ignore
  564. if is_comment(line) and prev_line != "" and not is_comment(prev_line):
  565. new_output.append("")
  566. new_output.append(line)
  567. return new_output
  568. def _with_star_comments(parsed: parse.ParsedContent, module: str, comments: List[str]) -> List[str]:
  569. star_comment = parsed.categorized_comments["nested"].get(module, {}).pop("*", None)
  570. if star_comment:
  571. return comments + [star_comment]
  572. return comments