scripts/dtrx

Sat, 21 Apr 2007 13:09:58 -0400

author
brett
date
Sat, 21 Apr 2007 13:09:58 -0400
branch
trunk
changeset 22
b240777ae53e
parent 20
69c93c3e6972
child 23
039dd321a7d0
permissions
-rwxr-xr-x

[svn] Improve the way we check archive contents. If all the entries look like
they're in ., they really shouldn't count as being in the same directory;
look at the next piece of the path. If the archive only has one
non-directory item, report that more clearly. You'll be able to tell by
whether or not there's a trailing slash in the prompt.

Improve the tests for doing straight decompression, and seek to the
beginning of the archive before we start writing to the file -- otherwise,
we write 0-byte files.

Lots of new ideas in the TODO. I think I'll do another release once
recursion is interactive.

#!/usr/bin/env python
#
# dtrx -- Intelligently extract various archive types.
# Copyright (c) 2006 Brett Smith <brettcsmith@brettcsmith.org>.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, 5th Floor, Boston, MA, 02111.

import errno
import logging
import mimetypes
import optparse
import os
import stat
import subprocess
import sys
import tempfile
import textwrap

from cStringIO import StringIO

VERSION = "4.0"
VERSION_BANNER = """dtrx version %s
Copyright (c) 2006 Brett Smith <brettcsmith@brettcsmith.org>

This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 2 of the License, or (at your
option) any later version.

This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
Public License for more details.""" % (VERSION,)

MATCHING_DIRECTORY = 1
ONE_ENTRY = 2
BOMB = 3
EMPTY = 4
ONE_ENTRY_KNOWN = 5

EXTRACT_HERE = 1
EXTRACT_WRAP = 2
EXTRACT_RENAME = 3

mimetypes.encodings_map.setdefault('.bz2', 'bzip2')
mimetypes.types_map['.exe'] = 'application/x-msdos-program'

def run_command(command, description, stdout=None, stderr=None, stdin=None):
    process = subprocess.Popen(command, stdin=stdin, stdout=stdout,
                               stderr=stderr)
    status = process.wait()
    for pipe in (process.stdout, process.stderr):
        try:
            pipe.close()
        except AttributeError:
            pass
    if status != 0:
        return ("%s error: '%s' returned status code %s" %
                (description, ' '.join(command), status))
    return None

class FilenameChecker(object):
    def __init__(self, original_name):
        self.original_name = original_name

    def is_free(self, filename):
        return not os.path.exists(filename)

    def check(self):
        for suffix in [''] + ['.%s' % (x,) for x in range(1, 10)]:
            filename = '%s%s' % (self.original_name, suffix)
            if self.is_free(filename):
                return filename
        raise ValueError("all alternatives for name %s taken" %
                         (self.original_name,))
        

class DirectoryChecker(FilenameChecker):
    def is_free(self, filename):
        try:
            os.mkdir(filename)
        except OSError, error:
            if error.errno == errno.EEXIST:
                return False
            raise
        return True


class ExtractorError(Exception):
    pass


class ProcessStreamer(object):
    def __init__(self, command, stdin, description="checking contents",
                 stderr=None):
        self.process = subprocess.Popen(command, bufsize=1, stdin=stdin,
                                        stdout=subprocess.PIPE, stderr=stderr)
        self.command = ' '.join(command)
        self.description = description

    def __iter__(self):
        return self

    def next(self):
        line = self.process.stdout.readline()
        if line:
            return line.rstrip('\n')
        else:
            raise StopIteration

    def stop(self):
        while self.process.stdout.readline():
            pass
        self.process.stdout.close()
        status = self.process.wait()
        if status != 0:
            raise ExtractorError("%s error: '%s' returned status code %s" %
                                 (self.description, self.command, status))
        try:
            self.process.stderr.close()
        except AttributeError:
            pass
    

class BaseExtractor(object):
    decoders = {'bzip2': 'bzcat', 'gzip': 'zcat', 'compress': 'zcat'}

    name_checker = DirectoryChecker

    def __init__(self, filename, mimetype, encoding):
        if encoding and (not self.decoders.has_key(encoding)):
            raise ValueError("unrecognized encoding %s" % (encoding,))
        self.filename = os.path.realpath(filename)
        self.mimetype = mimetype
        self.encoding = encoding
        self.included_archives = []
        try:
            self.archive = open(filename, 'r')
        except (IOError, OSError), error:
            raise ExtractorError("could not open %s: %s" %
                                 (filename, error.strerror))
        if encoding:
            self.pipe([self.decoders[encoding]], "decoding")
        self.prepare()

    def run(self, command, description="extraction", stdout=None, stderr=None,
            stdin=None):
        error = run_command(command, description, stdout, stderr, stdin)
        if error:
            raise ExtractorError(error)

    def pipe(self, command, description, stderr=None):
        output = tempfile.TemporaryFile()
        self.run(command, description, output, stderr, self.archive)
        self.archive.close()
        self.archive = output
        self.archive.flush()
    
    def prepare(self):
        pass

    def check_included_archive(self, filename):
        if extractor_map.has_key(mimetypes.guess_type(filename)[0]):
            self.included_archives.append(filename)

    def check_first_filename(self, filenames):
        try:
            first_filename = filenames.next()
        except StopIteration:
            filenames.stop()
            return (None, None)
        self.check_included_archive(first_filename)
        parts = first_filename.split('/')
        first_part = [parts[0]]
        if parts[0] == '.':
            first_part.append(parts[1])
        return (first_filename, '/'.join(first_part + ['']))

    def check_second_filename(self, filenames, first_part, first_filename):
        try:
            filename = filenames.next()
        except StopIteration:
            return ONE_ENTRY, first_filename
        self.check_included_archive(filename)
        if not filename.startswith(first_part):
            return BOMB, None
        return None, first_part
        
    def check_contents(self):
        filenames = self.get_filenames()
        first_filename, first_part = self.check_first_filename(filenames)
        if first_filename is None:
            return (EMPTY, None)
        archive_type, type_info = self.check_second_filename(filenames,
                                                             first_part,
                                                             first_filename)
        for filename in filenames:
            self.check_included_archive(filename)
            if (archive_type != BOMB) and (not filename.startswith(first_part)):
                archive_type = BOMB
                type_info = None
        filenames.stop()
        if archive_type is None:
            if self.basename() == first_part[:-1]:
                archive_type = MATCHING_DIRECTORY
            else:
                archive_type = ONE_ENTRY
        return archive_type, type_info

    def basename(self):
        pieces = os.path.basename(self.filename).split('.')
        extension = '.' + pieces[-1]
        if mimetypes.encodings_map.has_key(extension):
            pieces.pop()
            extension = '.' + pieces[-1]
        if (mimetypes.types_map.has_key(extension) or
            mimetypes.common_types.has_key(extension) or
            mimetypes.suffix_map.has_key(extension)):
            pieces.pop()
        return '.'.join(pieces)

    def extract(self, path):
        old_path = os.path.realpath(os.curdir)
        os.chdir(path)
        self.archive.seek(0, 0)
        self.extract_archive()
        os.chdir(old_path)
    

class TarExtractor(BaseExtractor):
    def get_filenames(self):
        self.archive.seek(0, 0)
        return ProcessStreamer(['tar', '-t'], self.archive)

    def extract_archive(self): 
        self.run(['tar', '-x'], stdin=self.archive)
        
        
class ZipExtractor(BaseExtractor):
    def __init__(self, filename, mimetype, encoding):
        self.filename = os.path.realpath(filename)
        self.mimetype = mimetype
        self.encoding = encoding
        self.included_archives = []
        self.archive = StringIO()

    def get_filenames(self):
        self.archive.seek(0, 0)
        return ProcessStreamer(['zipinfo', '-1', self.filename], None)

    def extract_archive(self):
        self.run(['unzip', '-q', self.filename])


class CpioExtractor(BaseExtractor):
    def get_filenames(self):
        self.archive.seek(0, 0)
        return ProcessStreamer(['cpio', '-t'], self.archive,
                               stderr=subprocess.PIPE)

    def extract_archive(self):
        self.run(['cpio', '-i', '--make-directories',
                  '--no-absolute-filenames'],
                 stderr=subprocess.PIPE, stdin=self.archive)


class RPMExtractor(CpioExtractor):
    def prepare(self):
        self.pipe(['rpm2cpio', '-'], "rpm2cpio")

    def basename(self):
        pieces = os.path.basename(self.filename).split('.')
        if len(pieces) == 1:
            return pieces[0]
        elif pieces[-1] != 'rpm':
            return BaseExtractor.basename(self)
        pieces.pop()
        if len(pieces) == 1:
            return pieces[0]
        elif len(pieces[-1]) < 8:
            pieces.pop()
        return '.'.join(pieces)

    def check_contents(self):
        CpioExtractor.check_contents(self)
        return (BOMB, None)


class DebExtractor(TarExtractor):
    def prepare(self):
        self.pipe(['ar', 'p', self.filename, 'data.tar.gz'],
                  "data.tar.gz extraction")
        self.archive.seek(0, 0)
        self.pipe(['zcat'], "data.tar.gz decompression")

    def basename(self):
        pieces = os.path.basename(self.filename).split('_')
        if len(pieces) == 1:
            return pieces[0]
        last_piece = pieces.pop()
        if (len(last_piece) > 10) or (not last_piece.endswith('.deb')):
            return BaseExtractor.basename(self)
        return '_'.join(pieces)

    def check_contents(self):
        TarExtractor.check_contents(self)
        return (BOMB, None)
        

class CompressionExtractor(BaseExtractor):
    name_checker = FilenameChecker

    def basename(self):
        pieces = os.path.basename(self.filename).split('.')
        extension = '.' + pieces[-1]
        if mimetypes.encodings_map.has_key(extension):
            pieces.pop()
        return '.'.join(pieces)

    def get_filenames(self):
        yield self.basename()

    def check_contents(self):
        return (ONE_ENTRY_KNOWN, self.basename())

    def extract(self, path):
        output = open(path, 'w')
        self.archive.seek(0, 0)
        self.run(['cat'], "output write", stdin=self.archive, stdout=output)
        output.close()
        

class BaseHandler(object):
    def __init__(self, extractor, contents, content_name, options):
        self.logger = logging.getLogger('dtrx-log')
        self.extractor = extractor
        self.contents = contents
        self.content_name = content_name
        self.options = options
        self.target = None

    def extract(self):
        try:
            self.extractor.extract(self.target)
        except (ExtractorError, IOError, OSError), error:
            return str(error)
        
    def cleanup(self):
        if self.target is None:
            return
        command = 'find'
        status = subprocess.call(['find', self.target, '-type', 'd',
                                  '-exec', 'chmod', 'u+rwx', '{}', ';'])
        if status == 0:
            command = 'chmod'
            status = subprocess.call(['chmod', '-R', 'u+rw', self.target])
        if status != 0:
            return "%s returned with exit status %s" % (command, status)


# The "where to extract" table, with options and archive types.
# This dictates the contents of each can_handle method.
#
#         Flat           Overwrite            None
# File    basename       basename             FilenameChecked
# Match   .              .                    tempdir + checked
# Bomb    .              basename             DirectoryChecked

class FlatHandler(BaseHandler):
    def can_handle(contents, options):
        return ((options.flat and (contents != ONE_ENTRY_KNOWN)) or
                (options.overwrite and (contents == MATCHING_DIRECTORY)))
    can_handle = staticmethod(can_handle)

    def __init__(self, extractor, contents, content_name, options):
        BaseHandler.__init__(self, extractor, contents, content_name, options)
        self.target = '.'

    def cleanup(self):
        for filename in self.extractor.get_filenames():
            stat_info = os.stat(filename)
            perms = stat.S_IRUSR | stat.S_IWUSR
            if stat.S_ISDIR(stat_info.st_mode):
                perms |= stat.S_IXUSR
            os.chmod(filename, stat_info.st_mode | perms)


class OverwriteHandler(BaseHandler):
    def can_handle(contents, options):
        return ((options.flat and (contents == ONE_ENTRY_KNOWN)) or
                (options.overwrite and (contents != MATCHING_DIRECTORY)))
    can_handle = staticmethod(can_handle)

    def __init__(self, extractor, contents, content_name, options):
        BaseHandler.__init__(self, extractor, contents, content_name, options)
        self.target = self.extractor.basename()
        

class MatchHandler(BaseHandler):
    def can_handle(contents, options):
        return ((contents == MATCHING_DIRECTORY) or
                ((contents == ONE_ENTRY) and
                 (options.onedir_policy in (EXTRACT_RENAME, EXTRACT_HERE))))
    can_handle = staticmethod(can_handle)

    def extract(self):
        if self.contents == MATCHING_DIRECTORY:
            basename = destination = self.extractor.basename()
        elif self.options.onedir_policy == EXTRACT_HERE:
            basename = destination = self.content_name.rstrip('/')
        else:
            basename = self.content_name.rstrip('/')
            destination = self.extractor.basename()
        self.target = tempdir = tempfile.mkdtemp(dir='.')
        result = BaseHandler.extract(self)
        if result is None:
            checker = self.extractor.name_checker(destination)
            self.target = checker.check()
            os.rename(os.path.join(tempdir, basename), self.target)
            os.rmdir(tempdir)
        return result


class EmptyHandler(object):
    def can_handle(contents, options):
        return contents == EMPTY
    can_handle = staticmethod(can_handle)

    def __init__(self, extractor, contents, content_name, options): pass
    def extract(self): pass
    def cleanup(self): pass


class BombHandler(BaseHandler):
    def can_handle(contents, options):
        return True
    can_handle = staticmethod(can_handle)

    def __init__(self, extractor, contents, content_name, options):
        BaseHandler.__init__(self, extractor, contents, content_name, options)
        checker = self.extractor.name_checker(self.extractor.basename())
        self.target = checker.check()

        
extractor_map = {'application/x-tar': TarExtractor,
                 'application/zip': ZipExtractor,
                 'application/x-msdos-program': ZipExtractor,
                 'application/x-debian-package': DebExtractor,
                 'application/x-redhat-package-manager': RPMExtractor,
                 'application/x-rpm': RPMExtractor,
                 'application/x-cpio': CpioExtractor}

handlers = [FlatHandler, OverwriteHandler, MatchHandler, EmptyHandler,
            BombHandler]

class ExtractorApplication(object):
    policy_answers = {'h': EXTRACT_HERE, 'i': EXTRACT_WRAP,
                      'r': EXTRACT_RENAME, '': EXTRACT_WRAP}

    def __init__(self, arguments):
        self.parse_options(arguments)
        self.setup_logger()
        self.successes = []
        self.failures = []

    def parse_options(self, arguments):
        parser = optparse.OptionParser(
            usage="%prog [options] archive [archive2 ...]",
            description="Intelligent archive extractor",
            version=VERSION_BANNER
            )
        parser.add_option('-r', '--recursive', dest='recursive',
                          action='store_true', default=False,
                          help='extract archives contained in the ones listed')
        parser.add_option('-q', '--quiet', dest='quiet',
                          action='count', default=3,
                          help='suppress warning/error messages')
        parser.add_option('-v', '--verbose', dest='verbose',
                          action='count', default=0,
                          help='be verbose/print debugging information')
        parser.add_option('-o', '--overwrite', dest='overwrite',
                          action='store_true', default=False,
                          help='overwrite any existing target directory')
        parser.add_option('-f', '--flat', '--no-directory', dest='flat',
                          action='store_true', default=False,
                          help="don't put contents in their own directory")
        parser.add_option('-l', '-t', '--list', '--table', dest='show_list',
                          action='store_true', default=False,
                          help="list contents of archives on standard output")
        parser.add_option('-n', '--noninteractive', dest='batch',
                          action='store_true', default=False,
                          help="don't ask how to handle special cases")
        self.options, filenames = parser.parse_args(arguments)
        self.options.onedir_policy = self.policy_answers['']
        if not filenames:
            parser.error("you did not list any archives")
        self.archives = {os.path.realpath(os.curdir): filenames}

    def setup_logger(self):
        self.logger = logging.getLogger('dtrx-log')
        handler = logging.StreamHandler()
        # WARNING is the default.
        handler.setLevel(10 * (self.options.quiet - self.options.verbose))
        formatter = logging.Formatter("dtrx: %(levelname)s: %(message)s")
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)

    def ask_question(self, question, answers):
        if self.options.batch:
            return answers['']
        last_line = question.pop()
        while True:
            print "\n".join(question)
            try:
                answer = raw_input(last_line)
            except EOFError:
                return answers['']
            try:
                return answers[answer.lower()]
            except KeyError:
                print

    def get_extractor(self):
        mimetype, encoding = mimetypes.guess_type(self.current_filename)
        try:
            extractor = extractor_map[mimetype]
        except KeyError:
            if encoding:
                extractor = CompressionExtractor
            else:
                return "not a known archive type"
        try:
            self.current_extractor = extractor(self.current_filename, mimetype,
                                               encoding)
        except ExtractorError, error:
            return str(error)

    def get_handler(self):
        try:
            content, content_name = self.current_extractor.check_contents()
            if content == ONE_ENTRY:
                question = textwrap.wrap("%s contains one entry: %s." %
                                         (self.current_filename, content_name))
                question.extend(["You can:",
                                 " * extract it Inside another directory",
                                 " * extract it and Rename the directory",
                                 " * extract it Here",
                                 "What do you want to do?  (I/r/h) "])
                self.options.onedir_policy = \
                    self.ask_question(question, self.policy_answers)
            for handler in handlers:
                if handler.can_handle(content, self.options):
                    self.current_handler = handler(self.current_extractor,
                                                   content, content_name,
                                                   self.options)
                    break
        except ExtractorError, error:
            return str(error)

    def recurse(self):
        if not self.options.recursive:
            return
        for filename in self.current_extractor.included_archives:
            tail_path, basename = os.path.split(filename)
            directory = os.path.join(self.current_directory,
                                     self.current_handler.target, tail_path)
            self.archives.setdefault(directory, []).append(basename)

    def report(self, function, *args):
        try:
            error = function(*args)
        except (ExtractorError, IOError, OSError), exception:
            error = str(exception)
        if error:
            self.logger.error("%s: %s", self.current_filename, error)
            return False
        return True

    def record_status(self, success):
        if success:
            self.successes.append(self.current_filename)
        else:
            self.failures.append(self.current_filename)

    def extract(self):
        while self.archives:
            self.current_directory, filenames = self.archives.popitem()
            for filename in filenames:
                os.chdir(self.current_directory)
                self.current_filename = filename
                success = (self.report(self.get_extractor) and
                           self.report(self.get_handler))
                if success:
                    for name in 'extract', 'cleanup':
                        success = (self.report(getattr(self.current_handler,
                                                       name)) and success)
                    self.recurse()
                self.record_status(success)

    def show_contents(self):
        for filename in self.current_extractor.get_filenames():
            print filename

    def show_list(self):
        filenames = self.archives.values()[0]
        if len(filenames) > 1:
            header = "%s:\n"
        else:
            header = None
        for filename in filenames:
            if header:
                print header % (filename,),
                header = "\n%s:\n"
            self.current_filename = filename
            success = (self.report(self.get_extractor) and
                       self.report(self.show_contents))
            self.record_status(success)

    def run(self):
        if self.options.show_list:
            self.show_list()
        else:
            self.extract()
        if self.failures:
            return 1
        return 0


if __name__ == '__main__':
    app = ExtractorApplication(sys.argv[1:])
    sys.exit(app.run())

mercurial