1  """ 
  2  Trade precision for recall 
  3  """ 
  4   
  5  try: 
  6      import xml.etree.cElementTree as ET 
  7  except: 
  8      import cElementTree as ET 
  9  import sys, os 
 10  thisPath = os.path.dirname(os.path.abspath(__file__)) 
 11  sys.path.append(os.path.abspath(os.path.join(thisPath,".."))) 
 12  import Utils.ElementTreeUtils as ETUtils 
 13  import math 
 14  from optparse import OptionParser 
 15  import sys 
 19      if val>=0:  
 20          val*=boost 
 21      else:  
 22          diff=abs(val)*boost-abs(val) 
 23          val+=diff 
 24      return val 
  25   
 27      if boost < 1.0 and val > 0: 
 28          if val < (1.0-boost) * classRange[1]: 
 29              return -val - 1 
 30       
 31       
 32       
 33      return val 
  34   
 35 -def adjustEntity(entityNode,targetLabel,multiplier,classRange=None): 
  36      """Adjust the confidence of targetLabel in entityNode by multiplier""" 
 37      predictions=entityNode.get("predictions") 
 38      if not predictions:  
 39          return 
 40      maxConfidence=None 
 41      maxLabel=None 
 42      labMod=[]  
 43      for labelConfidence in predictions.split(","): 
 44          label,confidence=labelConfidence.split(":") 
 45          confidence=float(confidence) 
 46          if label!=targetLabel:  
 47              labMod.append(labelConfidence) 
 48          else: 
 49              if classRange == None:  
 50                  confidence=scaleVal(float(confidence),multiplier)  
 51              else:  
 52                  confidence=scaleRange(float(confidence),multiplier, classRange[label])  
 53              labMod.append(label+":"+str(confidence)) 
 54          if maxConfidence==None or maxConfidence<confidence: 
 55              maxConfidence=confidence 
 56              maxLabel=label 
 57   
 58       
 59      entityNode.set("predictions",",".join(labMod)) 
 60      entityNode.set("type",maxLabel) 
  61       
 63      classRanges = {} 
 64      for entity in entities: 
 65          if entity.get("isName") == "True": 
 66              continue 
 67          predictions=entity.get("predictions") 
 68          if predictions: 
 69              for labelConfidence in predictions.split(","): 
 70                  label,confidence=labelConfidence.split(":") 
 71                  confidence=float(confidence) 
 72                  if not classRanges.has_key(label): 
 73                      classRanges[label] = [sys.maxint,-sys.maxint] 
 74                  classRanges[label] = [min(classRanges[label][0], confidence), max(classRanges[label][1], confidence)] 
 75      return classRanges 
  76   
 78      classRanges = {1:[sys.maxint,-sys.maxint], 2:[sys.maxint,-sys.maxint]} 
 79      for prediction in predictions: 
 80          for cls in [1, 2]: 
 81              classRanges[cls][0] = min(float(prediction[cls]), classRanges[cls][0]) 
 82              classRanges[cls][1] = max(float(prediction[cls]), classRanges[cls][1]) 
 83      return classRanges        
  84   
 86   
 87      @classmethod 
 88 -    def run(cls,inFile,multiplier=1.0,outFile=None,targetLabel="neg", binary=False): 
  89          """inFile can be a string with file name (.xml or .xml.gz) or an ElementTree or an Element or an open input stream 
 90          multiplier adjusts the level of boosting the non-negative predictions, it is a real number (0,inf) 
 91          multiplier 1.0 does nothing, <1.0 decreases negative class confidence, >1.0 increases negative class confidence 
 92          the root of the modified tree is returned and, if outFile is a string, written out to outFile as well""" 
 93          print >> sys.stderr, "##### Recall adjust with multiplier " + str(multiplier)[:5] + " #####" 
 94          tree=ETUtils.ETFromObj(inFile) 
 95          if not ET.iselement(tree): 
 96              assert isinstance(tree,ET.ElementTree) 
 97              root=tree.getroot() 
 98          else: 
 99              root = tree 
100           
101          if multiplier != -1: 
102              if binary: 
103                  print >> sys.stderr, "Recall binary mode" 
104                  classRanges = getClassRanges(root.getiterator("entity")) 
105                  assert len(classRanges.keys()) in [0,2] 
106                  if len(classRanges.keys()) == 0: 
107                      print >> sys.stderr, "Warning, recall adjustment skipped because no prediction weights found" 
108              else: 
109                  print >> sys.stderr, "Recall multiclass mode" 
110                  classRanges = None 
111              for entityNode in root.getiterator("entity"): 
112                  adjustEntity(entityNode,targetLabel,multiplier,classRanges) 
113          if outFile: 
114              ETUtils.write(root,outFile) 
115          return tree 
 116   
117  if __name__=="__main__": 
118      desc="Negative class adjustment in entity predictions. Reads from stdin, writes to stdout." 
119      parser = OptionParser(description=desc) 
120      parser.add_option("-i", "--input", default=None, dest="input", help="Predictions in interaction XML", metavar="FILE") 
121      parser.add_option("-o", "--output", default=None, dest="output", help="Predictions in interaction XML", metavar="FILE") 
122      parser.add_option("-l","--lambda",dest="l",action="store",default=None,type="float",help="The adjustment weight for the negative class. 1.0 does nothing, <1.0 decreases the predictions, >1.0 increases the predictions. No default.") 
123      parser.add_option("-t","--targetLabel",dest="targetLabel",action="store",default="neg",help="The label of the class to be adjusted. Defaults to 'neg'.") 
124   
125      (options, args) = parser.parse_args() 
126   
127      if options.l==None: 
128          print >> sys.stderr, "You need to give a lambda" 
129          sys.exit(1) 
130   
131       
132      RecallAdjust.run(options.input,options.l,options.output,options.targetLabel) 
133