scripts/dtrx

Thu, 11 Dec 2008 21:19:26 -0500

author
Brett Smith <brettcsmith@brettcsmith.org>
date
Thu, 11 Dec 2008 21:19:26 -0500
branch
trunk
changeset 100
7353b443dc98
parent 99
1ae3722ca219
child 101
014efef1a48f
permissions
-rwxr-xr-x

Fix crasher bug when extracting empty archives.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# dtrx -- Intelligently extract various archive types.
# Copyright ⓒ 2006, 2007, 2008 Brett Smith <brettcsmith@brettcsmith.org>
# Copyright ⓒ 2008 Peter Kelemen <Peter.Kelemen@gmail.com>
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, see <http://www.gnu.org/licenses/>.

# Python 2.3 string methods: 'rfind', 'rindex', 'rjust', 'rstrip'

import errno
import logging
import mimetypes
import optparse
import os
import re
import shutil
import signal
import stat
import subprocess
import sys
import tempfile
import textwrap
import traceback

try:
    set
except NameError:
    from sets import Set as set

VERSION = "6.3"
VERSION_BANNER = """dtrx version %s
Copyright ⓒ 2006, 2007, 2008 Brett Smith <brettcsmith@brettcsmith.org>
Copyright ⓒ 2008 Peter Kelemen <Peter.Kelemen@gmail.com>

This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
Public License for more details.""" % (VERSION,)

MATCHING_DIRECTORY = 1
ONE_ENTRY_KNOWN = 2
BOMB = 3
EMPTY = 4
ONE_ENTRY_FILE = 'file'
ONE_ENTRY_DIRECTORY = 'directory'

ONE_ENTRY_UNKNOWN = [ONE_ENTRY_FILE, ONE_ENTRY_DIRECTORY]

EXTRACT_HERE = 1
EXTRACT_WRAP = 2
EXTRACT_RENAME = 3

RECURSE_ALWAYS = 1
RECURSE_ONCE = 2
RECURSE_NOT_NOW = 3
RECURSE_NEVER = 4
RECURSE_LIST = 5

mimetypes.encodings_map.setdefault('.bz2', 'bzip2')
mimetypes.encodings_map.setdefault('.lzma', 'lzma')
mimetypes.types_map.setdefault('.gem', 'application/x-ruby-gem')

logger = logging.getLogger('dtrx-log')

class FilenameChecker(object):
    free_func = os.open
    free_args = (os.O_CREAT | os.O_EXCL,)
    free_close = os.close

    def __init__(self, original_name):
        self.original_name = original_name

    def is_free(self, filename):
        try:
            result = self.free_func(filename, *self.free_args)
        except OSError, error:
            if error.errno == errno.EEXIST:
                return False
            raise
        if self.free_close:
            self.free_close(result)
        return True

    def create(self):
        fd, filename = tempfile.mkstemp(prefix=self.original_name + '.',
                                        dir='.')
        os.close(fd)
        return filename

    def check(self):
        for suffix in [''] + ['.%s' % (x,) for x in range(1, 10)]:
            filename = '%s%s' % (self.original_name, suffix)
            if self.is_free(filename):
                return filename
        return self.create()


class DirectoryChecker(FilenameChecker):
    free_func = os.mkdir
    free_args = ()
    free_close = None

    def create(self):
        return tempfile.mkdtemp(prefix=self.original_name + '.', dir='.')


class ExtractorError(Exception):
    pass


class ExtractorUnusable(Exception):
    pass


EXTRACTION_ERRORS = (ExtractorError, ExtractorUnusable, OSError, IOError)

class BaseExtractor(object):
    decoders = {'bzip2': 'bzcat', 'gzip': 'zcat', 'compress': 'zcat',
                'lzma': 'lzcat'}
    name_checker = DirectoryChecker

    def __init__(self, filename, encoding):
        if encoding and (not self.decoders.has_key(encoding)):
            raise ValueError("unrecognized encoding %s" % (encoding,))
        self.filename = os.path.realpath(filename)
        self.encoding = encoding
        self.file_count = 0
        self.included_archives = []
        self.target = None
        self.content_type = None
        self.content_name = None
        self.pipes = []
        self.stderr = tempfile.TemporaryFile()
        self.exit_codes = []
        try:
            self.archive = open(filename, 'r')
        except (IOError, OSError), error:
            raise ExtractorError("could not open %s: %s" %
                                 (filename, error.strerror))
        if encoding:
            self.pipe([self.decoders[encoding]], "decoding")
        self.prepare()

    def pipe(self, command, description="extraction"):
        self.pipes.append((command, description))

    def first_bad_exit_code(self):
        for index, code in enumerate(self.exit_codes):
            if code != 0:
                return index
        return None

    def run_pipes(self, final_stdout=None):
        if not self.pipes:
            return
        elif final_stdout is None:
            # FIXME: Buffering this might be dumb.
            final_stdout = tempfile.TemporaryFile()
        num_pipes = len(self.pipes)
        last_pipe = num_pipes - 1
        processes = []
        for index, command in enumerate([pipe[0] for pipe in self.pipes]):
            if index == 0:
                stdin = self.archive
            else:
                stdin = processes[-1].stdout
            if index == last_pipe:
                stdout = final_stdout
            else:
                stdout = subprocess.PIPE
            try:
                processes.append(subprocess.Popen(command, stdin=stdin,
                                                  stdout=stdout,
                                                  stderr=self.stderr))
            except OSError, error:
                if error.errno == errno.ENOENT:
                    raise ExtractorUnusable("could not run %s" % (command[0],))
                raise
        self.exit_codes = [pipe.wait() for pipe in processes]
        self.archive.close()
        for index in range(last_pipe):
            processes[index].stdout.close()
        self.archive = final_stdout

    def prepare(self):
        pass

    def check_included_archives(self):
        if (self.content_name is None) or (not self.content_name.endswith('/')):
            self.included_root = './'
        else:
            self.included_root = self.content_name
        start_index = len(self.included_root)
        for path, dirname, filenames in os.walk(self.included_root):
            self.file_count += len(filenames)
            path = path[start_index:]
            for filename in filenames:
                if (ExtractorBuilder.try_by_mimetype(filename) or
                    ExtractorBuilder.try_by_extension(filename)):
                    self.included_archives.append(os.path.join(path, filename))

    def check_contents(self):
        if not self.contents:
            self.content_type = EMPTY
        elif len(self.contents) == 1:
            if self.basename() == self.contents[0]:
                self.content_type = MATCHING_DIRECTORY
            elif os.path.isdir(self.contents[0]):
                self.content_type = ONE_ENTRY_DIRECTORY
            else:
                self.content_type = ONE_ENTRY_FILE
            self.content_name = self.contents[0]
            if os.path.isdir(self.contents[0]):
                self.content_name += '/'
        else:
            self.content_type = BOMB
        self.check_included_archives()

    def basename(self):
        pieces = os.path.basename(self.filename).split('.')
        extension = '.' + pieces[-1]
        if mimetypes.encodings_map.has_key(extension):
            pieces.pop()
            extension = '.' + pieces[-1]
        if (mimetypes.types_map.has_key(extension) or
            mimetypes.common_types.has_key(extension) or
            mimetypes.suffix_map.has_key(extension)):
            pieces.pop()
        return '.'.join(pieces)

    def get_stderr(self):
        self.stderr.seek(0, 0)
        errors = self.stderr.read(-1)
        self.stderr.close()
        return errors

    def check_success(self, got_output):
        error_index = self.first_bad_exit_code()
        if (not got_output) and (error_index is not None):
            command = ' '.join(self.pipes[error_index][0])
            raise ExtractorError("%s error: '%s' returned status code %s" %
                                 (self.pipes[error_index][1], command,
                                  self.exit_codes[error_index]))
        
    def extract_archive(self):
        self.pipe(self.extract_pipe)
        self.run_pipes()

    def extract(self):
        try:
            self.target = tempfile.mkdtemp(prefix='.dtrx-', dir='.')
        except (OSError, IOError), error:
            raise ExtractorError("cannot extract here: %s" % (error.strerror,))
        old_path = os.path.realpath(os.curdir)
        os.chdir(self.target)
        try:
            self.archive.seek(0, 0)
            self.extract_archive()
            self.contents = os.listdir('.')
            self.check_contents()
            self.check_success(self.content_type != EMPTY)
        except EXTRACTION_ERRORS:
            self.archive.close()
            os.chdir(old_path)
            shutil.rmtree(self.target, ignore_errors=True)
            raise
        self.archive.close()
        os.chdir(old_path)

    def get_filenames(self):
        self.pipe(self.list_pipe, "listing")
        self.run_pipes()
        self.check_success(False)
        self.archive.seek(0, 0)
        while True:
            line = self.archive.readline()
            if not line:
                self.archive.close()
                return
            yield line.rstrip('\n')
    

class CompressionExtractor(BaseExtractor):
    file_type = 'compressed file'
    name_checker = FilenameChecker

    def basename(self):
        pieces = os.path.basename(self.filename).split('.')
        extension = '.' + pieces[-1]
        if mimetypes.encodings_map.has_key(extension):
            pieces.pop()
        return '.'.join(pieces)

    def get_filenames(self):
        # This code used to just immediately yield the basename, under the
        # assumption that that would be the filename.  However, if that
        # happens, dtrx -l will report this as a valid result for files with
        # compression extensions, even if those files shouldn't actually be
        # handled this way.  So, we call out to the file command to do a quick
        # check and make sure this actually looks like a compressed file.
        if 'compress' not in [match[0] for match in
                              ExtractorBuilder.try_by_magic(self.filename)]:
            raise ExtractorError("doesn't look like a compressed file")
        yield self.basename()

    def extract(self):
        self.content_type = ONE_ENTRY_KNOWN
        self.content_name = self.basename()
        self.contents = None
        self.included_root = './'
        try:
            output_fd, self.target = tempfile.mkstemp(prefix='.dtrx-', dir='.')
        except (OSError, IOError), error:
            raise ExtractorError("cannot extract here: %s" % (error.strerror,))
        self.run_pipes(output_fd)
        os.close(output_fd)
        try:
            self.check_success(os.stat(self.target)[stat.ST_SIZE] > 0)
        except EXTRACTION_ERRORS:
            os.unlink(self.target)
            raise
            
class TarExtractor(BaseExtractor):
    file_type = 'tar file'
    extract_pipe = ['tar', '-x']
    list_pipe = ['tar', '-t']
        
        
class CpioExtractor(BaseExtractor):
    file_type = 'cpio file'
    extract_pipe = ['cpio', '-i', '--make-directories', '--quiet',
                   '--no-absolute-filenames']
    list_pipe = ['cpio', '-t', '--quiet']


class RPMExtractor(CpioExtractor):
    file_type = 'RPM'

    def prepare(self):
        self.pipe(['rpm2cpio', '-'], "rpm2cpio")

    def basename(self):
        pieces = os.path.basename(self.filename).split('.')
        if len(pieces) == 1:
            return pieces[0]
        elif pieces[-1] != 'rpm':
            return BaseExtractor.basename(self)
        pieces.pop()
        if len(pieces) == 1:
            return pieces[0]
        elif len(pieces[-1]) < 8:
            pieces.pop()
        return '.'.join(pieces)

    def check_contents(self):
        self.check_included_archives()
        self.content_type = BOMB


class DebExtractor(TarExtractor):
    file_type = 'Debian package'

    def prepare(self):
        self.pipe(['ar', 'p', self.filename, 'data.tar.gz'],
                  "data.tar.gz extraction")
        self.pipe(['zcat'], "data.tar.gz decompression")

    def basename(self):
        pieces = os.path.basename(self.filename).split('_')
        if len(pieces) == 1:
            return pieces[0]
        last_piece = pieces.pop()
        if (len(last_piece) > 10) or (not last_piece.endswith('.deb')):
            return BaseExtractor.basename(self)
        return '_'.join(pieces)

    def check_contents(self):
        self.check_included_archives()
        self.content_type = BOMB


class DebMetadataExtractor(DebExtractor):
    def prepare(self):
        self.pipe(['ar', 'p', self.filename, 'control.tar.gz'],
                  "control.tar.gz extraction")
        self.pipe(['zcat'], "control.tar.gz decompression")


class GemExtractor(TarExtractor):
    file_type = 'Ruby gem'

    def prepare(self):
        self.pipe(['tar', '-xO', 'data.tar.gz'], "data.tar.gz extraction")
        self.pipe(['zcat'], "data.tar.gz decompression")

    def check_contents(self):
        self.check_included_archives()
        self.content_type = BOMB


class GemMetadataExtractor(CompressionExtractor):
    file_type = 'Ruby gem'

    def prepare(self):
        self.pipe(['tar', '-xO', 'metadata.gz'], "metadata.gz extraction")
        self.pipe(['zcat'], "metadata.gz decompression")

    def basename(self):
        return os.path.basename(self.filename) + '-metadata.txt'


class NoPipeExtractor(BaseExtractor):
    # Some extraction tools won't accept the archive from stdin.  With
    # these, the piping infrastructure we normally set up generally doesn't
    # work, at least at first.  We can still use most of it; we just don't
    # want to seed self.archive with the archive file, since that sucks up
    # memory.  So instead we seed it with /dev/null, and specify the
    # filename on the command line as necessary.  We also open the actual
    # file with os.open, to make sure we can actually do it (permissions
    # are good, etc.).  This class doesn't do anything by itself; it's just
    # meant to be a base class for extractors that rely on these dumb
    # tools.
    def __init__(self, filename, encoding):
        os.close(os.open(filename, os.O_RDONLY))
        BaseExtractor.__init__(self, '/dev/null', None)
        self.filename = os.path.realpath(filename)

    def extract_archive(self):
        self.extract_pipe = self.extract_command + [self.filename]
        BaseExtractor.extract_archive(self)

    def get_filenames(self):
        self.list_pipe = self.list_command + [self.filename]
        return BaseExtractor.get_filenames(self)


class ZipExtractor(NoPipeExtractor):
    file_type = 'Zip file'
    extract_command = ['unzip', '-q']
    list_command = ['zipinfo', '-1']


class SevenExtractor(NoPipeExtractor):
    file_type = '7z file'
    extract_command = ['7z', 'x']
    list_command = ['7z', 'l']
    border_re = re.compile('^[- ]+$')

    def get_filenames(self):
        fn_index = None
        for line in NoPipeExtractor.get_filenames(self):
            if self.border_re.match(line):
                if fn_index is not None:
                    break
                else:
                    fn_index = line.rindex(' ') + 1
            elif fn_index is not None:
                yield line[fn_index:]
        self.archive.close()
        

class CABExtractor(NoPipeExtractor):
    file_type = 'CAB archive'
    extract_command = ['cabextract', '-q']
    list_command = ['cabextract', '-l']
    border_re = re.compile(r'^[-\+]+$')

    def get_filenames(self):
        fn_index = None
        filenames = NoPipeExtractor.get_filenames(self)
        for line in filenames:
            if self.border_re.match(line):
                break
        for line in filenames:
            try:
                yield line.split(' | ', 2)[2]
            except IndexError:
                break
        self.archive.close()


class ShieldExtractor(NoPipeExtractor):
    file_type = 'InstallShield archive'
    extract_command = ['unshield', 'x']
    list_command = ['unshield', 'l']
    prefix_re = re.compile(r'^\s+\d+\s+')
    end_re = re.compile(r'^\s+-+\s+-+\s*$')

    def get_filenames(self):
        for line in NoPipeExtractor.get_filenames(self):
            if self.end_re.match(line):
                break
            else:
                match = self.prefix_re.match(line)
                if match:
                    yield line[match.end():]
        self.archive.close()

    def basename(self):
        result = NoPipeExtractor.basename(self)
        if result.endswith('.hdr'):
            result = result[:-4]
        return result


class RarExtractor(NoPipeExtractor):
    file_type = 'RAR archive'
    extract_command = ['unrar', 'x']
    list_command = ['unrar', 'l']
    border_re = re.compile('^-+$')

    def get_filenames(self):
        inside = False
        for line in NoPipeExtractor.get_filenames(self):
            if self.border_re.match(line):
                if inside:
                    break
                else:
                    inside = True
            elif inside:
                yield line.split(' ')[1]
        self.archive.close()


class BaseHandler(object):
    def __init__(self, extractor, options):
        self.extractor = extractor
        self.options = options
        self.target = None

    def handle(self):
        command = 'find'
        status = subprocess.call(['find', self.extractor.target, '-type', 'd',
                                  '-exec', 'chmod', 'u+rwx', '{}', ';'])
        if status == 0:
            command = 'chmod'
            status = subprocess.call(['chmod', '-R', 'u+rwX',
                                      self.extractor.target])
        if status != 0:
            return "%s returned with exit status %s" % (command, status)
        return self.organize()

    def set_target(self, target, checker):
        self.target = checker(target).check()
        if self.target != target:
            logger.warning("extracting %s to %s" %
                           (self.extractor.filename, self.target))


# The "where to extract" table, with options and archive types.
# This dictates the contents of each can_handle method.
#
#         Flat           Overwrite            None
# File    basename       basename             FilenameChecked
# Match   .              .                    tempdir + checked
# Bomb    .              basename             DirectoryChecked

class FlatHandler(BaseHandler):
    def can_handle(contents, options):
        return ((options.flat and (contents != ONE_ENTRY_KNOWN)) or
                (options.overwrite and (contents == MATCHING_DIRECTORY)))
    can_handle = staticmethod(can_handle)

    def organize(self):
        self.target = '.'
        for curdir, dirs, filenames in os.walk(self.extractor.target,
                                               topdown=False):
            path_parts = curdir.split(os.sep)
            if path_parts[0] == '.':
                del path_parts[1]
            else:
                del path_parts[0]
            newdir = os.path.join(*path_parts)
            if not os.path.isdir(newdir):
                os.makedirs(newdir)
            for filename in filenames:
                os.rename(os.path.join(curdir, filename),
                          os.path.join(newdir, filename))
            os.rmdir(curdir)


class OverwriteHandler(BaseHandler):
    def can_handle(contents, options):
        return ((options.flat and (contents == ONE_ENTRY_KNOWN)) or
                (options.overwrite and (contents != MATCHING_DIRECTORY)))
    can_handle = staticmethod(can_handle)

    def organize(self):
        self.target = self.extractor.basename()
        if os.path.isdir(self.target):
            shutil.rmtree(self.target)
        os.rename(self.extractor.target, self.target)
        

class MatchHandler(BaseHandler):
    def can_handle(contents, options):
        return ((contents == MATCHING_DIRECTORY) or
                ((contents in ONE_ENTRY_UNKNOWN) and
                 options.one_entry_policy.ok_for_match()))
    can_handle = staticmethod(can_handle)

    def organize(self):
        source = os.path.join(self.extractor.target,
                              os.listdir(self.extractor.target)[0])
        if os.path.isdir(source):
            checker = DirectoryChecker
        else:
            checker = FilenameChecker
        if self.options.one_entry_policy == EXTRACT_HERE:
            destination = self.extractor.content_name.rstrip('/')
        else:
            destination = self.extractor.basename()
        self.set_target(destination, checker)
        if os.path.isdir(self.extractor.target):
            os.rename(source, self.target)
            os.rmdir(self.extractor.target)
        else:
            os.rename(self.extractor.target, self.target)
        self.extractor.included_root = './'


class EmptyHandler(object):
    target = ''

    def can_handle(contents, options):
        return contents == EMPTY
    can_handle = staticmethod(can_handle)

    def __init__(self, extractor, options): pass
    def handle(self): pass


class BombHandler(BaseHandler):
    def can_handle(contents, options):
        return True
    can_handle = staticmethod(can_handle)

    def organize(self):
        basename = self.extractor.basename()
        self.set_target(basename, self.extractor.name_checker)
        os.rename(self.extractor.target, self.target)

        
class BasePolicy(object):
    try:
        width = int(os.environ['COLUMNS'])
    except (KeyError, ValueError):
        width = 80
    wrapper = textwrap.TextWrapper(width=width - 1)

    def __init__(self, options):
        self.current_policy = None
        if options.batch:
            self.permanent_policy = self.answers['']
        else:
            self.permanent_policy = None

    def wrap(self, question, filename):
        # Note: This function assumes the filename is the first thing in the
        # question text, and that's the only place it appears.
        if len(self.wrapper.wrap(filename + ' a')) > 1:
            return [filename] + self.wrapper.wrap(question[3:])
        return self.wrapper.wrap(question % (filename,))

    def ask_question(self, question):
        question = question + self.choices
        while True:
            print "\n".join(question)
            try:
                answer = raw_input(self.prompt)
            except EOFError:
                return self.answers['']
            try:
                return self.answers[answer.lower()]
            except KeyError:
                print

    def __cmp__(self, other):
        return cmp(self.current_policy, other)
    

class OneEntryPolicy(BasePolicy):
    answers = {'h': EXTRACT_HERE, 'i': EXTRACT_WRAP, 'r': EXTRACT_RENAME,
               '': EXTRACT_WRAP}
    choices = ["You can:",
               " * extract it Inside another directory",
               " * extract it and Rename the directory",
               " * extract it Here"]
    prompt = "What do you want to do?  (I/r/h) "

    def __init__(self, options):
        BasePolicy.__init__(self, options)
        if options.flat:
            default = 'h'
        elif options.one_entry_default is not None:
            default = options.one_entry_default.lower()
        else:
            return
        if 'here'.startswith(default):
            self.permanent_policy = EXTRACT_HERE
        elif 'rename'.startswith(default):
            self.permanent_policy = EXTRACT_RENAME
        elif 'inside'.startswith(default):
            self.permanent_policy = EXTRACT_WRAP
        elif default is not None:
            raise ValueError("bad value %s for default policy" % (default,))

    def prep(self, archive_filename, extractor):
        question = self.wrap(("%%s contains one %s, but its name " +
                              "doesn't match.") %
                             (extractor.content_type,), archive_filename)
        question.append(" Expected: " + extractor.basename())
        question.append("   Actual: " + extractor.content_name)
        self.current_policy = (self.permanent_policy or
                               self.ask_question(question))

    def ok_for_match(self):
        return self.current_policy in (EXTRACT_RENAME, EXTRACT_HERE)


class RecursionPolicy(BasePolicy):
    answers = {'o': RECURSE_ONCE, 'a': RECURSE_ALWAYS, 'n': RECURSE_NOT_NOW,
               'v': RECURSE_NEVER, 'l': RECURSE_LIST, '': RECURSE_NOT_NOW}
    choices = ["You can:",
               " * Always extract included archives",
               " * extract included archives this Once",
               " * choose Not to extract included archives",
               " * neVer extract included archives",
               " * List included archives"]
    prompt = "What do you want to do?  (a/o/N/v/l) "

    def __init__(self, options):
        BasePolicy.__init__(self, options)
        if options.show_list:
            self.permanent_policy = RECURSE_NEVER
        elif options.recursive:
            self.permanent_policy = RECURSE_ALWAYS

    def prep(self, current_filename, target, extractor):
        archive_count = len(extractor.included_archives)
        if (self.permanent_policy is not None) or (archive_count == 0):
            self.current_policy = self.permanent_policy or RECURSE_NOT_NOW
            return
        question = self.wrap(("%%s contains %s other archive file(s), " +
                              "out of %s file(s) total.") %
                             (archive_count, extractor.file_count),
                             current_filename)
        if target == '.':
            target = ''
        included_root = extractor.included_root
        if included_root == './':
            included_root = ''
        while True:
            self.current_policy = self.ask_question(question)
            if self.current_policy != RECURSE_LIST:
                break
            print ("\n%s\n" %
                   '\n'.join([os.path.join(target, included_root, filename)
                              for filename in extractor.included_archives]))
        if self.current_policy in (RECURSE_ALWAYS, RECURSE_NEVER):
            self.permanent_policy = self.current_policy

    def ok_to_recurse(self):
        return self.current_policy in (RECURSE_ALWAYS, RECURSE_ONCE)
            

class ExtractorBuilder(object):
    extractor_map = {'tar': {'extractor': TarExtractor,
                             'mimetypes': ('x-tar',),
                             'extensions': ('tar',),
                             'magic': ('POSIX tar archive',)},
                     'zip': {'extractor': ZipExtractor,
                             'mimetypes': ('zip',),
                             'extensions': ('zip',),
                             'magic': ('(Zip|ZIP self-extracting) archive',)},
                     'rpm': {'extractor': RPMExtractor,
                             'mimetypes': ('x-redhat-package-manager', 'x-rpm'),
                             'extensions': ('rpm',),
                             'magic': ('RPM',)},
                     'deb': {'extractor': DebExtractor,
                             'metadata': DebMetadataExtractor,
                             'mimetypes': ('x-debian-package',),
                             'extensions': ('deb',),
                             'magic': ('Debian binary package',)},
                     'cpio': {'extractor': CpioExtractor,
                              'mimetypes': ('x-cpio',),
                              'extensions': ('cpio',),
                              'magic': ('cpio archive',)},
                     'gem': {'extractor': GemExtractor,
                             'metadata': GemMetadataExtractor,
                             'mimetypes': ('x-ruby-gem',),
                             'extensions': ('gem',)},
                     '7z': {'extractor': SevenExtractor,
                             'mimetypes': ('x-7z-compressed',),
                             'extensions': ('7z',),
                             'magic': ('7-zip archive',)},
                     'cab': {'extractor': CABExtractor,
                             'mimetypes': ('x-cab',),
                             'extensions': ('cab',),
                             'magic': ('Microsoft Cabinet Archive',)},
                     'rar': {'extractor': RarExtractor,
                             'mimetypes': ('rar',),
                             'extensions': ('rar',),
                             'magic': ('RAR archive',)},
                     'shield': {'extractor': ShieldExtractor,
                                'mimetypes': ('x-cab',),
                                'extensions': ('cab', 'hdr'),
                                'magic': ('InstallShield CAB',)},
                     'compress': {'extractor': CompressionExtractor}
                     }

    mimetype_map = {}
    magic_mime_map = {}
    extension_map = {}
    for ext_name, ext_info in extractor_map.items():
        for mimetype in ext_info.get('mimetypes', ()):
            if '/' not in mimetype:
                mimetype = 'application/' + mimetype
            mimetype_map[mimetype] = ext_name
        for magic_re in ext_info.get('magic', ()):
            magic_mime_map[re.compile(magic_re)] = ext_name
        for extension in ext_info.get('extensions', ()):
            extension_map.setdefault(extension, []).append((ext_name, None))

    for mapping in (('tar', 'bzip2', 'tar.bz2'),
                    ('tar', 'gzip', 'tar.gz', 'tgz'),
                    ('compress', 'gzip', 'Z', 'gz'),
                    ('compress', 'bzip2', 'bz2'),
                    ('compress', 'lzma', 'lzma')):
        for extension in mapping[2:]:
            extension_map.setdefault(extension, []).append(mapping[:2])

    magic_encoding_map = {}
    for mapping in (('bzip2', 'bzip2 compressed'),
                    ('gzip', 'gzip compressed'),
                    ('lzma', 'LZMA compressed')):
        for pattern in mapping[1:]:
            magic_encoding_map[re.compile(pattern)] = mapping[0]

    def __init__(self, filename, options):
        self.filename = filename
        self.options = options

    def build_extractor(self, archive_type, encoding):
        extractors = self.extractor_map[archive_type]
        if self.options.metadata and extractors.has_key('metadata'):
            extractor = extractors['metadata']
        else:
            extractor = extractors['extractor']
        return extractor(self.filename, encoding)

    def get_extractor(self):
        tried_types = set()
        # As smart as it is, the magic test can't go first, because at least
        # on my system it just recognizes gem files as tar files.  I guess
        # it's possible for the opposite problem to occur -- where the mimetype
        # or extension suggests something less than ideal -- but it seems less
        # likely so I'm sticking with this.
        for func_name in ('mimetype', 'extension', 'magic'):
            logger.debug("getting extractors by %s" % (func_name,))
            extractor_types = \
                            getattr(self, 'try_by_' + func_name)(self.filename)
            logger.debug("done getting extractors")
            for ext_args in extractor_types:
                if ext_args in tried_types:
                    continue
                tried_types.add(ext_args)
                logger.debug("trying %s extractor from %s" %
                             (ext_args, func_name))
                yield self.build_extractor(*ext_args)

    def try_by_mimetype(cls, filename):
        mimetype, encoding = mimetypes.guess_type(filename)
        try:
            return [(cls.mimetype_map[mimetype], encoding)]
        except KeyError:
            if encoding:
                return [('compress', encoding)]
        return []
    try_by_mimetype = classmethod(try_by_mimetype)

    def magic_map_matches(cls, output, magic_map):
        return [result for regexp, result in magic_map.items()
                if regexp.search(output)]
    magic_map_matches = classmethod(magic_map_matches)
        
    def try_by_magic(cls, filename):
        process = subprocess.Popen(['file', '-z', filename],
                                   stdout=subprocess.PIPE)
        status = process.wait()
        if status != 0:
            return []
        output = process.stdout.readline()
        process.stdout.close()
        if output.startswith('%s: ' % filename):
            output = output[len(filename) + 2:]
        mimes = cls.magic_map_matches(output, cls.magic_mime_map)
        encodings = cls.magic_map_matches(output, cls.magic_encoding_map)
        if mimes and not encodings:
            encodings = [None]
        elif encodings and not mimes:
            mimes = ['compress']
        return [(m, e) for m in mimes for e in encodings]
    try_by_magic = classmethod(try_by_magic)

    def try_by_extension(cls, filename):
        parts = filename.split('.')[-2:]
        results = []
        while parts:
            results.extend(cls.extension_map.get('.'.join(parts), []))
            del parts[0]
        return results
    try_by_extension = classmethod(try_by_extension)


class BaseAction(object):
    def __init__(self, options, filenames):
        self.options = options
        self.filenames = filenames
        self.target = None
        
    def report(self, function, *args):
        try:
            error = function(*args)
        except EXTRACTION_ERRORS, exception:
            error = str(exception)
            logger.debug(''.join(traceback.format_exception(*sys.exc_info())))
        return error


class ExtractionAction(BaseAction):
    handlers = [FlatHandler, OverwriteHandler, MatchHandler, EmptyHandler,
                BombHandler]

    def __init__(self, options, filenames):
        BaseAction.__init__(self, options, filenames)
        self.did_print = False

    def get_handler(self, extractor):
        if extractor.content_type in ONE_ENTRY_UNKNOWN:
            self.options.one_entry_policy.prep(self.current_filename,
                                               extractor)
        for handler in self.handlers:
            if handler.can_handle(extractor.content_type, self.options):
                logger.debug("using %s handler" % (handler.__name__,))
                self.current_handler = handler(extractor, self.options)
                break

    def show_extraction(self, extractor):
        if self.options.log_level > logging.INFO:
            return
        elif self.did_print:
            print
        else:
            self.did_print = True
        print "%s:" % (self.current_filename,)
        if extractor.contents is None:
            print self.current_handler.target
            return
        def reverser(x, y):
            return cmp(y, x)
        if self.current_handler.target == '.':
            filenames = extractor.contents
            filenames.sort(reverser)
        else:
            filenames = [self.current_handler.target]
        pathjoin = os.path.join
        isdir = os.path.isdir
        while filenames:
            filename = filenames.pop()
            if isdir(filename):
                print "%s/" % (filename,)
                new_filenames = os.listdir(filename)
                new_filenames.sort(reverser)
                filenames.extend([pathjoin(filename, new_filename)
                                  for new_filename in new_filenames])
            else:
                print filename

    def run(self, filename, extractor):
        self.current_filename = filename
        error = (self.report(extractor.extract) or
                 self.report(self.get_handler, extractor) or
                 self.report(self.current_handler.handle) or
                 self.report(self.show_extraction, extractor))
        if not error:
            self.target = self.current_handler.target
        return error


class ListAction(BaseAction):
    def __init__(self, options, filenames):
        BaseAction.__init__(self, options, filenames)
        self.count = 0

    def get_list(self, extractor):
        # Note: The reason I'm getting all the filenames up front is
        # because if we run into trouble partway through the archive, we'll
        # try another extractor.  So before we display anything we have to
        # be sure this one is successful.  We maybe don't have to be quite
        # this conservative but this is the easy way out for now.
        self.filelist = list(extractor.get_filenames())

    def show_list(self, filename):
        self.count += 1
        if len(self.filenames) != 1:
            if self.count > 1:
                print
            print "%s:" % (filename,)
        print '\n'.join(self.filelist)

    def run(self, filename, extractor):
        return (self.report(self.get_list, extractor) or
                self.report(self.show_list, filename))


class ExtractorApplication(object):
    def __init__(self, arguments):
        for signal_num in (signal.SIGINT, signal.SIGTERM):
            signal.signal(signal_num, self.abort)
        signal.signal(signal.SIGPIPE, signal.SIG_DFL)
        self.parse_options(arguments)
        self.setup_logger()
        self.successes = []
        self.failures = []

    def clean_destination(self, dest_name):
        try:
            os.unlink(dest_name)
        except OSError, error:
            if error.errno == errno.EISDIR:
                shutil.rmtree(dest_name, ignore_errors=True)

    def abort(self, signal_num, frame):
        signal.signal(signal_num, signal.SIG_IGN)
        print
        logger.debug("traceback:\n" +
                     ''.join(traceback.format_stack(frame)).rstrip())
        logger.debug("got signal %s" % (signal_num,))
        try:
            basename = self.current_extractor.target
        except AttributeError:
            basename = None
        if basename is not None:
            logger.debug("cleaning up %s" % (basename,))
            clean_targets = set([os.path.realpath('.')])
            if hasattr(self, 'current_directory'):
                clean_targets.add(os.path.realpath(self.current_directory))
            for directory in clean_targets:
                self.clean_destination(os.path.join(directory, basename))
        sys.exit(1)

    def parse_options(self, arguments):
        parser = optparse.OptionParser(
            usage="%prog [options] archive [archive2 ...]",
            description="Intelligent archive extractor",
            version=VERSION_BANNER
            )
        parser.add_option('-l', '-t', '--list', '--table', dest='show_list',
                          action='store_true', default=False,
                          help="list contents of archives on standard output")
        parser.add_option('-m', '--metadata', dest='metadata',
                          action='store_true', default=False,
                          help="extract metadata from a .deb/.gem")
        parser.add_option('-r', '--recursive', dest='recursive',
                          action='store_true', default=False,
                          help="extract archives contained in the ones listed")
        parser.add_option('--one', '--one-entry', dest='one_entry_default',
                          default=None,
                          help=("specify extraction policy for one-entry " +
                                "archives: inside/rename/here"))
        parser.add_option('-n', '--noninteractive', dest='batch',
                          action='store_true', default=False,
                          help="don't ask how to handle special cases")
        parser.add_option('-o', '--overwrite', dest='overwrite',
                          action='store_true', default=False,
                          help="overwrite any existing target output")
        parser.add_option('-f', '--flat', '--no-directory', dest='flat',
                          action='store_true', default=False,
                          help="extract everything to the current directory")
        parser.add_option('-v', '--verbose', dest='verbose',
                          action='count', default=0,
                          help="be verbose/print debugging information")
        parser.add_option('-q', '--quiet', dest='quiet',
                          action='count', default=3,
                          help="suppress warning/error messages")
        self.options, filenames = parser.parse_args(arguments)
        if not filenames:
            parser.error("you did not list any archives")
        # This makes WARNING is the default.
        self.options.log_level = (10 * (self.options.quiet -
                                        self.options.verbose))
        try:
            self.options.one_entry_policy = OneEntryPolicy(self.options)
        except ValueError:
            parser.error("invalid value for --one-entry option")
        self.options.recursion_policy = RecursionPolicy(self.options)
        self.archives = {os.path.realpath(os.curdir): filenames}

    def setup_logger(self):
        logging.getLogger().setLevel(self.options.log_level)
        handler = logging.StreamHandler()
        handler.setLevel(self.options.log_level)
        formatter = logging.Formatter("dtrx: %(levelname)s: %(message)s")
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.debug("logger is set up")

    def recurse(self, filename, extractor, action):
        self.options.recursion_policy.prep(filename, action.target, extractor)
        if self.options.recursion_policy.ok_to_recurse():
            for filename in extractor.included_archives:
                logger.debug("recursing with %s archive" %
                             (extractor.content_type,))
                tail_path, basename = os.path.split(filename)
                path_args = [self.current_directory, extractor.included_root,
                             tail_path]
                logger.debug("included root: %s" % (extractor.included_root,))
                logger.debug("tail path: %s" % (tail_path,))
                if os.path.isdir(action.target):
                    logger.debug("action target: %s" % (action.target,))
                    path_args.insert(1, action.target)
                directory = os.path.join(*path_args)
                self.archives.setdefault(directory, []).append(basename)

    def check_file(self, filename):
        try:
            result = os.stat(filename)
        except OSError, error:
            return error.strerror
        if stat.S_ISDIR(result.st_mode):
            return "cannot work with a directory"

    def show_stderr(self, logger_func, stderr):
        if stderr:
            logger_func("Error output from this process:\n" +
                        stderr.rstrip('\n'))

    def try_extractors(self, filename, builder):
        errors = []
        for extractor in builder:
            self.current_extractor = extractor  # For the abort() method.
            error = self.action.run(filename, extractor)
            if error:
                errors.append((extractor.file_type, extractor.encoding, error,
                               extractor.get_stderr()))
                if extractor.target is not None:
                    self.clean_destination(extractor.target)
            else:
                self.show_stderr(logger.warn, extractor.get_stderr())
                self.recurse(filename, extractor, self.action)
                return
        logger.error("could not handle %s" % (filename,))
        if not errors:
            logger.error("not a known archive type")
            return True
        for file_type, encoding, error, stderr in errors:
            message = ["treating as", file_type, "failed:", error]
            if encoding:
                message.insert(1, "%s-encoded" % (encoding,))
            logger.error(' '.join(message))
            self.show_stderr(logger.error, stderr)
        return True
        
    def run(self):
        if self.options.show_list:
            action = ListAction
        else:
            action = ExtractionAction
        self.action = action(self.options, self.archives.values()[0])
        while self.archives:
            self.current_directory, self.filenames = self.archives.popitem()
            os.chdir(self.current_directory)
            for filename in self.filenames:
                builder = ExtractorBuilder(filename, self.options)
                error = (self.check_file(filename) or
                         self.try_extractors(filename, builder.get_extractor()))
                if error:
                    if error != True:
                        logger.error("%s: %s" % (filename, error))
                    self.failures.append(filename)
                else:
                    self.successes.append(filename)
            self.options.one_entry_policy.permanent_policy = EXTRACT_WRAP
        if self.failures:
            return 1
        return 0


if __name__ == '__main__':
    app = ExtractorApplication(sys.argv[1:])
    sys.exit(app.run())

mercurial