scripts/dtrx

branch
trunk
changeset 28
4d88f2231d33
parent 27
5711a4714e47
child 29
5fad99c17221
--- a/scripts/dtrx	Sun Apr 29 15:30:01 2007 -0400
+++ b/scripts/dtrx	Fri Oct 19 22:46:20 2007 -0400
@@ -5,7 +5,7 @@
 #
 # This program is free software; you can redistribute it and/or modify it
 # under the terms of the GNU General Public License as published by the
-# Free Software Foundation; either version 2 of the License, or (at your
+# Free Software Foundation; either version 3 of the License, or (at your
 # option) any later version.
 #
 # This program is distributed in the hope that it will be useful, but
@@ -27,16 +27,17 @@
 import sys
 import tempfile
 import textwrap
+import traceback
 
 from cStringIO import StringIO
 
-VERSION = "4.0"
+VERSION = "5.0"
 VERSION_BANNER = """dtrx version %s
 Copyright (c) 2006, 2007 Brett Smith <brettcsmith@brettcsmith.org>
 
 This program is free software; you can redistribute it and/or modify it
 under the terms of the GNU General Public License as published by the
-Free Software Foundation; either version 2 of the License, or (at your
+Free Software Foundation; either version 3 of the License, or (at your
 option) any later version.
 
 This program is distributed in the hope that it will be useful, but
@@ -107,50 +108,21 @@
     pass
 
 
-class ProcessStreamer(object):
-    def __init__(self, command, stdin, description="checking contents",
-                 stderr=None):
-        self.process = subprocess.Popen(command, bufsize=1, stdin=stdin,
-                                        stdout=subprocess.PIPE, stderr=stderr)
-        self.command = ' '.join(command)
-        self.description = description
-
-    def __iter__(self):
-        return self
-
-    def next(self):
-        line = self.process.stdout.readline()
-        if line:
-            return line.rstrip('\n')
-        else:
-            raise StopIteration
-
-    def stop(self):
-        while self.process.stdout.readline():
-            pass
-        self.process.stdout.close()
-        status = self.process.wait()
-        if status != 0:
-            raise ExtractorError("%s error: '%s' returned status code %s" %
-                                 (self.description, self.command, status))
-        try:
-            self.process.stderr.close()
-        except AttributeError:
-            pass
-    
-
 class BaseExtractor(object):
     decoders = {'bzip2': 'bzcat', 'gzip': 'zcat', 'compress': 'zcat'}
 
     name_checker = DirectoryChecker
 
-    def __init__(self, filename, mimetype, encoding):
+    def __init__(self, filename, encoding):
         if encoding and (not self.decoders.has_key(encoding)):
             raise ValueError("unrecognized encoding %s" % (encoding,))
         self.filename = os.path.realpath(filename)
-        self.mimetype = mimetype
         self.encoding = encoding
         self.included_archives = []
+        self.target = None
+        self.content_type = None
+        self.content_name = None
+        self.pipes = []
         try:
             self.archive = open(filename, 'r')
         except (IOError, OSError), error:
@@ -160,69 +132,65 @@
             self.pipe([self.decoders[encoding]], "decoding")
         self.prepare()
 
-    def run(self, command, description="extraction", stdout=None, stderr=None,
-            stdin=None):
-        error = run_command(command, description, stdout, stderr, stdin)
-        if error:
-            raise ExtractorError(error)
+    def pipe(self, command, description="extraction"):
+        self.pipes.append((command, description))
 
-    def pipe(self, command, description, stderr=None):
-        output = tempfile.TemporaryFile()
-        self.run(command, description, output, stderr, self.archive)
+    def run_pipes(self, final_stdout=None):
+        if final_stdout is None:
+            # FIXME: Buffering this might be dumb.
+            final_stdout = tempfile.TemporaryFile()
+        if not self.pipes:
+            return
+        num_pipes = len(self.pipes)
+        last_pipe = num_pipes - 1
+        processes = []
+        for index, command in enumerate([pipe[0] for pipe in self.pipes]):
+            if index == 0:
+                stdin = self.archive
+            else:
+                stdin = processes[-1].stdout
+            if index == last_pipe:
+                stdout = final_stdout
+            else:
+                stdout = subprocess.PIPE
+            processes.append(subprocess.Popen(command, stdin=stdin,
+                                              stdout=stdout,
+                                              stderr=subprocess.PIPE))
+        exit_codes = [pipe.wait() for pipe in processes]
         self.archive.close()
-        self.archive = output
-        self.archive.flush()
+        for index in range(last_pipe):
+            processes[index].stdout.close()
+            processes[index].stderr.close()
+        for index, status in enumerate(exit_codes):
+            if status != 0:
+                raise ExtractorError("%s error: '%s' returned status code %s" %
+                                     (self.pipes[index][1],
+                                      ' '.join(self.pipes[index][0]), status))
+        self.archive = final_stdout
     
     def prepare(self):
         pass
 
-    def check_included_archive(self, filename):
-        if extractor_map.has_key(mimetypes.guess_type(filename)[0]):
-            self.included_archives.append(filename)
-
-    def check_first_filename(self, filenames):
-        try:
-            first_filename = filenames.next()
-        except StopIteration:
-            filenames.stop()
-            return (None, None)
-        self.check_included_archive(first_filename)
-        parts = first_filename.split('/')
-        first_part = [parts[0]]
-        if parts[0] == '.':
-            first_part.append(parts[1])
-        return (first_filename, '/'.join(first_part + ['']))
+    def check_included_archives(self, filenames):
+        for filename in filenames:
+            if extractor_map.has_key(mimetypes.guess_type(filename)[0]):
+                self.included_archives.append(filename)
 
-    def check_second_filename(self, filenames, first_part, first_filename):
-        try:
-            filename = filenames.next()
-        except StopIteration:
-            return ONE_ENTRY, first_filename
-        self.check_included_archive(filename)
-        if not filename.startswith(first_part):
-            return BOMB, None
-        return None, first_part
-        
     def check_contents(self):
-        filenames = self.get_filenames()
-        first_filename, first_part = self.check_first_filename(filenames)
-        if first_filename is None:
-            return (EMPTY, None)
-        archive_type, type_info = self.check_second_filename(filenames,
-                                                             first_part,
-                                                             first_filename)
-        for filename in filenames:
-            self.check_included_archive(filename)
-            if (archive_type != BOMB) and (not filename.startswith(first_part)):
-                archive_type = BOMB
-                type_info = None
-        filenames.stop()
-        if archive_type is None:
-            if self.basename() == first_part[:-1]:
-                archive_type = MATCHING_DIRECTORY
+        filenames = os.listdir('.')
+        if not filenames:
+            self.content_type = EMPTY
+        elif len(filenames) == 1:
+            if self.basename() == filenames[0]:
+                self.content_type = MATCHING_DIRECTORY
             else:
-                archive_type = ONE_ENTRY
-        return archive_type, type_info
+                self.content_type = ONE_ENTRY
+            self.content_name = filenames[0]
+            if os.path.isdir(filenames[0]):
+                self.content_name += '/'
+        else:
+            self.content_type = BOMB
+        self.check_included_archives(filenames)
 
     def basename(self):
         pieces = os.path.basename(self.filename).split('.')
@@ -236,49 +204,59 @@
             pieces.pop()
         return '.'.join(pieces)
 
-    def extract(self, path):
+    def extract(self):
+        self.target = tempfile.mkdtemp(prefix='.dtrx-', dir='.')
         old_path = os.path.realpath(os.curdir)
-        os.chdir(path)
+        os.chdir(self.target)
         self.archive.seek(0, 0)
         self.extract_archive()
+        self.check_contents()
         os.chdir(old_path)
+
+    def get_filenames(self):
+        self.run_pipes()
+        self.archive.seek(0, 0)
+        while True:
+            line = self.archive.readline()
+            if not line:
+                self.archive.close()
+                return
+            yield line.rstrip('\n')
     
 
 class TarExtractor(BaseExtractor):
     def get_filenames(self):
-        self.archive.seek(0, 0)
-        return ProcessStreamer(['tar', '-t'], self.archive)
+        self.pipe(['tar', '-t'], "listing")
+        return BaseExtractor.get_filenames(self)
 
     def extract_archive(self): 
-        self.run(['tar', '-x'], stdin=self.archive)
+        self.pipe(['tar', '-x'])
+        self.run_pipes()
         
         
 class ZipExtractor(BaseExtractor):
-    def __init__(self, filename, mimetype, encoding):
+    def __init__(self, filename, encoding):
+        BaseExtractor.__init__(self, '/dev/null', None)
         self.filename = os.path.realpath(filename)
-        self.mimetype = mimetype
-        self.encoding = encoding
-        self.included_archives = []
-        self.archive = StringIO()
 
     def get_filenames(self):
-        self.archive.seek(0, 0)
-        return ProcessStreamer(['zipinfo', '-1', self.filename], None)
+        self.pipe(['zipinfo', '-1', self.filename], "listing")
+        return BaseExtractor.get_filenames(self)
 
     def extract_archive(self):
-        self.run(['unzip', '-q', self.filename])
+        self.pipe(['unzip', '-q', self.filename])
+        self.run_pipes()
 
 
 class CpioExtractor(BaseExtractor):
     def get_filenames(self):
-        self.archive.seek(0, 0)
-        return ProcessStreamer(['cpio', '-t'], self.archive,
-                               stderr=subprocess.PIPE)
+        self.pipe(['cpio', '-t'], "listing")
+        return BaseExtractor.get_filenames(self)
 
     def extract_archive(self):
-        self.run(['cpio', '-i', '--make-directories',
-                  '--no-absolute-filenames'],
-                 stderr=subprocess.PIPE, stdin=self.archive)
+        self.pipe(['cpio', '-i', '--make-directories',
+                   '--no-absolute-filenames'])
+        self.run_pipes()
 
 
 class RPMExtractor(CpioExtractor):
@@ -299,15 +277,14 @@
         return '.'.join(pieces)
 
     def check_contents(self):
-        CpioExtractor.check_contents(self)
-        return (BOMB, None)
+        self.check_included_archives(os.listdir('.'))
+        self.content_type = BOMB
 
 
 class DebExtractor(TarExtractor):
     def prepare(self):
         self.pipe(['ar', 'p', self.filename, 'data.tar.gz'],
                   "data.tar.gz extraction")
-        self.archive.seek(0, 0)
         self.pipe(['zcat'], "data.tar.gz decompression")
 
     def basename(self):
@@ -320,9 +297,9 @@
         return '_'.join(pieces)
 
     def check_contents(self):
-        TarExtractor.check_contents(self)
-        return (BOMB, None)
-        
+        self.check_included_archives(os.listdir('.'))
+        self.content_type = BOMB
+
 
 class CompressionExtractor(BaseExtractor):
     name_checker = FilenameChecker
@@ -337,42 +314,32 @@
     def get_filenames(self):
         yield self.basename()
 
-    def check_contents(self):
-        return (ONE_ENTRY_KNOWN, self.basename())
-
-    def extract(self, path):
-        output = open(path, 'w')
-        self.archive.seek(0, 0)
-        self.run(['cat'], "output write", stdin=self.archive, stdout=output)
-        output.close()
+    def extract(self):
+        self.content_type = ONE_ENTRY_KNOWN
+        self.content_name = self.basename()
+        output_fd, self.target = tempfile.mkstemp(prefix='.dtrx-', dir='.')
+        self.run_pipes(output_fd)
+        os.close(output_fd)
         
 
 class BaseHandler(object):
-    def __init__(self, extractor, contents, content_name, options):
+    def __init__(self, extractor, options):
         self.logger = logging.getLogger('dtrx-log')
         self.extractor = extractor
-        self.contents = contents
-        self.content_name = content_name
         self.options = options
         self.target = None
 
-    def extract(self):
-        try:
-            self.extractor.extract(self.target)
-        except (ExtractorError, IOError, OSError), error:
-            return str(error)
-        
-    def cleanup(self):
-        if self.target is None:
-            return
+    def handle(self):
         command = 'find'
-        status = subprocess.call(['find', self.target, '-type', 'd',
+        status = subprocess.call(['find', self.extractor.target, '-type', 'd',
                                   '-exec', 'chmod', 'u+rwx', '{}', ';'])
         if status == 0:
             command = 'chmod'
-            status = subprocess.call(['chmod', '-R', 'u+rw', self.target])
+            status = subprocess.call(['chmod', '-R', 'u+rwX',
+                                      self.extractor.target])
         if status != 0:
             return "%s returned with exit status %s" % (command, status)
+        return self.organize()
 
 
 # The "where to extract" table, with options and archive types.
@@ -389,17 +356,22 @@
                 (options.overwrite and (contents == MATCHING_DIRECTORY)))
     can_handle = staticmethod(can_handle)
 
-    def __init__(self, extractor, contents, content_name, options):
-        BaseHandler.__init__(self, extractor, contents, content_name, options)
+    def organize(self):
         self.target = '.'
-
-    def cleanup(self):
-        for filename in self.extractor.get_filenames():
-            stat_info = os.stat(filename)
-            perms = stat.S_IRUSR | stat.S_IWUSR
-            if stat.S_ISDIR(stat_info.st_mode):
-                perms |= stat.S_IXUSR
-            os.chmod(filename, stat_info.st_mode | perms)
+        for curdir, dirs, filenames in os.walk(self.extractor.target,
+                                               topdown=False):
+            path_parts = curdir.split(os.sep)
+            if path_parts[0] == '.':
+                path_parts.pop(1)
+            else:
+                path_parts.pop(0)
+            newdir = os.path.join(*path_parts)
+            if not os.path.isdir(newdir):
+                os.makedirs(newdir)
+            for filename in filenames:
+                os.rename(os.path.join(curdir, filename),
+                          os.path.join(newdir, filename))
+            os.rmdir(curdir)
 
 
 class OverwriteHandler(BaseHandler):
@@ -408,9 +380,13 @@
                 (options.overwrite and (contents != MATCHING_DIRECTORY)))
     can_handle = staticmethod(can_handle)
 
-    def __init__(self, extractor, contents, content_name, options):
-        BaseHandler.__init__(self, extractor, contents, content_name, options)
+    def organize(self):
         self.target = self.extractor.basename()
+        result = run_command(['rm', '-rf', self.target],
+                             "removing %s to overwrite" % (self.target,))
+        if result is None:
+            os.rename(self.extractor.target, self.target)
+        return result
         
 
 class MatchHandler(BaseHandler):
@@ -420,22 +396,19 @@
                  options.one_entry_policy.ok_for_match()))
     can_handle = staticmethod(can_handle)
 
-    def extract(self):
-        if self.contents == MATCHING_DIRECTORY:
-            basename = destination = self.extractor.basename()
-        elif self.options.one_entry_policy == EXTRACT_HERE:
-            basename = destination = self.content_name.rstrip('/')
+    def organize(self):
+        if self.options.one_entry_policy == EXTRACT_HERE:
+            destination = self.extractor.content_name.rstrip('/')
         else:
-            basename = self.content_name.rstrip('/')
             destination = self.extractor.basename()
-        self.target = tempdir = tempfile.mkdtemp(dir='.')
-        result = BaseHandler.extract(self)
-        if result is None:
-            checker = self.extractor.name_checker(destination)
-            self.target = checker.check()
-            os.rename(os.path.join(tempdir, basename), self.target)
-            os.rmdir(tempdir)
-        return result
+        self.target = self.extractor.name_checker(destination).check()
+        if os.path.isdir(self.extractor.target):
+            os.rename(os.path.join(self.extractor.target,
+                                   os.listdir(self.extractor.target)[0]),
+                      self.target)
+            os.rmdir(self.extractor.target)
+        else:
+            os.rename(self.extractor.target, self.target)
 
 
 class EmptyHandler(object):
@@ -443,9 +416,8 @@
         return contents == EMPTY
     can_handle = staticmethod(can_handle)
 
-    def __init__(self, extractor, contents, content_name, options): pass
-    def extract(self): pass
-    def cleanup(self): pass
+    def __init__(self, extractor, options): pass
+    def handle(self): pass
 
 
 class BombHandler(BaseHandler):
@@ -453,10 +425,10 @@
         return True
     can_handle = staticmethod(can_handle)
 
-    def __init__(self, extractor, contents, content_name, options):
-        BaseHandler.__init__(self, extractor, contents, content_name, options)
-        checker = self.extractor.name_checker(self.extractor.basename())
-        self.target = checker.check()
+    def organize(self):
+        basename = self.extractor.basename()
+        self.target = self.extractor.name_checker(basename).check()
+        os.rename(self.extractor.target, self.target)
 
         
 class BasePolicy(object):
@@ -608,25 +580,22 @@
             else:
                 return "not a known archive type"
         try:
-            self.current_extractor = extractor(self.current_filename, mimetype,
-                                               encoding)
+            self.current_extractor = extractor(self.current_filename, encoding)
         except ExtractorError, error:
             return str(error)
 
     def get_handler(self):
-        try:
-            content, content_name = self.current_extractor.check_contents()
-            if content == ONE_ENTRY:
-                self.options.one_entry_policy.prep(self.current_filename,
-                                                   content_name)
-            for handler in handlers:
-                if handler.can_handle(content, self.options):
-                    self.current_handler = handler(self.current_extractor,
-                                                   content, content_name,
-                                                   self.options)
-                    break
-        except ExtractorError, error:
-            return str(error)
+        for var_name in ('type', 'name'):
+            exec('content_%s = self.current_extractor.content_%s' %
+                 (var_name, var_name))
+        if content_type == ONE_ENTRY:
+            self.options.one_entry_policy.prep(self.current_filename,
+                                               content_name)
+        for handler in handlers:
+            if handler.can_handle(content_type, self.options):
+                self.current_handler = handler(self.current_extractor,
+                                               self.options)
+                break
 
     def recurse(self):
         archives = self.current_extractor.included_archives
@@ -643,6 +612,7 @@
             error = function(*args)
         except (ExtractorError, IOError, OSError), exception:
             error = str(exception)
+            self.logger.debug(traceback.format_exception(*sys.exc_info()))
         if error:
             self.logger.error("%s: %s", self.current_filename, error)
             return False
@@ -661,11 +631,10 @@
                 os.chdir(self.current_directory)
                 self.current_filename = filename
                 success = (self.report(self.get_extractor) and
-                           self.report(self.get_handler))
+                           self.report(self.current_extractor.extract) and
+                           self.report(self.get_handler) and
+                           self.report(self.current_handler.handle))
                 if success:
-                    for name in 'extract', 'cleanup':
-                        success = (self.report(getattr(self.current_handler,
-                                                       name)) and success)
                     self.recurse()
                 self.record_status(success)
             self.options.one_entry_policy.permanent_policy = EXTRACT_WRAP

mercurial