progress.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. from abc import ABC, abstractmethod
  2. from collections import deque
  3. from collections.abc import Sized
  4. from dataclasses import dataclass, field
  5. from datetime import timedelta
  6. from math import ceil
  7. from threading import Event, RLock, Thread
  8. from types import TracebackType
  9. from typing import (
  10. Any,
  11. Callable,
  12. Deque,
  13. Dict,
  14. Iterable,
  15. List,
  16. NamedTuple,
  17. NewType,
  18. Optional,
  19. Sequence,
  20. Tuple,
  21. Type,
  22. TypeVar,
  23. Union,
  24. )
  25. from . import filesize, get_console
  26. from .console import Console, JustifyMethod, RenderableType, Group
  27. from .highlighter import Highlighter
  28. from .jupyter import JupyterMixin
  29. from .live import Live
  30. from .progress_bar import ProgressBar
  31. from .spinner import Spinner
  32. from .style import StyleType
  33. from .table import Column, Table
  34. from .text import Text, TextType
  35. TaskID = NewType("TaskID", int)
  36. ProgressType = TypeVar("ProgressType")
  37. GetTimeCallable = Callable[[], float]
  38. class _TrackThread(Thread):
  39. """A thread to periodically update progress."""
  40. def __init__(self, progress: "Progress", task_id: "TaskID", update_period: float):
  41. self.progress = progress
  42. self.task_id = task_id
  43. self.update_period = update_period
  44. self.done = Event()
  45. self.completed = 0
  46. super().__init__()
  47. def run(self) -> None:
  48. task_id = self.task_id
  49. advance = self.progress.advance
  50. update_period = self.update_period
  51. last_completed = 0
  52. wait = self.done.wait
  53. while not wait(update_period):
  54. completed = self.completed
  55. if last_completed != completed:
  56. advance(task_id, completed - last_completed)
  57. last_completed = completed
  58. self.progress.update(self.task_id, completed=self.completed, refresh=True)
  59. def __enter__(self) -> "_TrackThread":
  60. self.start()
  61. return self
  62. def __exit__(
  63. self,
  64. exc_type: Optional[Type[BaseException]],
  65. exc_val: Optional[BaseException],
  66. exc_tb: Optional[TracebackType],
  67. ) -> None:
  68. self.done.set()
  69. self.join()
  70. def track(
  71. sequence: Union[Sequence[ProgressType], Iterable[ProgressType]],
  72. description: str = "Working...",
  73. total: Optional[float] = None,
  74. auto_refresh: bool = True,
  75. console: Optional[Console] = None,
  76. transient: bool = False,
  77. get_time: Optional[Callable[[], float]] = None,
  78. refresh_per_second: float = 10,
  79. style: StyleType = "bar.back",
  80. complete_style: StyleType = "bar.complete",
  81. finished_style: StyleType = "bar.finished",
  82. pulse_style: StyleType = "bar.pulse",
  83. update_period: float = 0.1,
  84. disable: bool = False,
  85. ) -> Iterable[ProgressType]:
  86. """Track progress by iterating over a sequence.
  87. Args:
  88. sequence (Iterable[ProgressType]): A sequence (must support "len") you wish to iterate over.
  89. description (str, optional): Description of task show next to progress bar. Defaults to "Working".
  90. total: (float, optional): Total number of steps. Default is len(sequence).
  91. auto_refresh (bool, optional): Automatic refresh, disable to force a refresh after each iteration. Default is True.
  92. transient: (bool, optional): Clear the progress on exit. Defaults to False.
  93. console (Console, optional): Console to write to. Default creates internal Console instance.
  94. refresh_per_second (float): Number of times per second to refresh the progress information. Defaults to 10.
  95. style (StyleType, optional): Style for the bar background. Defaults to "bar.back".
  96. complete_style (StyleType, optional): Style for the completed bar. Defaults to "bar.complete".
  97. finished_style (StyleType, optional): Style for a finished bar. Defaults to "bar.done".
  98. pulse_style (StyleType, optional): Style for pulsing bars. Defaults to "bar.pulse".
  99. update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1.
  100. disable (bool, optional): Disable display of progress.
  101. Returns:
  102. Iterable[ProgressType]: An iterable of the values in the sequence.
  103. """
  104. columns: List["ProgressColumn"] = (
  105. [TextColumn("[progress.description]{task.description}")] if description else []
  106. )
  107. columns.extend(
  108. (
  109. BarColumn(
  110. style=style,
  111. complete_style=complete_style,
  112. finished_style=finished_style,
  113. pulse_style=pulse_style,
  114. ),
  115. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  116. TimeRemainingColumn(),
  117. )
  118. )
  119. progress = Progress(
  120. *columns,
  121. auto_refresh=auto_refresh,
  122. console=console,
  123. transient=transient,
  124. get_time=get_time,
  125. refresh_per_second=refresh_per_second or 10,
  126. disable=disable,
  127. )
  128. with progress:
  129. yield from progress.track(
  130. sequence, total=total, description=description, update_period=update_period
  131. )
  132. class ProgressColumn(ABC):
  133. """Base class for a widget to use in progress display."""
  134. max_refresh: Optional[float] = None
  135. def __init__(self, table_column: Optional[Column] = None) -> None:
  136. self._table_column = table_column
  137. self._renderable_cache: Dict[TaskID, Tuple[float, RenderableType]] = {}
  138. self._update_time: Optional[float] = None
  139. def get_table_column(self) -> Column:
  140. """Get a table column, used to build tasks table."""
  141. return self._table_column or Column()
  142. def __call__(self, task: "Task") -> RenderableType:
  143. """Called by the Progress object to return a renderable for the given task.
  144. Args:
  145. task (Task): An object containing information regarding the task.
  146. Returns:
  147. RenderableType: Anything renderable (including str).
  148. """
  149. current_time = task.get_time()
  150. if self.max_refresh is not None and not task.completed:
  151. try:
  152. timestamp, renderable = self._renderable_cache[task.id]
  153. except KeyError:
  154. pass
  155. else:
  156. if timestamp + self.max_refresh > current_time:
  157. return renderable
  158. renderable = self.render(task)
  159. self._renderable_cache[task.id] = (current_time, renderable)
  160. return renderable
  161. @abstractmethod
  162. def render(self, task: "Task") -> RenderableType:
  163. """Should return a renderable object."""
  164. class RenderableColumn(ProgressColumn):
  165. """A column to insert an arbitrary column.
  166. Args:
  167. renderable (RenderableType, optional): Any renderable. Defaults to empty string.
  168. """
  169. def __init__(
  170. self, renderable: RenderableType = "", *, table_column: Optional[Column] = None
  171. ):
  172. self.renderable = renderable
  173. super().__init__(table_column=table_column)
  174. def render(self, task: "Task") -> RenderableType:
  175. return self.renderable
  176. class SpinnerColumn(ProgressColumn):
  177. """A column with a 'spinner' animation.
  178. Args:
  179. spinner_name (str, optional): Name of spinner animation. Defaults to "dots".
  180. style (StyleType, optional): Style of spinner. Defaults to "progress.spinner".
  181. speed (float, optional): Speed factor of spinner. Defaults to 1.0.
  182. finished_text (TextType, optional): Text used when task is finished. Defaults to " ".
  183. """
  184. def __init__(
  185. self,
  186. spinner_name: str = "dots",
  187. style: Optional[StyleType] = "progress.spinner",
  188. speed: float = 1.0,
  189. finished_text: TextType = " ",
  190. table_column: Optional[Column] = None,
  191. ):
  192. self.spinner = Spinner(spinner_name, style=style, speed=speed)
  193. self.finished_text = (
  194. Text.from_markup(finished_text)
  195. if isinstance(finished_text, str)
  196. else finished_text
  197. )
  198. super().__init__(table_column=table_column)
  199. def set_spinner(
  200. self,
  201. spinner_name: str,
  202. spinner_style: Optional[StyleType] = "progress.spinner",
  203. speed: float = 1.0,
  204. ) -> None:
  205. """Set a new spinner.
  206. Args:
  207. spinner_name (str): Spinner name, see python -m rich.spinner.
  208. spinner_style (Optional[StyleType], optional): Spinner style. Defaults to "progress.spinner".
  209. speed (float, optional): Speed factor of spinner. Defaults to 1.0.
  210. """
  211. self.spinner = Spinner(spinner_name, style=spinner_style, speed=speed)
  212. def render(self, task: "Task") -> RenderableType:
  213. text = (
  214. self.finished_text
  215. if task.finished
  216. else self.spinner.render(task.get_time())
  217. )
  218. return text
  219. class TextColumn(ProgressColumn):
  220. """A column containing text."""
  221. def __init__(
  222. self,
  223. text_format: str,
  224. style: StyleType = "none",
  225. justify: JustifyMethod = "left",
  226. markup: bool = True,
  227. highlighter: Optional[Highlighter] = None,
  228. table_column: Optional[Column] = None,
  229. ) -> None:
  230. self.text_format = text_format
  231. self.justify: JustifyMethod = justify
  232. self.style = style
  233. self.markup = markup
  234. self.highlighter = highlighter
  235. super().__init__(table_column=table_column or Column(no_wrap=True))
  236. def render(self, task: "Task") -> Text:
  237. _text = self.text_format.format(task=task)
  238. if self.markup:
  239. text = Text.from_markup(_text, style=self.style, justify=self.justify)
  240. else:
  241. text = Text(_text, style=self.style, justify=self.justify)
  242. if self.highlighter:
  243. self.highlighter.highlight(text)
  244. return text
  245. class BarColumn(ProgressColumn):
  246. """Renders a visual progress bar.
  247. Args:
  248. bar_width (Optional[int], optional): Width of bar or None for full width. Defaults to 40.
  249. style (StyleType, optional): Style for the bar background. Defaults to "bar.back".
  250. complete_style (StyleType, optional): Style for the completed bar. Defaults to "bar.complete".
  251. finished_style (StyleType, optional): Style for a finished bar. Defaults to "bar.done".
  252. pulse_style (StyleType, optional): Style for pulsing bars. Defaults to "bar.pulse".
  253. """
  254. def __init__(
  255. self,
  256. bar_width: Optional[int] = 40,
  257. style: StyleType = "bar.back",
  258. complete_style: StyleType = "bar.complete",
  259. finished_style: StyleType = "bar.finished",
  260. pulse_style: StyleType = "bar.pulse",
  261. table_column: Optional[Column] = None,
  262. ) -> None:
  263. self.bar_width = bar_width
  264. self.style = style
  265. self.complete_style = complete_style
  266. self.finished_style = finished_style
  267. self.pulse_style = pulse_style
  268. super().__init__(table_column=table_column)
  269. def render(self, task: "Task") -> ProgressBar:
  270. """Gets a progress bar widget for a task."""
  271. return ProgressBar(
  272. total=max(0, task.total),
  273. completed=max(0, task.completed),
  274. width=None if self.bar_width is None else max(1, self.bar_width),
  275. pulse=not task.started,
  276. animation_time=task.get_time(),
  277. style=self.style,
  278. complete_style=self.complete_style,
  279. finished_style=self.finished_style,
  280. pulse_style=self.pulse_style,
  281. )
  282. class TimeElapsedColumn(ProgressColumn):
  283. """Renders time elapsed."""
  284. def render(self, task: "Task") -> Text:
  285. """Show time remaining."""
  286. elapsed = task.finished_time if task.finished else task.elapsed
  287. if elapsed is None:
  288. return Text("-:--:--", style="progress.elapsed")
  289. delta = timedelta(seconds=int(elapsed))
  290. return Text(str(delta), style="progress.elapsed")
  291. class TimeRemainingColumn(ProgressColumn):
  292. """Renders estimated time remaining."""
  293. # Only refresh twice a second to prevent jitter
  294. max_refresh = 0.5
  295. def render(self, task: "Task") -> Text:
  296. """Show time remaining."""
  297. remaining = task.time_remaining
  298. if remaining is None:
  299. return Text("-:--:--", style="progress.remaining")
  300. remaining_delta = timedelta(seconds=int(remaining))
  301. return Text(str(remaining_delta), style="progress.remaining")
  302. class FileSizeColumn(ProgressColumn):
  303. """Renders completed filesize."""
  304. def render(self, task: "Task") -> Text:
  305. """Show data completed."""
  306. data_size = filesize.decimal(int(task.completed))
  307. return Text(data_size, style="progress.filesize")
  308. class TotalFileSizeColumn(ProgressColumn):
  309. """Renders total filesize."""
  310. def render(self, task: "Task") -> Text:
  311. """Show data completed."""
  312. data_size = filesize.decimal(int(task.total))
  313. return Text(data_size, style="progress.filesize.total")
  314. class DownloadColumn(ProgressColumn):
  315. """Renders file size downloaded and total, e.g. '0.5/2.3 GB'.
  316. Args:
  317. binary_units (bool, optional): Use binary units, KiB, MiB etc. Defaults to False.
  318. """
  319. def __init__(
  320. self, binary_units: bool = False, table_column: Optional[Column] = None
  321. ) -> None:
  322. self.binary_units = binary_units
  323. super().__init__(table_column=table_column)
  324. def render(self, task: "Task") -> Text:
  325. """Calculate common unit for completed and total."""
  326. completed = int(task.completed)
  327. total = int(task.total)
  328. if self.binary_units:
  329. unit, suffix = filesize.pick_unit_and_suffix(
  330. total,
  331. ["bytes", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"],
  332. 1024,
  333. )
  334. else:
  335. unit, suffix = filesize.pick_unit_and_suffix(
  336. total, ["bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"], 1000
  337. )
  338. completed_ratio = completed / unit
  339. total_ratio = total / unit
  340. precision = 0 if unit == 1 else 1
  341. completed_str = f"{completed_ratio:,.{precision}f}"
  342. total_str = f"{total_ratio:,.{precision}f}"
  343. download_status = f"{completed_str}/{total_str} {suffix}"
  344. download_text = Text(download_status, style="progress.download")
  345. return download_text
  346. class TransferSpeedColumn(ProgressColumn):
  347. """Renders human readable transfer speed."""
  348. def render(self, task: "Task") -> Text:
  349. """Show data transfer speed."""
  350. speed = task.finished_speed or task.speed
  351. if speed is None:
  352. return Text("?", style="progress.data.speed")
  353. data_speed = filesize.decimal(int(speed))
  354. return Text(f"{data_speed}/s", style="progress.data.speed")
  355. class ProgressSample(NamedTuple):
  356. """Sample of progress for a given time."""
  357. timestamp: float
  358. """Timestamp of sample."""
  359. completed: float
  360. """Number of steps completed."""
  361. @dataclass
  362. class Task:
  363. """Information regarding a progress task.
  364. This object should be considered read-only outside of the :class:`~Progress` class.
  365. """
  366. id: TaskID
  367. """Task ID associated with this task (used in Progress methods)."""
  368. description: str
  369. """str: Description of the task."""
  370. total: float
  371. """str: Total number of steps in this task."""
  372. completed: float
  373. """float: Number of steps completed"""
  374. _get_time: GetTimeCallable
  375. """Callable to get the current time."""
  376. finished_time: Optional[float] = None
  377. """float: Time task was finished."""
  378. visible: bool = True
  379. """bool: Indicates if this task is visible in the progress display."""
  380. fields: Dict[str, Any] = field(default_factory=dict)
  381. """dict: Arbitrary fields passed in via Progress.update."""
  382. start_time: Optional[float] = field(default=None, init=False, repr=False)
  383. """Optional[float]: Time this task was started, or None if not started."""
  384. stop_time: Optional[float] = field(default=None, init=False, repr=False)
  385. """Optional[float]: Time this task was stopped, or None if not stopped."""
  386. finished_speed: Optional[float] = None
  387. """Optional[float]: The last speed for a finished task."""
  388. _progress: Deque[ProgressSample] = field(
  389. default_factory=deque, init=False, repr=False
  390. )
  391. _lock: RLock = field(repr=False, default_factory=RLock)
  392. """Thread lock."""
  393. def get_time(self) -> float:
  394. """float: Get the current time, in seconds."""
  395. return self._get_time()
  396. @property
  397. def started(self) -> bool:
  398. """bool: Check if the task as started."""
  399. return self.start_time is not None
  400. @property
  401. def remaining(self) -> float:
  402. """float: Get the number of steps remaining."""
  403. return self.total - self.completed
  404. @property
  405. def elapsed(self) -> Optional[float]:
  406. """Optional[float]: Time elapsed since task was started, or ``None`` if the task hasn't started."""
  407. if self.start_time is None:
  408. return None
  409. if self.stop_time is not None:
  410. return self.stop_time - self.start_time
  411. return self.get_time() - self.start_time
  412. @property
  413. def finished(self) -> bool:
  414. """Check if the task has finished."""
  415. return self.finished_time is not None
  416. @property
  417. def percentage(self) -> float:
  418. """float: Get progress of task as a percentage."""
  419. if not self.total:
  420. return 0.0
  421. completed = (self.completed / self.total) * 100.0
  422. completed = min(100.0, max(0.0, completed))
  423. return completed
  424. @property
  425. def speed(self) -> Optional[float]:
  426. """Optional[float]: Get the estimated speed in steps per second."""
  427. if self.start_time is None:
  428. return None
  429. with self._lock:
  430. progress = self._progress
  431. if not progress:
  432. return None
  433. total_time = progress[-1].timestamp - progress[0].timestamp
  434. if total_time == 0:
  435. return None
  436. iter_progress = iter(progress)
  437. next(iter_progress)
  438. total_completed = sum(sample.completed for sample in iter_progress)
  439. speed = total_completed / total_time
  440. return speed
  441. @property
  442. def time_remaining(self) -> Optional[float]:
  443. """Optional[float]: Get estimated time to completion, or ``None`` if no data."""
  444. if self.finished:
  445. return 0.0
  446. speed = self.speed
  447. if not speed:
  448. return None
  449. estimate = ceil(self.remaining / speed)
  450. return estimate
  451. def _reset(self) -> None:
  452. """Reset progress."""
  453. self._progress.clear()
  454. self.finished_time = None
  455. self.finished_speed = None
  456. class Progress(JupyterMixin):
  457. """Renders an auto-updating progress bar(s).
  458. Args:
  459. console (Console, optional): Optional Console instance. Default will an internal Console instance writing to stdout.
  460. auto_refresh (bool, optional): Enable auto refresh. If disabled, you will need to call `refresh()`.
  461. refresh_per_second (Optional[float], optional): Number of times per second to refresh the progress information or None to use default (10). Defaults to None.
  462. speed_estimate_period: (float, optional): Period (in seconds) used to calculate the speed estimate. Defaults to 30.
  463. transient: (bool, optional): Clear the progress on exit. Defaults to False.
  464. redirect_stdout: (bool, optional): Enable redirection of stdout, so ``print`` may be used. Defaults to True.
  465. redirect_stderr: (bool, optional): Enable redirection of stderr. Defaults to True.
  466. get_time: (Callable, optional): A callable that gets the current time, or None to use Console.get_time. Defaults to None.
  467. disable (bool, optional): Disable progress display. Defaults to False
  468. expand (bool, optional): Expand tasks table to fit width. Defaults to False.
  469. """
  470. def __init__(
  471. self,
  472. *columns: Union[str, ProgressColumn],
  473. console: Optional[Console] = None,
  474. auto_refresh: bool = True,
  475. refresh_per_second: float = 10,
  476. speed_estimate_period: float = 30.0,
  477. transient: bool = False,
  478. redirect_stdout: bool = True,
  479. redirect_stderr: bool = True,
  480. get_time: Optional[GetTimeCallable] = None,
  481. disable: bool = False,
  482. expand: bool = False,
  483. ) -> None:
  484. assert (
  485. refresh_per_second is None or refresh_per_second > 0
  486. ), "refresh_per_second must be > 0"
  487. self._lock = RLock()
  488. self.columns = columns or (
  489. TextColumn("[progress.description]{task.description}"),
  490. BarColumn(),
  491. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  492. TimeRemainingColumn(),
  493. )
  494. self.speed_estimate_period = speed_estimate_period
  495. self.disable = disable
  496. self.expand = expand
  497. self._tasks: Dict[TaskID, Task] = {}
  498. self._task_index: TaskID = TaskID(0)
  499. self.live = Live(
  500. console=console or get_console(),
  501. auto_refresh=auto_refresh,
  502. refresh_per_second=refresh_per_second,
  503. transient=transient,
  504. redirect_stdout=redirect_stdout,
  505. redirect_stderr=redirect_stderr,
  506. get_renderable=self.get_renderable,
  507. )
  508. self.get_time = get_time or self.console.get_time
  509. self.print = self.console.print
  510. self.log = self.console.log
  511. @property
  512. def console(self) -> Console:
  513. return self.live.console
  514. @property
  515. def tasks(self) -> List[Task]:
  516. """Get a list of Task instances."""
  517. with self._lock:
  518. return list(self._tasks.values())
  519. @property
  520. def task_ids(self) -> List[TaskID]:
  521. """A list of task IDs."""
  522. with self._lock:
  523. return list(self._tasks.keys())
  524. @property
  525. def finished(self) -> bool:
  526. """Check if all tasks have been completed."""
  527. with self._lock:
  528. if not self._tasks:
  529. return True
  530. return all(task.finished for task in self._tasks.values())
  531. def start(self) -> None:
  532. """Start the progress display."""
  533. if not self.disable:
  534. self.live.start(refresh=True)
  535. def stop(self) -> None:
  536. """Stop the progress display."""
  537. self.live.stop()
  538. if not self.console.is_interactive:
  539. self.console.print()
  540. def __enter__(self) -> "Progress":
  541. self.start()
  542. return self
  543. def __exit__(
  544. self,
  545. exc_type: Optional[Type[BaseException]],
  546. exc_val: Optional[BaseException],
  547. exc_tb: Optional[TracebackType],
  548. ) -> None:
  549. self.stop()
  550. def track(
  551. self,
  552. sequence: Union[Iterable[ProgressType], Sequence[ProgressType]],
  553. total: Optional[float] = None,
  554. task_id: Optional[TaskID] = None,
  555. description: str = "Working...",
  556. update_period: float = 0.1,
  557. ) -> Iterable[ProgressType]:
  558. """Track progress by iterating over a sequence.
  559. Args:
  560. sequence (Sequence[ProgressType]): A sequence of values you want to iterate over and track progress.
  561. total: (float, optional): Total number of steps. Default is len(sequence).
  562. task_id: (TaskID): Task to track. Default is new task.
  563. description: (str, optional): Description of task, if new task is created.
  564. update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1.
  565. Returns:
  566. Iterable[ProgressType]: An iterable of values taken from the provided sequence.
  567. """
  568. if total is None:
  569. if isinstance(sequence, Sized):
  570. task_total = float(len(sequence))
  571. else:
  572. raise ValueError(
  573. f"unable to get size of {sequence!r}, please specify 'total'"
  574. )
  575. else:
  576. task_total = total
  577. if task_id is None:
  578. task_id = self.add_task(description, total=task_total)
  579. else:
  580. self.update(task_id, total=task_total)
  581. if self.live.auto_refresh:
  582. with _TrackThread(self, task_id, update_period) as track_thread:
  583. for value in sequence:
  584. yield value
  585. track_thread.completed += 1
  586. else:
  587. advance = self.advance
  588. refresh = self.refresh
  589. for value in sequence:
  590. yield value
  591. advance(task_id, 1)
  592. refresh()
  593. def start_task(self, task_id: TaskID) -> None:
  594. """Start a task.
  595. Starts a task (used when calculating elapsed time). You may need to call this manually,
  596. if you called ``add_task`` with ``start=False``.
  597. Args:
  598. task_id (TaskID): ID of task.
  599. """
  600. with self._lock:
  601. task = self._tasks[task_id]
  602. if task.start_time is None:
  603. task.start_time = self.get_time()
  604. def stop_task(self, task_id: TaskID) -> None:
  605. """Stop a task.
  606. This will freeze the elapsed time on the task.
  607. Args:
  608. task_id (TaskID): ID of task.
  609. """
  610. with self._lock:
  611. task = self._tasks[task_id]
  612. current_time = self.get_time()
  613. if task.start_time is None:
  614. task.start_time = current_time
  615. task.stop_time = current_time
  616. def update(
  617. self,
  618. task_id: TaskID,
  619. *,
  620. total: Optional[float] = None,
  621. completed: Optional[float] = None,
  622. advance: Optional[float] = None,
  623. description: Optional[str] = None,
  624. visible: Optional[bool] = None,
  625. refresh: bool = False,
  626. **fields: Any,
  627. ) -> None:
  628. """Update information associated with a task.
  629. Args:
  630. task_id (TaskID): Task id (returned by add_task).
  631. total (float, optional): Updates task.total if not None.
  632. completed (float, optional): Updates task.completed if not None.
  633. advance (float, optional): Add a value to task.completed if not None.
  634. description (str, optional): Change task description if not None.
  635. visible (bool, optional): Set visible flag if not None.
  636. refresh (bool): Force a refresh of progress information. Default is False.
  637. **fields (Any): Additional data fields required for rendering.
  638. """
  639. with self._lock:
  640. task = self._tasks[task_id]
  641. completed_start = task.completed
  642. if total is not None and total != task.total:
  643. task.total = total
  644. task._reset()
  645. if advance is not None:
  646. task.completed += advance
  647. if completed is not None:
  648. task.completed = completed
  649. if description is not None:
  650. task.description = description
  651. if visible is not None:
  652. task.visible = visible
  653. task.fields.update(fields)
  654. update_completed = task.completed - completed_start
  655. current_time = self.get_time()
  656. old_sample_time = current_time - self.speed_estimate_period
  657. _progress = task._progress
  658. popleft = _progress.popleft
  659. while _progress and _progress[0].timestamp < old_sample_time:
  660. popleft()
  661. while len(_progress) > 1000:
  662. popleft()
  663. if update_completed > 0:
  664. _progress.append(ProgressSample(current_time, update_completed))
  665. if task.completed >= task.total and task.finished_time is None:
  666. task.finished_time = task.elapsed
  667. if refresh:
  668. self.refresh()
  669. def reset(
  670. self,
  671. task_id: TaskID,
  672. *,
  673. start: bool = True,
  674. total: Optional[float] = None,
  675. completed: int = 0,
  676. visible: Optional[bool] = None,
  677. description: Optional[str] = None,
  678. **fields: Any,
  679. ) -> None:
  680. """Reset a task so completed is 0 and the clock is reset.
  681. Args:
  682. task_id (TaskID): ID of task.
  683. start (bool, optional): Start the task after reset. Defaults to True.
  684. total (float, optional): New total steps in task, or None to use current total. Defaults to None.
  685. completed (int, optional): Number of steps completed. Defaults to 0.
  686. **fields (str): Additional data fields required for rendering.
  687. """
  688. current_time = self.get_time()
  689. with self._lock:
  690. task = self._tasks[task_id]
  691. task._reset()
  692. task.start_time = current_time if start else None
  693. if total is not None:
  694. task.total = total
  695. task.completed = completed
  696. if visible is not None:
  697. task.visible = visible
  698. if fields:
  699. task.fields = fields
  700. if description is not None:
  701. task.description = description
  702. task.finished_time = None
  703. self.refresh()
  704. def advance(self, task_id: TaskID, advance: float = 1) -> None:
  705. """Advance task by a number of steps.
  706. Args:
  707. task_id (TaskID): ID of task.
  708. advance (float): Number of steps to advance. Default is 1.
  709. """
  710. current_time = self.get_time()
  711. with self._lock:
  712. task = self._tasks[task_id]
  713. completed_start = task.completed
  714. task.completed += advance
  715. update_completed = task.completed - completed_start
  716. old_sample_time = current_time - self.speed_estimate_period
  717. _progress = task._progress
  718. popleft = _progress.popleft
  719. while _progress and _progress[0].timestamp < old_sample_time:
  720. popleft()
  721. while len(_progress) > 1000:
  722. popleft()
  723. _progress.append(ProgressSample(current_time, update_completed))
  724. if task.completed >= task.total and task.finished_time is None:
  725. task.finished_time = task.elapsed
  726. task.finished_speed = task.speed
  727. def refresh(self) -> None:
  728. """Refresh (render) the progress information."""
  729. if not self.disable and self.live.is_started:
  730. self.live.refresh()
  731. def get_renderable(self) -> RenderableType:
  732. """Get a renderable for the progress display."""
  733. renderable = Group(*self.get_renderables())
  734. return renderable
  735. def get_renderables(self) -> Iterable[RenderableType]:
  736. """Get a number of renderables for the progress display."""
  737. table = self.make_tasks_table(self.tasks)
  738. yield table
  739. def make_tasks_table(self, tasks: Iterable[Task]) -> Table:
  740. """Get a table to render the Progress display.
  741. Args:
  742. tasks (Iterable[Task]): An iterable of Task instances, one per row of the table.
  743. Returns:
  744. Table: A table instance.
  745. """
  746. table_columns = (
  747. (
  748. Column(no_wrap=True)
  749. if isinstance(_column, str)
  750. else _column.get_table_column().copy()
  751. )
  752. for _column in self.columns
  753. )
  754. table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand)
  755. for task in tasks:
  756. if task.visible:
  757. table.add_row(
  758. *(
  759. (
  760. column.format(task=task)
  761. if isinstance(column, str)
  762. else column(task)
  763. )
  764. for column in self.columns
  765. )
  766. )
  767. return table
  768. def __rich__(self) -> RenderableType:
  769. """Makes the Progress class itself renderable."""
  770. with self._lock:
  771. return self.get_renderable()
  772. def add_task(
  773. self,
  774. description: str,
  775. start: bool = True,
  776. total: float = 100.0,
  777. completed: int = 0,
  778. visible: bool = True,
  779. **fields: Any,
  780. ) -> TaskID:
  781. """Add a new 'task' to the Progress display.
  782. Args:
  783. description (str): A description of the task.
  784. start (bool, optional): Start the task immediately (to calculate elapsed time). If set to False,
  785. you will need to call `start` manually. Defaults to True.
  786. total (float, optional): Number of total steps in the progress if know. Defaults to 100.
  787. completed (int, optional): Number of steps completed so far.. Defaults to 0.
  788. visible (bool, optional): Enable display of the task. Defaults to True.
  789. **fields (str): Additional data fields required for rendering.
  790. Returns:
  791. TaskID: An ID you can use when calling `update`.
  792. """
  793. with self._lock:
  794. task = Task(
  795. self._task_index,
  796. description,
  797. total,
  798. completed,
  799. visible=visible,
  800. fields=fields,
  801. _get_time=self.get_time,
  802. _lock=self._lock,
  803. )
  804. self._tasks[self._task_index] = task
  805. if start:
  806. self.start_task(self._task_index)
  807. new_task_index = self._task_index
  808. self._task_index = TaskID(int(self._task_index) + 1)
  809. self.refresh()
  810. return new_task_index
  811. def remove_task(self, task_id: TaskID) -> None:
  812. """Delete a task if it exists.
  813. Args:
  814. task_id (TaskID): A task ID.
  815. """
  816. with self._lock:
  817. del self._tasks[task_id]
  818. if __name__ == "__main__": # pragma: no coverage
  819. import random
  820. import time
  821. from .panel import Panel
  822. from .rule import Rule
  823. from .syntax import Syntax
  824. from .table import Table
  825. syntax = Syntax(
  826. '''def loop_last(values: Iterable[T]) -> Iterable[Tuple[bool, T]]:
  827. """Iterate and generate a tuple with a flag for last value."""
  828. iter_values = iter(values)
  829. try:
  830. previous_value = next(iter_values)
  831. except StopIteration:
  832. return
  833. for value in iter_values:
  834. yield False, previous_value
  835. previous_value = value
  836. yield True, previous_value''',
  837. "python",
  838. line_numbers=True,
  839. )
  840. table = Table("foo", "bar", "baz")
  841. table.add_row("1", "2", "3")
  842. progress_renderables = [
  843. "Text may be printed while the progress bars are rendering.",
  844. Panel("In fact, [i]any[/i] renderable will work"),
  845. "Such as [magenta]tables[/]...",
  846. table,
  847. "Pretty printed structures...",
  848. {"type": "example", "text": "Pretty printed"},
  849. "Syntax...",
  850. syntax,
  851. Rule("Give it a try!"),
  852. ]
  853. from itertools import cycle
  854. examples = cycle(progress_renderables)
  855. console = Console(record=True)
  856. with Progress(
  857. SpinnerColumn(),
  858. TextColumn("[progress.description]{task.description}"),
  859. BarColumn(),
  860. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  861. TimeRemainingColumn(),
  862. TimeElapsedColumn(),
  863. console=console,
  864. transient=True,
  865. ) as progress:
  866. task1 = progress.add_task("[red]Downloading", total=1000)
  867. task2 = progress.add_task("[green]Processing", total=1000)
  868. task3 = progress.add_task("[yellow]Thinking", total=1000, start=False)
  869. while not progress.finished:
  870. progress.update(task1, advance=0.5)
  871. progress.update(task2, advance=0.3)
  872. time.sleep(0.01)
  873. if random.randint(0, 100) < 1:
  874. progress.log(next(examples))