aboutsummaryrefslogtreecommitdiff
blob: 18b2f89d4766f25f75a9bb3712e201a5811a6d92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""Generic support for running commandline tools."""

import io
import logging
import os
import sys
import traceback
from contextlib import nullcontext
from functools import partial
from signal import SIG_DFL, SIGINT, SIGPIPE, signal

from .. import formatters
from ..log import suppress_logging
from .exceptions import ExitException, find_user_exception


class Tool:
    """Abstraction for commandline tools."""

    def __init__(self, parser, outfile=None, errfile=None):
        """Initialize the utility to run.

        :type parser: :obj:`ArgumentParser` subclass instance
        :param parser: argparser instance defining valid arguments for the given tool.
        :type outfile: file-like object
        :param outfile: File to use for stdout, defaults to C{sys.stdout}.
        :type errfile: file-like object
        :param errfile: File to use for stderr, defaults to C{sys.stderr}.
        """
        if parser is None:
            raise ValueError("invalid argparser")
        self.parser = parser
        self.options = None

        if outfile is None:
            if not sys.stdout.isatty() and sys.stdout == sys.__stdout__:
                # if redirecting/piping stdout use line buffering, skip if
                # stdout has been set to some non-standard object
                outfile = os.fdopen(sys.stdout.fileno(), "w", 1)
            else:
                outfile = sys.stdout
        if errfile is None:
            errfile = sys.stderr

        out_fd = err_fd = None
        if hasattr(outfile, "fileno") and hasattr(errfile, "fileno"):
            # annoyingly, fileno can exist but through unsupport
            try:
                out_fd, err_fd = outfile.fileno(), errfile.fileno()
            except (io.UnsupportedOperation, IOError):
                pass

        if out_fd is not None and err_fd is not None:
            out_stat, err_stat = os.fstat(out_fd), os.fstat(err_fd)
            if (
                out_stat.st_dev == err_stat.st_dev
                and out_stat.st_ino == err_stat.st_ino
                and not errfile.isatty()
            ):
                # they're the same underlying fd.  thus
                # point the handles at the same so we don't
                # get intermixed buffering issues.
                errfile = outfile

        self._outfile = outfile
        self._errfile = errfile
        self.out = self.parser.out = formatters.PlainTextFormatter(outfile)
        self.err = self.parser.err = formatters.PlainTextFormatter(errfile)
        self.out.verbosity = self.err.verbosity = getattr(self.parser, "verbosity", 0)

    def __call__(self, args=None):
        """Run the utility.

        :type args: sequence of strings
        :param args: arguments to parse, defaulting to C{sys.argv[1:]}.
        """
        self.args = args

        # run script and return exit status
        try:
            ret = self.main()
        except ExitException as e:
            if self.parser.debug:
                raise
            if isinstance(e.code, str):
                self.err.error(e.code)
                e.code = 1
            ret = e.code

        return ret

    def parse_args(self, args=None, namespace=None):
        """Parse the given arguments using argparse.

        :type args: sequence of strings
        :param args: arguments to parse, defaulting to C{sys.argv[1:]}.
        :type namespace: argparse.Namespace object
        :param namespace: Namespace object to use for created attributes.
        """
        try:
            self.pre_parse(args, namespace)
            options = self.parser.parse_args(args=args, namespace=namespace)
            main_func = options.pop("main_func", None)
            if main_func is None:
                raise RuntimeError("argparser missing main method")

            # reconfigure formatters for colored output if enabled
            if getattr(options, "color", True):
                formatter_factory = partial(
                    formatters.get_formatter,
                    force_color=getattr(options, "color", False),
                )
                self.out = formatter_factory(self._outfile)
                self.err = formatter_factory(self._errfile)

            # reconfigure formatters with properly parsed output verbosity
            self.out.verbosity = self.err.verbosity = getattr(options, "verbosity", 0)

            if logging.root.handlers:
                # Remove the default handler.
                logging.root.handlers.pop(0)
            logging.root.addHandler(FormattingHandler(self.err))

            options = self.post_parse(options)
            return options, main_func
        except Exception as e:
            # handle custom execution-related exceptions
            self.handle_exec_exception(e)

    def pre_parse(self, args, namespace):
        """Handle custom options before argparsing."""

    def post_parse(self, options):
        """Handle custom options after argparsing."""
        return options

    def handle_exec_exception(self, e):
        """Handle custom runtime exceptions."""
        if self.parser.debug:
            raise
        # output user error if one exists otherwise show debugging traceback
        exc = find_user_exception(e)
        if exc is not None:
            # allow exception attribute to override user verbosity level
            if getattr(exc, "_verbosity", None) is not None:
                verbosity = exc._verbosity
            else:
                verbosity = getattr(self.parser, "verbosity", 0)
            # output verbose error message if it exists
            if verbosity > 0:
                msg = exc.msg(verbosity).strip("\n")
                if msg:
                    self.err.write(msg)
                    raise SystemExit
            self.parser.error(exc)
        raise

    def main(self):
        """Execute the main script function."""
        exitstatus = -10

        # ignore broken pipes
        signal(SIGPIPE, SIG_DFL)

        # suppress warning level log output and below in quiet mode
        if self.parser.verbosity >= 0 or self.parser.debug:
            suppress_warnings = nullcontext()
        else:
            suppress_warnings = suppress_logging(logging.WARNING)

        try:
            with suppress_warnings:
                self.options, func = self.parse_args(
                    args=self.args, namespace=self.options
                )
                exitstatus = func(self.options, self.out, self.err)
        except SystemExit as e:
            # handle argparse or other third party modules using sys.exit internally
            exitstatus = e.code
        except KeyboardInterrupt:
            self._errfile.write("keyboard interrupted- exiting")
            if self.parser.debug:
                self._errfile.write("\n")
                traceback.print_exc()
            signal(SIGINT, SIG_DFL)
            os.killpg(os.getpgid(0), SIGINT)
        except Exception as e:
            # handle custom execution-related exceptions
            self.out.flush()
            self.err.flush()
            self.handle_exec_exception(e)

        if self.options is not None:
            # set terminal title on exit
            if exitstatus:
                self.out.title(f"{self.options.prog} failed")
            else:
                self.out.title(f"{self.options.prog} succeeded")

        return exitstatus


class FormattingHandler(logging.Handler):
    """Logging handler printing through a formatter."""

    def __init__(self, formatter):
        logging.Handler.__init__(self)
        # "formatter" clashes with a Handler attribute.
        self.out = formatter

    def emit(self, record):
        if record.levelno >= logging.ERROR:
            color = "red"
        elif record.levelno >= logging.WARNING:
            color = "yellow"
        else:
            color = "cyan"
        first_prefix = (
            self.out.fg(color),
            self.out.bold,
            record.levelname,
            self.out.reset,
            " ",
            record.name,
            ": ",
        )
        later_prefix = (len(record.levelname) + len(record.name)) * " " + " : "
        self.out.first_prefix.extend(first_prefix)
        self.out.later_prefix.append(later_prefix)
        try:
            for line in self.format(record).split("\n"):
                self.out.write(line, wrap=True)
        except Exception:
            self.handleError(record)
        finally:
            self.out.later_prefix.pop()
            for _ in range(len(first_prefix)):
                self.out.first_prefix.pop()