Package TEES :: Package Core :: Module ExampleUtils
[hide private]

Source Code for Module TEES.Core.ExampleUtils

  1  """ 
  2  Tools for writing and reading classifier example files 
  3   
  4  These functions read and write machine learning example files and convert 
  5  examples into final data forms. The memory representation for each 
  6  example is a 4-tuple (or list) of the format: (id, class, features, extra). id is a string, 
  7  class is an int (-1 or +1 for binary) and features is a dictionary of int:float -pairs, where 
  8  the int is the feature id and the float is the feature value. 
  9  Extra is a dictionary of String:String pairs, for additional information about the  
 10  examples. 
 11  """ 
 12   
 13  import sys, os, itertools 
 14  import Split 
 15  import types 
 16  #from IdSet import IdSet 
 17  #thisPath = os.path.dirname(os.path.abspath(__file__)) 
 18  #sys.path.append(os.path.abspath(os.path.join(thisPath,".."))) 
 19  #import Utils.InteractionXML.IDUtils as IDUtils 
 20  #import Utils.Libraries.combine as combine 
 21  import types 
 22  import gzip 
 23  #try: 
 24  #    import xml.etree.cElementTree as ET 
 25  #except ImportError: 
 26  #    import cElementTree as ET 
 27  #import Utils.ElementTreeUtils as ETUtils 
 28  import RecallAdjust 
29 30 -def gen2iterable(genfunc):
31 """ 32 Makes a multi-use iterator generator. See http://bugs.python.org/issue5973 33 for details. 34 """ 35 def wrapper(*args, **kwargs): 36 class _iterable(object): 37 def __iter__(self): 38 return genfunc(*args, **kwargs)
39 return _iterable() 40 return wrapper 41
42 -def isDuplicate(example1, example2):
43 if example1[1] != example2[1]: 44 return False 45 if example1[2] != example2[2]: 46 return False 47 return True
48
49 -def removeDuplicates(examples):
50 """ removes all but one of the examples that have the same class and identical feature vectors""" 51 duplicateList = [False] * len(examples) 52 for i in range(len(examples)): 53 if not duplicateList[i]: 54 for j in range(i+1, len(examples)): 55 if not duplicateList[j]: 56 if isDuplicate(examples[i], examples[j]): 57 duplicateList[j] = True 58 newExamples = [] 59 for i in range(len(examples)): 60 if not duplicateList[i]: 61 newExamples.append(examples[i]) 62 return newExamples
63
64 -def normalizeFeatureVectors(examples):
65 for example in examples: 66 # Normalize features 67 total = 0.0 68 for v in example[2].values(): total += abs(v) 69 if total == 0.0: 70 total = 1.0 71 for k,v in example[2].iteritems(): 72 example[2][k] = float(v) / total
73
74 -def copyExamples(examples):
75 examplesCopy = [] 76 for example in examples: 77 examplesCopy.append([example[0], example[1], example[2].copy(), example[3]]) 78 return examplesCopy
79
80 -def appendExamples(examples, file):
81 noneClassCount = 0 82 for example in examples: 83 # None-value as a class indicates a class that did not match an existing id, 84 # in situations where new ids cannot be defined, such as predicting. An example 85 # with class == None should never get this far, ideally it should be filtered 86 # in the ExampleBuilder, but this at least prevents a crash. 87 if example[1] == None: 88 noneClassCount += 1 89 continue 90 # Write class 91 file.write(str(example[1])) 92 # Get and sort feature ids 93 keys = example[2].keys() 94 keys.sort() 95 # None-value as a key indicates a feature that did not match an existing id, 96 # in situations where new ids cannot be defined, such as predicting 97 if None in example[2]: 98 keys.remove(None) 99 # Write features 100 for key in keys: 101 file.write(" " + str(key)+":"+str(example[2][key])) 102 # Write comment area 103 file.write(" # id:" + example[0]) 104 for extraKey, extraValue in example[3].iteritems(): 105 assert(extraKey != "id") # id must be defined as example[0] 106 if type(extraValue) == types.StringType: 107 file.write( " " + str(extraKey) + ":" + extraValue) 108 file.write("\n") 109 if noneClassCount != 0: 110 print >> sys.stderr, "Warning,", noneClassCount, "examples had an undefined class."
111
112 -def appendExamplesBinary(examples, file):
113 import struct 114 for example in examples: 115 #file.write(str(example[1])) 116 keys = example[2].keys() 117 keys.sort() 118 file.write(struct.pack("1i", len(keys))) 119 file.write(struct.pack(str(len(keys))+"i", *keys))
120 #for key in keys:
121 # file.write(" " + str(key)+":"+str(example[2][key])) 122 #file.write(" # id:" + example[0]) 123 #for extraKey, extraValue in example[3].iteritems(): 124 # assert(extraKey != "id") 125 # if type(extraValue) == types.StringType: 126 # file.write( " " + str(extraKey) + ":" + extraValue) 127 #file.write("\n") 128 129 -def writeExamples(examples, filename, commentLines=None):
130 if filename.endswith(".gz"): 131 f = gzip.open(filename,"wt") 132 else: 133 f = open(filename,"wt") 134 if commentLines != None: 135 for commentLine in commentLines: 136 f.write("# "+commentLine+"\n") 137 appendExamples(examples, f) 138 f.close()
139
140 -def writePredictions(predictions, exampleFileName):
141 if exampleFileName.endswith(".gz"): 142 f = gzip.open(exampleFileName,"wt") 143 else: 144 f = open(exampleFileName,"wt") 145 exampleLines = f.readlines() 146 f.close() 147 for line in exampleLines: 148 if line[0] != "#": 149 break 150 if line.find("#commentColumns:") != -1: 151 pass
152
153 -def getIdsFromFile(filename):
154 if filename.endswith(".gz"): 155 f = gzip.open(filename,"rt") 156 else: 157 f = open(filename,"rt") 158 ids = [] 159 for line in f.readlines(): 160 if line[0] == "#": 161 continue 162 splits = line.rsplit("#", 1) 163 ids.append( splits[-1].strip() ) 164 return ids
165
166 @gen2iterable 167 -def readExamples(filename, readFeatures=True):
168 if filename.endswith(".gz"): 169 f = gzip.open(filename,"rt") 170 else: 171 f = open(filename,"rt") 172 #try: 173 for line in f: 174 if line[0] == "#": 175 continue 176 splits = line.split("#") 177 commentSplits = splits[-1].split() 178 id = None 179 extra = {} 180 for commentSplit in commentSplits: 181 #if commentSplit.find("id:") == -1: 182 # continue 183 key, value = commentSplit.split(":") 184 if key == "id": 185 id = value 186 else: 187 extra[key] = value 188 splits2 = splits[0].split() 189 classId = int(splits2[0]) 190 features = {} 191 if readFeatures: 192 for item in splits2[1:]: 193 featureId, featureValue = item.split(":") 194 features[int(featureId)] = float(featureValue) 195 yield [id,classId,features,extra] 196 #finally: 197 f.close()
198
199 -def makeCorpusDivision(corpusElements, fraction=0.5, seed=0):
200 documentIds = corpusElements.documentsById.keys() 201 return makeDivision(documentIds, fraction, seed)
202
203 -def makeCorpusFolds(corpusElements, folds=10):
204 documentIds = corpusElements.documentsById.keys() 205 return makeFolds(documentIds, folds)
206
207 -def makeExampleDivision(examples, fraction=0.5):
208 documentIds = set() 209 for example in examples: 210 documentIds.add(example[0].rsplit(".",2)[0]) 211 documentIds = list(documentIds) 212 return makeDivision(documentIds, fraction)
213
214 -def makeExampleFolds(examples, folds=10):
215 documentIds = set() 216 for example in examples: 217 documentIds.add(example[0].rsplit(".",2)[0]) 218 documentIds = list(documentIds) 219 return makeFolds(documentIds, folds)
220
221 -def makeDivision(ids, fraction=0.5, seed=0):
222 sample = Split.getSample(len(ids),fraction, seed) 223 division = {} 224 for i in range(len(ids)): 225 division[ids[i]] = sample[i] 226 return division
227
228 -def makeFolds(ids, folds=10):
229 sample = Split.getFolds(len(ids),folds) 230 division = {} 231 for i in range(len(ids)): 232 division[ids[i]] = sample[i] 233 return division
234
235 -def divideExamples(examples, division=None):
236 if division == None: 237 division = makeExampleDivision(examples) 238 239 exampleSets = {} 240 for example in examples: 241 documentId = example[0].rsplit(".",2)[0] 242 if division.has_key(documentId): 243 if not exampleSets.has_key(division[documentId]): 244 exampleSets[division[documentId]] = [] 245 exampleSets[division[documentId]].append(example) 246 return exampleSets
247
248 -def divideExampleFile(exampleFileName, division, outputDir):
249 if exampleFileName.endswith(".gz"): 250 f = gzip.open(exampleFileName,"rt") 251 else: 252 f = open(exampleFileName,"rt") 253 lines = f.readlines() 254 f.close() 255 256 divisionFiles = {} 257 for line in lines: 258 if line[0] == "#": 259 continue 260 id = line.split("#")[-1].strip() 261 documentId = id.rsplit(".",2)[0] 262 if not divisionFiles.has_key(division[documentId]): 263 divisionFiles[division[documentId]] = open(outputDir+"/set"+str(division[documentId]),"wt") 264 divisionFiles[division[documentId]].write(line) 265 for v in divisionFiles.values(): 266 v.close()
267
268 #@gen2iterable 269 #def loadPredictions(predictionsFile): 270 # if predictionsFile.endswith(".gz"): 271 # f = gzip.open(predictionsFile,"rt") 272 # else: 273 # f = open(predictionsFile,"rt") 274 # #try: 275 # for line in f: 276 # splits = line.split() 277 # if len(splits) == 1: 278 # yield [float(splits[0])] 279 # else: # multiclass 280 # if "," in splits[0]: # multilabel 281 # pred = [[]] 282 # for value in splits[0].split(","): 283 # pred[0].append(int(value)) 284 # else: 285 # pred = [int(splits[0])] 286 # for split in splits[1:]: 287 # if split != "N/A": 288 # split = float(split) 289 # pred.append(split) 290 # yield pred 291 # #finally: 292 # f.close() 293 294 @gen2iterable 295 -def loadPredictions(predictionsFile, recallAdjust=None, classRanges=None, threshold=None):
296 if predictionsFile.endswith(".gz"): 297 f = gzip.open(predictionsFile,"rt") 298 else: 299 f = open(predictionsFile,"rt") 300 #try: 301 for line in f: 302 splits = line.split() 303 if len(splits) == 1: # true binary 304 assert recallAdjust == None or recallAdjust == 1.0 # not implemented for binary classification 305 yield [float(splits[0])] 306 elif len(splits) == 3 and (recallAdjust != None and recallAdjust != 1.0) and classRanges == None: # SVM multiclass two class "binary" classification 307 # Go through all the predictions to get the ranges 308 predictions = [splits] 309 for line in f: 310 predictions.append(line.split()) 311 f.close() # end first iteration 312 classRanges = RecallAdjust.getClassRangesFromPredictions(predictions) 313 # Load predictions again with the range information 314 for yieldedValue in loadPredictions(predictionsFile, recallAdjust, classRanges): 315 yield yieldedValue 316 break 317 else: # multiclass 318 if "," in splits[0]: # multilabel 319 pred = [[]] 320 for value in splits[0].split(","): 321 pred[0].append(int(value)) 322 else: 323 pred = [int(splits[0])] 324 for split in splits[1:]: 325 if split != "N/A": 326 split = float(split) 327 pred.append(split) 328 # Recall adjust 329 if recallAdjust != None and recallAdjust != 1.0: 330 if classRanges == None: 331 pred[1] = RecallAdjust.scaleVal(pred[1], recallAdjust) 332 else: # SVM multiclass two class "binary" classification 333 pred[1] = RecallAdjust.scaleRange(pred[1], recallAdjust, classRanges[1]) 334 #if pred[0] == 1: 335 maxStrength = pred[1] 336 pred[0] = 1 337 for i in range(2, len(pred)): 338 if pred[i] > maxStrength: 339 maxStrength = pred[i] 340 pred[0] = i 341 # Thresholding 342 if threshold != None: 343 if pred[1] > threshold: 344 pred[0] = 1 345 else: 346 maxStrength = pred[2] 347 pred[0] = 2 348 for i in range(2, len(pred)): 349 if pred[i] > maxStrength: 350 maxStrength = pred[i] 351 pred[0] = i 352 # Return the prediction 353 yield pred 354 #finally: 355 f.close()
356