Package TEES :: Package Core :: Module RecallAdjust
[hide private]

Source Code for Module TEES.Core.RecallAdjust

  1  """ 
  2  Trade precision for recall 
  3  """ 
  4   
  5  try: 
  6      import xml.etree.cElementTree as ET 
  7  except: 
  8      import cElementTree as ET 
  9  import sys, os 
 10  thisPath = os.path.dirname(os.path.abspath(__file__)) 
 11  sys.path.append(os.path.abspath(os.path.join(thisPath,".."))) 
 12  import Utils.ElementTreeUtils as ETUtils 
 13  import math 
 14  from optparse import OptionParser 
 15  import sys 
16 17 #Scales the value; works correctly for both positive and negative values 18 -def scaleVal(val,boost=1.0):
19 if val>=0: #non-negative case 20 val*=boost 21 else: #negative case (pretend as if it were positive, and translate back) 22 diff=abs(val)*boost-abs(val) 23 val+=diff 24 return val
25
26 -def scaleRange(val, boost, classRange):
27 if boost < 1.0 and val > 0: 28 if val < (1.0-boost) * classRange[1]: 29 return -val - 1 30 #elif boost > 1.0 and val > 0: 31 # if val < (boost-1) * classRange[1]: 32 # return -val - 1 33 return val
34
35 -def adjustEntity(entityNode,targetLabel,multiplier,classRange=None):
36 """Adjust the confidence of targetLabel in entityNode by multiplier""" 37 predictions=entityNode.get("predictions") 38 if not predictions: #nothing to do 39 return 40 maxConfidence=None 41 maxLabel=None 42 labMod=[] #list with modified "label:confidence" 43 for labelConfidence in predictions.split(","): 44 label,confidence=labelConfidence.split(":") 45 confidence=float(confidence) 46 if label!=targetLabel: #nothing to do 47 labMod.append(labelConfidence) 48 else: 49 if classRange == None: #multiclass 50 confidence=scaleVal(float(confidence),multiplier) #modify... 51 else: #binary 52 confidence=scaleRange(float(confidence),multiplier, classRange[label]) #modify... 53 labMod.append(label+":"+str(confidence)) 54 if maxConfidence==None or maxConfidence<confidence: 55 maxConfidence=confidence 56 maxLabel=label 57 58 #Done 59 entityNode.set("predictions",",".join(labMod)) 60 entityNode.set("type",maxLabel)
61
62 -def getClassRanges(entities):
63 classRanges = {} 64 for entity in entities: 65 if entity.get("isName") == "True": 66 continue 67 predictions=entity.get("predictions") 68 if predictions: 69 for labelConfidence in predictions.split(","): 70 label,confidence=labelConfidence.split(":") 71 confidence=float(confidence) 72 if not classRanges.has_key(label): 73 classRanges[label] = [sys.maxint,-sys.maxint] 74 classRanges[label] = [min(classRanges[label][0], confidence), max(classRanges[label][1], confidence)] 75 return classRanges
76
77 -def getClassRangesFromPredictions(predictions):
78 classRanges = {1:[sys.maxint,-sys.maxint], 2:[sys.maxint,-sys.maxint]} 79 for prediction in predictions: 80 for cls in [1, 2]: 81 classRanges[cls][0] = min(float(prediction[cls]), classRanges[cls][0]) 82 classRanges[cls][1] = max(float(prediction[cls]), classRanges[cls][1]) 83 return classRanges
84
85 -class RecallAdjust:
86 87 @classmethod
88 - def run(cls,inFile,multiplier=1.0,outFile=None,targetLabel="neg", binary=False):
89 """inFile can be a string with file name (.xml or .xml.gz) or an ElementTree or an Element or an open input stream 90 multiplier adjusts the level of boosting the non-negative predictions, it is a real number (0,inf) 91 multiplier 1.0 does nothing, <1.0 decreases negative class confidence, >1.0 increases negative class confidence 92 the root of the modified tree is returned and, if outFile is a string, written out to outFile as well""" 93 print >> sys.stderr, "##### Recall adjust with multiplier " + str(multiplier)[:5] + " #####" 94 tree=ETUtils.ETFromObj(inFile) 95 if not ET.iselement(tree): 96 assert isinstance(tree,ET.ElementTree) 97 root=tree.getroot() 98 else: 99 root = tree 100 101 if multiplier != -1: 102 if binary: 103 print >> sys.stderr, "Recall binary mode" 104 classRanges = getClassRanges(root.getiterator("entity")) 105 assert len(classRanges.keys()) in [0,2] 106 if len(classRanges.keys()) == 0: 107 print >> sys.stderr, "Warning, recall adjustment skipped because no prediction weights found" 108 else: 109 print >> sys.stderr, "Recall multiclass mode" 110 classRanges = None 111 for entityNode in root.getiterator("entity"): 112 adjustEntity(entityNode,targetLabel,multiplier,classRanges) 113 if outFile: 114 ETUtils.write(root,outFile) 115 return tree
116 117 if __name__=="__main__": 118 desc="Negative class adjustment in entity predictions. Reads from stdin, writes to stdout." 119 parser = OptionParser(description=desc) 120 parser.add_option("-i", "--input", default=None, dest="input", help="Predictions in interaction XML", metavar="FILE") 121 parser.add_option("-o", "--output", default=None, dest="output", help="Predictions in interaction XML", metavar="FILE") 122 parser.add_option("-l","--lambda",dest="l",action="store",default=None,type="float",help="The adjustment weight for the negative class. 1.0 does nothing, <1.0 decreases the predictions, >1.0 increases the predictions. No default.") 123 parser.add_option("-t","--targetLabel",dest="targetLabel",action="store",default="neg",help="The label of the class to be adjusted. Defaults to 'neg'.") 124 125 (options, args) = parser.parse_args() 126 127 if options.l==None: 128 print >> sys.stderr, "You need to give a lambda" 129 sys.exit(1) 130 131 #RecallAdjust.run(sys.stdin,options.l,sys.stdout,options.targetLabel) 132 RecallAdjust.run(options.input,options.l,options.output,options.targetLabel) 133