literal.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import ast
  2. from pprint import PrettyPrinter
  3. from typing import Any, Callable, Dict, List, Set, Tuple
  4. from isort.exceptions import (
  5. AssignmentsFormatMismatch,
  6. LiteralParsingFailure,
  7. LiteralSortTypeMismatch,
  8. )
  9. from isort.settings import DEFAULT_CONFIG, Config
  10. class ISortPrettyPrinter(PrettyPrinter):
  11. """an isort customized pretty printer for sorted literals"""
  12. def __init__(self, config: Config):
  13. super().__init__(width=config.line_length, compact=True)
  14. type_mapping: Dict[str, Tuple[type, Callable[[Any, ISortPrettyPrinter], str]]] = {}
  15. def assignments(code: str) -> str:
  16. values = {}
  17. for line in code.splitlines(keepends=True):
  18. if not line.strip():
  19. continue
  20. if " = " not in line:
  21. raise AssignmentsFormatMismatch(code)
  22. variable_name, value = line.split(" = ", 1)
  23. values[variable_name] = value
  24. return "".join(
  25. f"{variable_name} = {values[variable_name]}" for variable_name in sorted(values.keys())
  26. )
  27. def assignment(code: str, sort_type: str, extension: str, config: Config = DEFAULT_CONFIG) -> str:
  28. """Sorts the literal present within the provided code against the provided sort type,
  29. returning the sorted representation of the source code.
  30. """
  31. if sort_type == "assignments":
  32. return assignments(code)
  33. if sort_type not in type_mapping:
  34. raise ValueError(
  35. "Trying to sort using an undefined sort_type. "
  36. f"Defined sort types are {', '.join(type_mapping.keys())}."
  37. )
  38. variable_name, literal = code.split(" = ")
  39. variable_name = variable_name.lstrip()
  40. try:
  41. value = ast.literal_eval(literal)
  42. except Exception as error:
  43. raise LiteralParsingFailure(code, error)
  44. expected_type, sort_function = type_mapping[sort_type]
  45. if type(value) != expected_type:
  46. raise LiteralSortTypeMismatch(type(value), expected_type)
  47. printer = ISortPrettyPrinter(config)
  48. sorted_value_code = f"{variable_name} = {sort_function(value, printer)}"
  49. if config.formatting_function:
  50. sorted_value_code = config.formatting_function(
  51. sorted_value_code, extension, config
  52. ).rstrip()
  53. sorted_value_code += code[len(code.rstrip()) :]
  54. return sorted_value_code
  55. def register_type(
  56. name: str, kind: type
  57. ) -> Callable[[Callable[[Any, ISortPrettyPrinter], str]], Callable[[Any, ISortPrettyPrinter], str]]:
  58. """Registers a new literal sort type."""
  59. def wrap(
  60. function: Callable[[Any, ISortPrettyPrinter], str]
  61. ) -> Callable[[Any, ISortPrettyPrinter], str]:
  62. type_mapping[name] = (kind, function)
  63. return function
  64. return wrap
  65. @register_type("dict", dict)
  66. def _dict(value: Dict[Any, Any], printer: ISortPrettyPrinter) -> str:
  67. return printer.pformat(dict(sorted(value.items(), key=lambda item: item[1]))) # type: ignore
  68. @register_type("list", list)
  69. def _list(value: List[Any], printer: ISortPrettyPrinter) -> str:
  70. return printer.pformat(sorted(value))
  71. @register_type("unique-list", list)
  72. def _unique_list(value: List[Any], printer: ISortPrettyPrinter) -> str:
  73. return printer.pformat(list(sorted(set(value))))
  74. @register_type("set", set)
  75. def _set(value: Set[Any], printer: ISortPrettyPrinter) -> str:
  76. return "{" + printer.pformat(tuple(sorted(value)))[1:-1] + "}"
  77. @register_type("tuple", tuple)
  78. def _tuple(value: Tuple[Any, ...], printer: ISortPrettyPrinter) -> str:
  79. return printer.pformat(tuple(sorted(value)))
  80. @register_type("unique-tuple", tuple)
  81. def _unique_tuple(value: Tuple[Any, ...], printer: ISortPrettyPrinter) -> str:
  82. return printer.pformat(tuple(sorted(set(value))))