writer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) 2008-2010, 2013-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
  2. # Copyright (c) 2014 Arun Persaud <arun@nubati.net>
  3. # Copyright (c) 2015-2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
  4. # Copyright (c) 2015 Mike Frysinger <vapier@gentoo.org>
  5. # Copyright (c) 2015 Florian Bruhin <me@the-compiler.org>
  6. # Copyright (c) 2015 Ionel Cristian Maries <contact@ionelmc.ro>
  7. # Copyright (c) 2018, 2020 Anthony Sottile <asottile@umich.edu>
  8. # Copyright (c) 2018 ssolanki <sushobhitsolanki@gmail.com>
  9. # Copyright (c) 2019-2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
  10. # Copyright (c) 2019 Kylian <development@goudcode.nl>
  11. # Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
  12. # Copyright (c) 2021 Andreas Finkler <andi.finkler@gmail.com>
  13. # Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
  14. # Copyright (c) 2021 Mark Byrne <31762852+mbyrnepr2@users.noreply.github.com>
  15. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  16. # For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
  17. """Utilities for creating VCG and Dot diagrams"""
  18. import itertools
  19. import os
  20. from astroid import modutils, nodes
  21. from pylint.pyreverse.diagrams import (
  22. ClassDiagram,
  23. ClassEntity,
  24. DiagramEntity,
  25. PackageDiagram,
  26. PackageEntity,
  27. )
  28. from pylint.pyreverse.printer import EdgeType, NodeProperties, NodeType
  29. from pylint.pyreverse.printer_factory import get_printer_for_filetype
  30. from pylint.pyreverse.utils import is_exception
  31. class DiagramWriter:
  32. """base class for writing project diagrams"""
  33. def __init__(self, config):
  34. self.config = config
  35. self.printer_class = get_printer_for_filetype(self.config.output_format)
  36. self.printer = None # defined in set_printer
  37. self.file_name = "" # defined in set_printer
  38. self.depth = self.config.max_color_depth
  39. self.available_colors = itertools.cycle(
  40. [
  41. "aliceblue",
  42. "antiquewhite",
  43. "aquamarine",
  44. "burlywood",
  45. "cadetblue",
  46. "chartreuse",
  47. "chocolate",
  48. "coral",
  49. "cornflowerblue",
  50. "cyan",
  51. "darkgoldenrod",
  52. "darkseagreen",
  53. "dodgerblue",
  54. "forestgreen",
  55. "gold",
  56. "hotpink",
  57. "mediumspringgreen",
  58. ]
  59. )
  60. self.used_colors = {}
  61. def write(self, diadefs):
  62. """write files for <project> according to <diadefs>"""
  63. for diagram in diadefs:
  64. basename = diagram.title.strip().replace(" ", "_")
  65. file_name = f"{basename}.{self.config.output_format}"
  66. if os.path.exists(self.config.output_directory):
  67. file_name = os.path.join(self.config.output_directory, file_name)
  68. self.set_printer(file_name, basename)
  69. if diagram.TYPE == "class":
  70. self.write_classes(diagram)
  71. else:
  72. self.write_packages(diagram)
  73. self.save()
  74. def write_packages(self, diagram: PackageDiagram) -> None:
  75. """write a package diagram"""
  76. # sorted to get predictable (hence testable) results
  77. for module in sorted(diagram.modules(), key=lambda x: x.title):
  78. module.fig_id = module.node.qname()
  79. self.printer.emit_node(
  80. module.fig_id,
  81. type_=NodeType.PACKAGE,
  82. properties=self.get_package_properties(module),
  83. )
  84. # package dependencies
  85. for rel in diagram.get_relationships("depends"):
  86. self.printer.emit_edge(
  87. rel.from_object.fig_id,
  88. rel.to_object.fig_id,
  89. type_=EdgeType.USES,
  90. )
  91. def write_classes(self, diagram: ClassDiagram) -> None:
  92. """write a class diagram"""
  93. # sorted to get predictable (hence testable) results
  94. for obj in sorted(diagram.objects, key=lambda x: x.title):
  95. obj.fig_id = obj.node.qname()
  96. type_ = NodeType.INTERFACE if obj.shape == "interface" else NodeType.CLASS
  97. self.printer.emit_node(
  98. obj.fig_id, type_=type_, properties=self.get_class_properties(obj)
  99. )
  100. # inheritance links
  101. for rel in diagram.get_relationships("specialization"):
  102. self.printer.emit_edge(
  103. rel.from_object.fig_id,
  104. rel.to_object.fig_id,
  105. type_=EdgeType.INHERITS,
  106. )
  107. # implementation links
  108. for rel in diagram.get_relationships("implements"):
  109. self.printer.emit_edge(
  110. rel.from_object.fig_id,
  111. rel.to_object.fig_id,
  112. type_=EdgeType.IMPLEMENTS,
  113. )
  114. # generate associations
  115. for rel in diagram.get_relationships("association"):
  116. self.printer.emit_edge(
  117. rel.from_object.fig_id,
  118. rel.to_object.fig_id,
  119. label=rel.name,
  120. type_=EdgeType.ASSOCIATION,
  121. )
  122. def set_printer(self, file_name: str, basename: str) -> None:
  123. """set printer"""
  124. self.printer = self.printer_class(basename)
  125. self.file_name = file_name
  126. def get_package_properties(self, obj: PackageEntity) -> NodeProperties:
  127. """get label and shape for packages."""
  128. return NodeProperties(
  129. label=obj.title,
  130. color=self.get_shape_color(obj) if self.config.colorized else "black",
  131. )
  132. def get_class_properties(self, obj: ClassEntity) -> NodeProperties:
  133. """get label and shape for classes."""
  134. properties = NodeProperties(
  135. label=obj.title,
  136. attrs=obj.attrs if not self.config.only_classnames else None,
  137. methods=obj.methods if not self.config.only_classnames else None,
  138. fontcolor="red" if is_exception(obj.node) else "black",
  139. color=self.get_shape_color(obj) if self.config.colorized else "black",
  140. )
  141. return properties
  142. def get_shape_color(self, obj: DiagramEntity) -> str:
  143. """get shape color"""
  144. qualified_name = obj.node.qname()
  145. if modutils.is_standard_module(qualified_name.split(".", maxsplit=1)[0]):
  146. return "grey"
  147. if isinstance(obj.node, nodes.ClassDef):
  148. package = qualified_name.rsplit(".", maxsplit=2)[0]
  149. elif obj.node.package:
  150. package = qualified_name
  151. else:
  152. package = qualified_name.rsplit(".", maxsplit=1)[0]
  153. base_name = ".".join(package.split(".", self.depth)[: self.depth])
  154. if base_name not in self.used_colors:
  155. self.used_colors[base_name] = next(self.available_colors)
  156. return self.used_colors[base_name]
  157. def save(self) -> None:
  158. """write to disk"""
  159. self.printer.generate(self.file_name)