[svn] Move policy-handling code into a dedicated set of classes. This makes trunk

Sun, 29 Apr 2007 15:12:02 -0400

author
brett
date
Sun, 29 Apr 2007 15:12:02 -0400
branch
trunk
changeset 25
ef62f2f55eb8
parent 24
60056f3e3e60
child 26
d660410455d9

[svn] Move policy-handling code into a dedicated set of classes. This makes
question construction at least moderately cleaner, and more importantly, it
gets it out of the main application class, where it was bugging me.

TODO file | annotate | diff | comparison | revisions
scripts/dtrx file | annotate | diff | comparison | revisions
--- a/TODO	Sun Apr 29 13:29:50 2007 -0400
+++ b/TODO	Sun Apr 29 15:12:02 2007 -0400
@@ -1,5 +1,4 @@
 Things which I have a use case/anti-use case for:
-* Some better way of dealing with options/policies.
 * Use file to detect the archive type.
 * Support lzma compression (http://tukaani.org/lzma/download)
 * Support pisi packages (http://paketler.pardus.org.tr/pardus-2007/)
--- a/scripts/dtrx	Sun Apr 29 13:29:50 2007 -0400
+++ b/scripts/dtrx	Sun Apr 29 15:12:02 2007 -0400
@@ -417,13 +417,13 @@
     def can_handle(contents, options):
         return ((contents == MATCHING_DIRECTORY) or
                 ((contents == ONE_ENTRY) and
-                 (options.onedir_policy in (EXTRACT_RENAME, EXTRACT_HERE))))
+                 options.one_entry_policy.ok_for_match()))
     can_handle = staticmethod(can_handle)
 
     def extract(self):
         if self.contents == MATCHING_DIRECTORY:
             basename = destination = self.extractor.basename()
-        elif self.options.onedir_policy == EXTRACT_HERE:
+        elif self.options.one_entry_policy == EXTRACT_HERE:
             basename = destination = self.content_name.rstrip('/')
         else:
             basename = self.content_name.rstrip('/')
@@ -459,6 +459,93 @@
         self.target = checker.check()
 
         
+class BasePolicy(object):
+    def __init__(self, options):
+        self.current_policy = None
+
+    def ask_question(self, question):
+        question = textwrap.wrap(question) + self.choices
+        while True:
+            print "\n".join(question)
+            try:
+                answer = raw_input(self.prompt)
+            except EOFError:
+                return self.answers['']
+            try:
+                return self.answers[answer.lower()]
+            except KeyError:
+                print
+
+    def prep(self, *args):
+        self.current_policy = (self.permanent_policy or
+                               self.ask_question(*args))
+
+    def __cmp__(self, other):
+        return cmp(self.current_policy, other)
+    
+
+class OneEntryPolicy(BasePolicy):
+    answers = {'h': EXTRACT_HERE, 'i': EXTRACT_WRAP, 'r': EXTRACT_RENAME,
+               '': EXTRACT_WRAP}
+    choices = ["You can:",
+               " * extract it Inside another directory",
+               " * extract it and Rename the directory",
+               " * extract it Here"]
+    prompt = "What do you want to do?  (I/r/h) "
+
+    def __init__(self, options):
+        BasePolicy.__init__(self, options)
+        if options.batch:
+            self.permanent_policy = self.answers['']
+        else:
+            self.permanent_policy = None
+
+    def ask_question(self, archive_filename, entry_name):
+        return BasePolicy.ask_question(self, ("%s contains one entry: %s." %
+                                              (archive_filename, entry_name)))
+
+    def ok_for_match(self):
+        return self.current_policy in (EXTRACT_RENAME, EXTRACT_HERE)
+
+
+class RecursionPolicy(BasePolicy):
+    answers = {'o': RECURSE_ONCE, 'a': RECURSE_ALWAYS, 'n': RECURSE_NOT_NOW,
+               'v': RECURSE_NEVER, '': RECURSE_NOT_NOW}
+    choices = ["You can:",
+               " * Always extract included archives",
+               " * extract included archives this Once",
+               " * choose Not to extract included archives",
+               " * neVer extract included archives"]
+    prompt = "What do you want to do?  (a/o/N/v) "
+
+    def __init__(self, options):
+        BasePolicy.__init__(self, options)
+        if options.recursive:
+            self.permanent_policy = RECURSE_ALWAYS
+        elif options.batch:
+            self.permanent_policy = self.answers['']
+        else:
+            self.permanent_policy = None
+
+    def prep(self, current_filename, included_archives):
+        archive_count = len(included_archives)
+        if (self.permanent_policy is not None) or (archive_count == 0):
+            self.current_policy = self.permanent_policy or RECURSE_NOT_NOW
+            return
+        elif archive_count > 1:
+            question = ("%s contains %s other archive files." %
+                        (current_filename, archive_count))
+        else:
+            question = ("%s contains another archive: %s." %
+                        (current_filename, included_archives[0]))
+        self.current_policy = self.ask_question(question)
+        if self.current_policy in (RECURSE_ALWAYS, RECURSE_NEVER):
+            self.permanent_policy = self.current_policy
+
+    def ok_to_recurse(self):
+        return self.current_policy in (RECURSE_ALWAYS, RECURSE_ONCE)
+            
+
 extractor_map = {'application/x-tar': TarExtractor,
                  'application/zip': ZipExtractor,
                  'application/x-msdos-program': ZipExtractor,
@@ -471,12 +558,6 @@
             BombHandler]
 
 class ExtractorApplication(object):
-    policy_answers = {'h': EXTRACT_HERE, 'i': EXTRACT_WRAP,
-                      'r': EXTRACT_RENAME, '': EXTRACT_WRAP}
-    recursive_answers = {'o': RECURSE_ONCE, 'a': RECURSE_ALWAYS,
-                         'n': RECURSE_NOT_NOW, 'v': RECURSE_NEVER,
-                         '': RECURSE_NOT_NOW}
-
     def __init__(self, arguments):
         self.parse_options(arguments)
         self.setup_logger()
@@ -513,6 +594,8 @@
         self.options, filenames = parser.parse_args(arguments)
         if not filenames:
             parser.error("you did not list any archives")
+        self.options.one_entry_policy = OneEntryPolicy(self.options)
+        self.options.recursion_policy = RecursionPolicy(self.options)
         self.archives = {os.path.realpath(os.curdir): filenames}
 
     def setup_logger(self):
@@ -524,21 +607,6 @@
         handler.setFormatter(formatter)
         self.logger.addHandler(handler)
 
-    def ask_question(self, question, answers):
-        if self.options.batch:
-            return answers['']
-        last_line = question.pop()
-        while True:
-            print "\n".join(question)
-            try:
-                answer = raw_input(last_line)
-            except EOFError:
-                return answers['']
-            try:
-                return answers[answer.lower()]
-            except KeyError:
-                print
-
     def get_extractor(self):
         mimetype, encoding = mimetypes.guess_type(self.current_filename)
         try:
@@ -557,16 +625,9 @@
     def get_handler(self):
         try:
             content, content_name = self.current_extractor.check_contents()
-            if (content == ONE_ENTRY) and (self.options.onedir_policy is None):
-                question = textwrap.wrap("%s contains one entry: %s." %
-                                         (self.current_filename, content_name))
-                question.extend(["You can:",
-                                 " * extract it Inside another directory",
-                                 " * extract it and Rename the directory",
-                                 " * extract it Here",
-                                 "What do you want to do?  (I/r/h) "])
-                self.options.onedir_policy = \
-                    self.ask_question(question, self.policy_answers)
+            if content == ONE_ENTRY:
+                self.options.one_entry_policy.prep(self.current_filename,
+                                                   content_name)
             for handler in handlers:
                 if handler.can_handle(content, self.options):
                     self.current_handler = handler(self.current_extractor,
@@ -576,40 +637,15 @@
         except ExtractorError, error:
             return str(error)
 
-    def get_recursion_policy(self):
-        if len(self.current_extractor.included_archives) > 1:
-            question = ("%s contains %s other archive files." %
-                        (self.current_filename,
-                         len(self.current_extractor.included_archives)))
-        else:
-            question = ("%s contains another archive: %s." %
-                        (self.current_filename,
-                         self.current_extractor.included_archives[0]))
-        question = textwrap.wrap(question)
-        question.extend(["You can:",
-                         " * Always extract included archives",
-                         " * extract included archives this Once",
-                         " * choose Not to extract included archives",
-                         " * neVer extract included archives",
-                         "What do you want to do?  (a/o/N/v) "])
-        self.options.recurse_policy = self.ask_question(question,
-                                                        self.recursive_answers)
-
     def recurse(self):
-        if not self.current_extractor.included_archives:
-            return
-        if self.options.recurse_policy is None:
-            if self.options.recursive:
-                self.options.recurse_policy = RECURSE_ALWAYS
-            else:
-                self.get_recursion_policy()
-        if self.options.recurse_policy in (RECURSE_NOT_NOW, RECURSE_NEVER):
-            return
-        for filename in self.current_extractor.included_archives:
-            tail_path, basename = os.path.split(filename)
-            directory = os.path.join(self.current_directory,
-                                     self.current_handler.target, tail_path)
-            self.archives.setdefault(directory, []).append(basename)
+        archives = self.current_extractor.included_archives
+        self.options.recursion_policy.prep(self.current_filename, archives)
+        if self.options.recursion_policy.ok_to_recurse():
+            for filename in archives:
+                tail_path, basename = os.path.split(filename)
+                directory = os.path.join(self.current_directory,
+                                         self.current_handler.target, tail_path)
+                self.archives.setdefault(directory, []).append(basename)
 
     def report(self, function, *args):
         try:
@@ -628,17 +664,14 @@
             self.failures.append(self.current_filename)
 
     def extract(self):
-        self.options.recurse_policy = None
         first_run = True
         while self.archives:
+            if not first_run:
+                self.options.one_entry_policy.permanent_policy = EXTRACT_WRAP
+            else:
+                first_run = False
             self.current_directory, filenames = self.archives.popitem()
-            self.options.onedir_policy = EXTRACT_WRAP
             for filename in filenames:
-                if first_run:
-                    self.options.onedir_policy = None
-                if self.options.recurse_policy not in (RECURSE_ALWAYS,
-                                                       RECURSE_NEVER):
-                    self.options.recurse_policy = None
                 os.chdir(self.current_directory)
                 self.current_filename = filename
                 success = (self.report(self.get_extractor) and
@@ -649,7 +682,6 @@
                                                        name)) and success)
                     self.recurse()
                 self.record_status(success)
-            first_run = False
 
     def show_contents(self):
         for filename in self.current_extractor.get_filenames():

mercurial