diff --git a/CHANGELOG.md b/CHANGELOG.md index e429af4..72eb0e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Unreleased +### Features + +- Add a `dot` table format for exporting query results as Graphviz DOT. + ### Bug Fixes - Expand `~` in configured log file paths before opening the log. diff --git a/litecli/liteclirc b/litecli/liteclirc index ec162d3..cdb7244 100644 --- a/litecli/liteclirc +++ b/litecli/liteclirc @@ -39,7 +39,7 @@ log_level = INFO # Table format. Possible values: # ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl, # rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira, -# vertical, tsv, csv. +# vertical, tsv, csv, dot. # Recommended: ascii table_format = ascii diff --git a/litecli/main.py b/litecli/main.py index fa732c3..9bae466 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -40,6 +40,7 @@ from .key_bindings import cli_bindings from .lexer import LiteCliLexer from .packages import special +from .packages.dot_output import format_dot_output from .packages.filepaths import dir_path_exists from .packages.prompt_utils import confirm, confirm_destructive_query from .packages.special.main import NO_QUERY @@ -60,6 +61,7 @@ def _load_sqlite3() -> Any: _sqlite3 = _load_sqlite3() OperationalError = _sqlite3.OperationalError sqlite_version = _sqlite3.sqlite_version +LOCAL_OUTPUT_FORMATS = ("dot",) # Query tuples are used for maintaining history Query = namedtuple("Query", ["query", "successful", "mutating"]) @@ -89,7 +91,11 @@ def __init__( self.multi_line = c["main"].as_bool("multi_line") self.key_bindings = c["main"]["key_bindings"] special.set_favorite_queries(self.config) - self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + self.local_format_name: str | None = None + config_table_format = c["main"]["table_format"] + self.formatter = TabularOutputFormatter(format_name="ascii" if config_table_format in LOCAL_OUTPUT_FORMATS else config_table_format) + if config_table_format in LOCAL_OUTPUT_FORMATS: + self.local_format_name = config_table_format # self.formatter.litecli = self, ty raises unresolved-attribute, hence use dynamic assignment setattr(self.formatter, "litecli", self) self.syntax_style = c["main"]["syntax_style"] @@ -137,7 +143,7 @@ def __init__( # Initialize completer. self.completer = SQLCompleter( - supported_formats=self.formatter.supported_formats, + supported_formats=self.supported_table_formats(), keyword_casing=keyword_casing, ) self._completer_lock = threading.Lock() @@ -188,13 +194,31 @@ def register_special_commands(self) -> None: case_sensitive=True, ) + def supported_table_formats(self) -> list[str]: + supported_formats = list(self.formatter.supported_formats) + for format_name in LOCAL_OUTPUT_FORMATS: + if format_name not in supported_formats: + supported_formats.append(format_name) + return supported_formats + + def current_table_format(self) -> str: + return self.local_format_name or self.formatter.format_name + + def set_table_format(self, format_name: str) -> None: + if format_name in LOCAL_OUTPUT_FORMATS: + self.local_format_name = format_name + return + + self.formatter.format_name = format_name + self.local_format_name = None + def change_table_format(self, arg: str, **_: Any) -> Generator[tuple[None, None, None, str], None, None]: try: - self.formatter.format_name = arg + self.set_table_format(arg) yield (None, None, None, "Changed table format to {}".format(arg)) except ValueError: msg = "Table format {} not recognized. Allowed formats:".format(arg) - for table_type in self.formatter.supported_formats: + for table_type in self.supported_table_formats(): msg += "\n\t{}".format(table_type) yield (None, None, None, msg) @@ -839,7 +863,8 @@ def run_query(self, query: str, new_line: bool = True) -> None: click.echo(line, nl=new_line) def format_output(self, title: Any, cur: Any, headers: Any, expanded: bool = False, max_width: int | None = None) -> Iterable[str]: - expanded = expanded or self.formatter.format_name == "vertical" + format_name = self.current_table_format() + expanded = expanded or format_name == "vertical" output_iter: Iterable[str] = [] output_kwargs = { @@ -854,6 +879,9 @@ def format_output(self, title: Any, cur: Any, headers: Any, expanded: bool = Fal output_iter = itertools.chain(output_iter, [title]) if cur: + if format_name == "dot": + return itertools.chain(output_iter, format_dot_output(cur, headers or [])) + column_types = None if hasattr(cur, "description"): column_types = [str(col) for col in cur.description] @@ -972,9 +1000,9 @@ def cli( if execute: try: if csv: - litecli.formatter.format_name = "csv" + litecli.set_table_format("csv") elif not table: - litecli.formatter.format_name = "tsv" + litecli.set_table_format("tsv") litecli.run_query(execute) exit(0) @@ -999,9 +1027,9 @@ def cli( new_line = True if csv: - litecli.formatter.format_name = "csv" + litecli.set_table_format("csv") elif not table: - litecli.formatter.format_name = "tsv" + litecli.set_table_format("tsv") litecli.run_query(stdin_text, new_line=new_line) exit(0) diff --git a/litecli/packages/dot_output.py b/litecli/packages/dot_output.py new file mode 100644 index 0000000..fb8785e --- /dev/null +++ b/litecli/packages/dot_output.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Any, Iterable + + +def _dot_value(value: Any) -> str: + if value is None: + return "NULL" + return str(value) + + +def _dot_quote(value: Any) -> str: + text = _dot_value(value) + return '"' + text.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n").replace("\r", "\\r") + '"' + + +def format_dot_output(rows: Iterable[Iterable[Any]], headers: Iterable[str]) -> Iterable[str]: + """Format one-column results as nodes and multi-column results as edges.""" + header_names = list(headers) + + yield "digraph result {" + + if header_names: + yield " // Columns: {}".format(", ".join(header_names)) + + for row in rows: + row_values = list(row) + if not row_values: + continue + + if len(row_values) == 1: + yield " {};".format(_dot_quote(row_values[0])) + continue + + label = "" + if len(row_values) > 2: + label_values = [] + for index, value in enumerate(row_values[2:], start=2): + column_name = header_names[index] if index < len(header_names) else "column{}".format(index + 1) + label_values.append("{}={}".format(column_name, _dot_value(value))) + label = " [label={}]".format(_dot_quote(", ".join(label_values))) + + yield " {} -> {}{};".format(_dot_quote(row_values[0]), _dot_quote(row_values[1]), label) + + yield "}" diff --git a/tests/liteclirc b/tests/liteclirc index 91f7df9..a91b92d 100644 --- a/tests/liteclirc +++ b/tests/liteclirc @@ -32,7 +32,7 @@ log_level = INFO # Table format. Possible values: # ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl, # rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira, -# vertical, tsv, csv. +# vertical, tsv, csv, dot. # Recommended: ascii table_format = ascii diff --git a/tests/test_dot_output.py b/tests/test_dot_output.py new file mode 100644 index 0000000..8e0c681 --- /dev/null +++ b/tests/test_dot_output.py @@ -0,0 +1,39 @@ +from litecli.packages.dot_output import format_dot_output + + +def test_dot_output_formats_edges(): + rows = [("orders", "customers"), ("line_items", "orders")] + headers = ["child", "parent"] + + assert list(format_dot_output(rows, headers)) == [ + "digraph result {", + " // Columns: child, parent", + ' "orders" -> "customers";', + ' "line_items" -> "orders";', + "}", + ] + + +def test_dot_output_formats_nodes_and_escapes_values(): + rows = [('a"b',), ("line\nbreak",), (None,)] + + assert list(format_dot_output(rows, ["name"])) == [ + "digraph result {", + " // Columns: name", + ' "a\\"b";', + ' "line\\nbreak";', + ' "NULL";', + "}", + ] + + +def test_dot_output_uses_extra_columns_as_edge_label(): + rows = [("a", "b", "foreign key")] + headers = ["source", "target", "relation"] + + assert list(format_dot_output(rows, headers)) == [ + "digraph result {", + " // Columns: source, target, relation", + ' "a" -> "b" [label="relation=foreign key"];', + "}", + ] diff --git a/tests/test_main.py b/tests/test_main.py index 21ad5b6..ac08fcd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -134,6 +134,19 @@ def test_batch_mode_csv(executor): assert expected in "".join(result.output) +def test_dot_table_format_is_supported(): + m = LiteCli(liteclirc=default_config_file) + + assert "dot" in m.supported_table_formats() + assert list(m.change_table_format("dot")) == [(None, None, None, "Changed table format to dot")] + assert list(m.format_output(None, [("orders", "customers")], ["source", "target"])) == [ + "digraph result {", + " // Columns: source, target", + ' "orders" -> "customers";', + "}", + ] + + def test_help_strings_end_with_periods(): """Make sure click options have help text that end with a period.""" for param in cli.params: