Package TEES :: Package Classifiers :: Module SVMMultiClassClassifier
[hide private]

Source Code for Module TEES.Classifiers.SVMMultiClassClassifier

  1  import sys,os 
  2  sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/..") 
  3  import subprocess 
  4  import copy 
  5  import tempfile 
  6  import types, copy 
  7  from ExternalClassifier import ExternalClassifier 
  8  import Utils.Settings as Settings 
  9  import Utils.Download as Download 
 10  import Tools.Tool 
 11  #import SVMMultiClassModelUtils 
 12  from Evaluators.AveragingMultiClassEvaluator import AveragingMultiClassEvaluator 
 13   
14 -def install(destDir=None, downloadDir=None, redownload=False, compile=True, updateLocalSettings=False):
15 print >> sys.stderr, "Installing SVM-Multiclass" 16 if compile: 17 url = Settings.URL["SVM_MULTICLASS_SOURCE"] 18 else: 19 url = Settings.URL["SVM_MULTICLASS_LINUX"] 20 if downloadDir == None: 21 downloadDir = os.path.join(Settings.DATAPATH, "tools/download/") 22 if destDir == None: 23 destDir = Settings.DATAPATH 24 destDir += "/tools/SVMMultiClass" 25 26 Download.downloadAndExtract(url, destDir, downloadDir, redownload=redownload) 27 if compile: 28 print >> sys.stderr, "Compiling SVM-Multiclass" 29 Tools.Tool.testPrograms("SVM-Multiclass", ["make"]) 30 subprocess.call("cd " + destDir + "; make", shell=True) 31 32 Tools.Tool.finalizeInstall(["svm_multiclass_learn", "svm_multiclass_classify"], 33 {"svm_multiclass_learn":"echo | ./svm_multiclass_learn -? > /dev/null", 34 "svm_multiclass_classify":"echo | ./svm_multiclass_classify -? > /dev/null"}, 35 destDir, {"SVM_MULTICLASS_DIR":destDir}, updateLocalSettings)
36
37 -class SVMMultiClassClassifier(ExternalClassifier):
38 """ 39 A wrapper for the Joachims SVM Multiclass classifier. 40 """ 41
42 - def __init__(self, connection=None):
43 ExternalClassifier.__init__(self, connection=connection) 44 self.defaultEvaluator = AveragingMultiClassEvaluator 45 self.parameterFormat = "-%k %v" 46 self.trainDirSetting = "SVM_MULTICLASS_DIR" 47 self.trainCommand = "svm_multiclass_learn %p %e %m" 48 self.classifyDirSetting = "SVM_MULTICLASS_DIR" 49 self.classifyCommand = "svm_multiclass_classify %e %m %c"
50 51 # def filterIds(self, ids, model, verbose=False): 52 # # Get feature ids 53 # if type(ids) in types.StringTypes: 54 # from Core.IdSet import IdSet 55 # ids = IdSet(filename=ids) 56 # # Get SVM model file feature ids 57 # if verbose: 58 # print >> sys.stderr, "Reading SVM model" 59 # if model.endswith(".gz"): 60 # f = gzip.open(model, "rt") 61 # else: 62 # f = open(model, "rt") 63 # supportVectorLine = f.readlines()[-1] 64 # f.close() 65 # modelIdNumbers = set() 66 # for split in supportVectorLine.split(): 67 # if ":" in split: 68 # idPart = split.split(":")[0] 69 # if idPart.isdigit(): 70 # #print idPart 71 # modelIdNumbers.add(int(idPart)) 72 # modelIdNumbers = list(modelIdNumbers) 73 # modelIdNumbers.sort() 74 # # Make a new feature set with only features that are in the model file 75 # if verbose: 76 # print >> sys.stderr, "Feature set has", len(ids.Ids), "features, highest id is", max(ids._namesById.keys()) 77 # print >> sys.stderr, "Model has", len(modelIdNumbers), "features" 78 # print >> sys.stderr, "Filtering ids" 79 # newIds = IdSet() 80 # newIds.nextFreeId = 999999999 81 # for featureId in modelIdNumbers: 82 # featureName = ids.getName(featureId) 83 # assert featureName != None, featureId 84 # newIds.defineId(featureName, featureId) 85 # newIds.nextFreeId = max(newIds.Ids.values())+1 86 # # Print statistics 87 # if verbose: 88 # print >> sys.stderr, "Filtered ids:", len(newIds.Ids), "(original", str(len(ids.Ids)) + ")" 89 # return newIds 90 91 if __name__=="__main__": 92 # Import Psyco if available 93 try: 94 import psyco 95 psyco.full() 96 print >> sys.stderr, "Found Psyco, using" 97 except ImportError: 98 print >> sys.stderr, "Psyco not installed" 99 100 from optparse import OptionParser 101 import os 102 from Utils.Parameters import * 103 optparser = OptionParser(description="Joachims SVM Multiclass classifier wrapper") 104 optparser.add_option("-e", "--examples", default=None, dest="examples", help="Example File", metavar="FILE") 105 optparser.add_option("-a", "--action", default=None, dest="action", help="TRAIN, CLASSIFY or OPTIMIZE") 106 optparser.add_option("--optimizeStep", default="BOTH", dest="optimizeStep", help="BOTH, SUBMIT or RESULTS") 107 optparser.add_option("--classifyExamples", default=None, dest="classifyExamples", help="Example File", metavar="FILE") 108 optparser.add_option("--classIds", default=None, dest="classIds", help="Class ids", metavar="FILE") 109 optparser.add_option("-m", "--model", default=None, dest="model", help="path to model file") 110 #optparser.add_option("-w", "--work", default=None, dest="work", help="Working directory for intermediate and debug files") 111 optparser.add_option("-o", "--output", default=None, dest="output", help="Output directory or file") 112 optparser.add_option("-r", "--remote", default=None, dest="remote", help="Remote connection") 113 #optparser.add_option("-c", "--classifier", default="SVMMultiClassClassifier", dest="classifier", help="Classifier Class") 114 optparser.add_option("-p", "--parameters", default=None, dest="parameters", help="Parameters for the classifier") 115 #optparser.add_option("-d", "--ids", default=None, dest="ids", help="") 116 #optparser.add_option("--filterIds", default=None, dest="filterIds", help="") 117 optparser.add_option("--install", default=None, dest="install", help="Install directory (or DEFAULT)") 118 optparser.add_option("--installFromSource", default=False, action="store_true", dest="installFromSource", help="") 119 (options, args) = optparser.parse_args() 120 121 if options.install != None: 122 downloadDir = None 123 destDir = None 124 if options.install != "DEFAULT": 125 if "," in options.install: 126 destDir, downloadDir = options.install.split(",") 127 else: 128 destDir = options.install 129 install(destDir, downloadDir, False, options.installFromSource) 130 sys.exit() 131 # elif options.filterIds != None: 132 # assert options.model != None 133 # classifier = SVMMultiClassClassifier() 134 # filteredIds = classifier.filterIds(options.filterIds, options.model, verbose=True) 135 # if options.output != None: 136 # filteredIds.write(options.output) 137 else: 138 assert options.action in ["TRAIN", "CLASSIFY", "OPTIMIZE"] 139 classifier = SVMMultiClassClassifier(Connection.getConnection(options.remote)) 140 if options.action == "TRAIN": 141 import time 142 trained = classifier.train(options.examples, options.output, options.parameters, options.classifyExamples) 143 status = trained.getStatus() 144 while status not in ["FINISHED", "FAILED"]: 145 print >> sys.stderr, "Training classifier, status =", status 146 time.sleep(10) 147 status = trained.getStatus() 148 print >> sys.stderr, "Training finished, status =", status 149 if trained.getStatus() == "FINISHED": 150 trained.downloadPredictions() 151 trained.downloadModel() 152 elif options.action == "CLASSIFY": 153 classified = classifier.classify(options.examples, options.output, options.model, True) 154 if classified.getStatus() == "FINISHED": 155 classified.downloadPredictions() 156 else: # OPTIMIZE 157 options.parameters = splitParameters(options.parameters) 158 optimized = classifier.optimize(options.examples, options.output, options.parameters, options.classifyExamples, options.classIds, step=options.optimizeStep) 159