1 """
2 Trade precision for recall
3 """
4
5 try:
6 import xml.etree.cElementTree as ET
7 except:
8 import cElementTree as ET
9 import sys, os
10 thisPath = os.path.dirname(os.path.abspath(__file__))
11 sys.path.append(os.path.abspath(os.path.join(thisPath,"..")))
12 import Utils.ElementTreeUtils as ETUtils
13 import math
14 from optparse import OptionParser
15 import sys
19 if val>=0:
20 val*=boost
21 else:
22 diff=abs(val)*boost-abs(val)
23 val+=diff
24 return val
25
27 if boost < 1.0 and val > 0:
28 if val < (1.0-boost) * classRange[1]:
29 return -val - 1
30
31
32
33 return val
34
35 -def adjustEntity(entityNode,targetLabel,multiplier,classRange=None):
36 """Adjust the confidence of targetLabel in entityNode by multiplier"""
37 predictions=entityNode.get("predictions")
38 if not predictions:
39 return
40 maxConfidence=None
41 maxLabel=None
42 labMod=[]
43 for labelConfidence in predictions.split(","):
44 label,confidence=labelConfidence.split(":")
45 confidence=float(confidence)
46 if label!=targetLabel:
47 labMod.append(labelConfidence)
48 else:
49 if classRange == None:
50 confidence=scaleVal(float(confidence),multiplier)
51 else:
52 confidence=scaleRange(float(confidence),multiplier, classRange[label])
53 labMod.append(label+":"+str(confidence))
54 if maxConfidence==None or maxConfidence<confidence:
55 maxConfidence=confidence
56 maxLabel=label
57
58
59 entityNode.set("predictions",",".join(labMod))
60 entityNode.set("type",maxLabel)
61
63 classRanges = {}
64 for entity in entities:
65 if entity.get("isName") == "True":
66 continue
67 predictions=entity.get("predictions")
68 if predictions:
69 for labelConfidence in predictions.split(","):
70 label,confidence=labelConfidence.split(":")
71 confidence=float(confidence)
72 if not classRanges.has_key(label):
73 classRanges[label] = [sys.maxint,-sys.maxint]
74 classRanges[label] = [min(classRanges[label][0], confidence), max(classRanges[label][1], confidence)]
75 return classRanges
76
78 classRanges = {1:[sys.maxint,-sys.maxint], 2:[sys.maxint,-sys.maxint]}
79 for prediction in predictions:
80 for cls in [1, 2]:
81 classRanges[cls][0] = min(float(prediction[cls]), classRanges[cls][0])
82 classRanges[cls][1] = max(float(prediction[cls]), classRanges[cls][1])
83 return classRanges
84
86
87 @classmethod
88 - def run(cls,inFile,multiplier=1.0,outFile=None,targetLabel="neg", binary=False):
89 """inFile can be a string with file name (.xml or .xml.gz) or an ElementTree or an Element or an open input stream
90 multiplier adjusts the level of boosting the non-negative predictions, it is a real number (0,inf)
91 multiplier 1.0 does nothing, <1.0 decreases negative class confidence, >1.0 increases negative class confidence
92 the root of the modified tree is returned and, if outFile is a string, written out to outFile as well"""
93 print >> sys.stderr, "##### Recall adjust with multiplier " + str(multiplier)[:5] + " #####"
94 tree=ETUtils.ETFromObj(inFile)
95 if not ET.iselement(tree):
96 assert isinstance(tree,ET.ElementTree)
97 root=tree.getroot()
98 else:
99 root = tree
100
101 if multiplier != -1:
102 if binary:
103 print >> sys.stderr, "Recall binary mode"
104 classRanges = getClassRanges(root.getiterator("entity"))
105 assert len(classRanges.keys()) in [0,2]
106 if len(classRanges.keys()) == 0:
107 print >> sys.stderr, "Warning, recall adjustment skipped because no prediction weights found"
108 else:
109 print >> sys.stderr, "Recall multiclass mode"
110 classRanges = None
111 for entityNode in root.getiterator("entity"):
112 adjustEntity(entityNode,targetLabel,multiplier,classRanges)
113 if outFile:
114 ETUtils.write(root,outFile)
115 return tree
116
117 if __name__=="__main__":
118 desc="Negative class adjustment in entity predictions. Reads from stdin, writes to stdout."
119 parser = OptionParser(description=desc)
120 parser.add_option("-i", "--input", default=None, dest="input", help="Predictions in interaction XML", metavar="FILE")
121 parser.add_option("-o", "--output", default=None, dest="output", help="Predictions in interaction XML", metavar="FILE")
122 parser.add_option("-l","--lambda",dest="l",action="store",default=None,type="float",help="The adjustment weight for the negative class. 1.0 does nothing, <1.0 decreases the predictions, >1.0 increases the predictions. No default.")
123 parser.add_option("-t","--targetLabel",dest="targetLabel",action="store",default="neg",help="The label of the class to be adjusted. Defaults to 'neg'.")
124
125 (options, args) = parser.parse_args()
126
127 if options.l==None:
128 print >> sys.stderr, "You need to give a lambda"
129 sys.exit(1)
130
131
132 RecallAdjust.run(options.input,options.l,options.output,options.targetLabel)
133