# trecgen2007_score.py
# Scoring program for the NIST TREC 2007 Genomics Track
# Includes new individual character as document passage measure PASSAGE2.
# Old measure is PASSAGE.
# Calculation of PASSAGE2 has been independently verified with parallel implementation by Fabien Campagne.
# Usage: python trecgen2007_score.py [path-to-gold-standard-file] [glob-or-path-to-run-file] > STDOUT
import sys
import glob
import os.path

# turns debugging output on and off...
DEBUG_FLAG = False

# utility function for debugging...
def debug(s):
	if DEBUG_FLAG:
		sys.stderr.write("%s\n" % s)
		
# function to remove duplicate elements from a list 
# this function must preserve order!
def removeDups(s):
	temp = set()
	out = []
	for t in s:
		if t not in temp:
			temp.add(t) 
			out.append(t) # Add the non-duplicate element to the output list while preserving the order
	return out

# function to read the run file and build the submission data structure.
# this is a dictionary of topics, each entry consisting of a list of ranked nominated passages...
def buildSubmissions(file):
	submissionDx = {}
	for line in file:
		# strip newline...
		line = line.strip()
		# skip blank lines...
		if len(line) == 0:
			continue
		# split fields by whitespace, should split by tabs but so many entrants did not
		# follow the official format that splitting by whitespace is more reliable...
		fields = line.split()
		# catch errors...		
		try:
			topic	=	int(fields[0])
			pmid	=	fields[1]
			rank	=	int(fields[2])
			offset =	int(fields[4])
			length =	int(fields[5])			
			submissionDx.setdefault(topic,[]).append((rank, pmid, offset, length))
		except Exception, e:
			# report and terminate...
			sys.stderr.write("Unable to correct improperly formated line:\n%s\n" % line)
			sys.stderr.write("Parsed fields are %s\n" % str(fields))
			raise e
	return submissionDx

# function to read the gold standard file and build the gold standard data structure.
# this is a dictionary of topics, each entry consisting of a dictionary, keyed by PMID,
# where each PMID is associated with a list of gold standard passage triples that are
# (offset,length,aspects), where offset and length are integers and aspects is a list of MeSH terms...
def buildGoldStd(file):
	goldStdDx = {}
	for line in file:
		fields = [s.strip() for s in line.split('\t')]
		if len(fields) > 4:
			topic	=	int(fields[0])
			pmid	=	fields[1]
			offset =	int(fields[2])
			length =	int(fields[3])
			# aspects in the gold standard file are separated from each other by a '|' symbol...
			aspects = [s.strip() for s in fields[4].split('|')]
			topicDx = goldStdDx.setdefault(topic,{})
			pmidLst = topicDx.setdefault(pmid,[])
			pmidLst.append((offset,length,aspects))
	return goldStdDx

# function to find unique pmids for all topicids in a dictionary,
# it is important that this function preserves rank order!
def buildUniquePmidsByTopic(submissionDx):
	uniquePmidsByTopic = {}
	for topic in submissionDx.keys():
		quartets = submissionDx[topic]
		quartets.sort() #sort the list by rank, then by pmid, then offset and then by length
		tempLst = []
		for quartet in quartets:
			pmid	=	quartet[1] 
			tempLst.append(pmid)
		# remove duplications, preserving order...
		uniquePmidsByTopic[topic] = removeDups(tempLst)
	return uniquePmidsByTopic

# return a dictionary where the keys are topics and the values are the average precision for that topic...
def calculateDocAveragePrecision(submissionDx, goldStdDx):
	averagePrecisionByTopic = {}
	uniquePmidsByTopic = buildUniquePmidsByTopic(submissionDx)
	for topic in goldStdDx:
		goldPmids = goldStdDx[topic]
		pmids = uniquePmidsByTopic.get(topic, [])
		numerator	=	0
		denominator=	0
		sumPrecision = 0.0
		for pmid in pmids:
			denominator += 1
			if pmid in goldPmids:
				numerator += 1
				# accumulate precision at each point of recall...
				precision = float(numerator)/float(denominator)
				sumPrecision += precision
				debug("DOCUMENT TOPIC:%d RANK:%d PMID:%s => relevant, %d/%d = %0.8f" % (topic, denominator, pmid, numerator, denominator, precision))
			else:
				debug("DOCUMENT TOPIC:%d RANK:%d PMID:%s => not relevant" % (topic, denominator, pmid))
				pass
		# average, adding zeros for each pmid that was not retrieved...
		averagePrecisionByTopic[topic]  = sumPrecision/len(goldPmids)
	return averagePrecisionByTopic

def findRelevantCharsByTopic(quartet, topic):
	relevantOffset	=	0
	relevantLength =	0
	relevantAspects = None
	goldPassage = None
	(rank, pmid, runOffset, runLength) = quartet
	runEnd = runOffset + runLength
	goldTriplets = goldStdDx[topic].get(pmid, [])
	for goldTriplet in goldTriplets:
		(goldOffset, goldLength, goldAspects) = goldTriplet
		goldEnd = goldOffset + goldLength
		if runOffset >= goldOffset and runEnd <= goldEnd:
			# run is fully contained in gold passage...
			relevantOffset = runOffset
			relevantLength = runLength
		elif runOffset < goldOffset and runEnd <= goldEnd and runEnd >= goldOffset:
			# run starts before the gold and ends within the gold...
			relevantOffset = goldOffset
			relevantLength =  runEnd - goldOffset
		elif  runOffset >= goldOffset and runOffset <= goldEnd and runEnd > goldEnd:
			# run starts within the gold and ends beyond the gold...
			relevantOffset = runOffset
			relevantLength =  goldEnd - runOffset
		elif runOffset < goldOffset and runEnd > goldEnd:
			# run starts before the gold and ends after the gold...
			relevantOffset = goldOffset
			relevantLength = goldLength
		if relevantLength != 0:
			relevantAspects = goldAspects
			goldPassage = (pmid, goldOffset, goldLength)
			break
	return (relevantOffset, relevantLength, relevantAspects, goldPassage)

def checkSpanOverlap(span, relevantSpans):
	(pmid, runOffset, runLength) = span
	runEnd = runOffset + runLength
	relevantChars = 0
	for (quartPmid, quartOffset, quartLength) in relevantSpans:
		if pmid != quartPmid:
			continue
		quartEnd = quartOffset + quartLength
		if runOffset >= quartOffset and runEnd <= quartEnd:
			# run is fully contained in quart passage...
			relevantChars = runLength
		elif runOffset < quartOffset and runEnd <= quartEnd and runEnd >= quartOffset:
			# run starts before the quart and ends within the quart...
			relevantChars =  runEnd - quartOffset
		elif  runOffset >= quartOffset and runOffset <= quartEnd and runEnd > quartEnd:
			# run starts within the quart and ends beyond the quart...
			relevantChars =  quartEnd - runOffset
		elif runOffset < quartOffset and runEnd > quartEnd:
			# run starts before the quart and ends after the quart...
			relevantChars = quartLength
		if relevantChars != 0:
			break
	return relevantChars

# return a dictionary where the keys are topics and the values are the passage average precision for that topic...
def calculatePassageAveragePrecision(submissionDx, goldStdDx):
	averagePrecisionByTopic = {}	
	for topic in goldStdDx:
		quartets = submissionDx.get(topic, [])
		quartets.sort() #sort the list by rank, then by pmid, then offset and then by length
		numerator = 0
		denominator = 0
		sumPrecision = 0.0
		count = 0
		foundGoldTriplets = set() # keep track of gold passages found
		relevantSpans = []
		for quartet in quartets:
			(relevantOffset, relevantChars, relevantAspects, goldPassage) = findRelevantCharsByTopic(quartet,topic)
			denominator += quartet[3]
			if relevantChars > 0:
				## check for "double-retrieval" of gold passage characters...
				## ideally, these should be removed from scoring, but none of the top rated systems
				## did this, and the effect on scoring appears minor.
				##if goldPassage in foundGoldTriplets and checkSpanOverlap(quartet[1:], relevantSpans) > 0:
				##	sys.stderr.write("WARNING: Relevant passage characters retrieved more than once!\n")
				##relevantSpans.append((quartet[1], relevantOffset, relevantChars))
				# accumulate precision at each point of recall...
				count += 1
				numerator += relevantChars
				sumPrecision += float(numerator)/float(denominator)
				foundGoldTriplets.add(goldPassage)
				debug("PASSAGE TOPIC:%d SPAN:%s.%d.%d => relevant, %d of %d chars" % (topic, quartet[1], quartet[2], quartet[3], relevantChars, quartet[3]))
			else:
				pass
				debug("PASSAGE TOPIC:%d SPAN:%s.%d.%d => not relevant, %d chars" % (topic, quartet[1], quartet[2], quartet[3], quartet[3]))				
		# determine which gold standard passages where not retrieved...
		allGoldPassages = set()
		for pmid in goldStdDx[topic]:
			for (offset, length, aspects) in goldStdDx[topic][pmid]:
				allGoldPassages.add((pmid, offset, length))
		numZeros = len(allGoldPassages.difference(foundGoldTriplets))		
		# compute the average precision for the topic...
		averagePrecisionByTopic[topic] = sumPrecision/(count + numZeros)
		debug("PASSAGE TOPIC:%d RETRIEVED %d of %d RELEVANT PASSAGES" % (topic, count, count + numZeros))
	return averagePrecisionByTopic

def findUniqueAspectsByTopic(topic):
	uniqueAspectsByTopic = set()
	for pmid in goldStdDx[topic]:
		goldTriplets = goldStdDx[topic][pmid]
		for (offset,length,aspects) in goldTriplets:
			uniqueAspectsByTopic.update(aspects)
	return uniqueAspectsByTopic

# return a dictionary where the keys are topics and the values are the aspect average precision for that topic...
def calculateAspectAveragePrecision(submissionDx,goldStdDx):
	averagePrecisionByTopic = {}	
	for topic in goldStdDx:
		aspectsFound = set()
		uniqueAspectsByTopic = findUniqueAspectsByTopic(topic)
		quartets = submissionDx.get(topic, [])
		quartets.sort() #sort the list by rank, then by pmid, then offset and then by length
		numerator = 0
		denominator = 0
		sumPrecision = 0.0
		count = 0
		debug("ASPECT FOR TOPIC:%d = %s" % (topic, str(uniqueAspectsByTopic)))
		for quartet in quartets:
			(relevantOffset, relevantChars, relevantAspects, goldPassage) = findRelevantCharsByTopic(quartet, topic)
			if relevantAspects:
				# only look at aspects that we have not yet seen...
				numNewAspects  = len(set(relevantAspects).difference(aspectsFound))
				if numNewAspects > 0:
					# all aspects at this rank get same numerator and denominator...
					numerator += 1
					denominator += 1
					sumPrecision +=numNewAspects*float(numerator)/float(denominator)
					debug("ASPECT TOPIC:%d NEW ASPECTS IN %s, %d of %d times %d" % (topic, str(relevantAspects), numerator, denominator, numNewAspects))
				else:
					# old aspects, but no new aspects, do not increase denominator...
					debug("ASPECT TOPIC:%d NO NEW ASPECTS IN %s" % (topic, str(relevantAspects)))
					pass
				# update aspects found list...
				aspectsFound.update(relevantAspects)
			else:
				# no aspects for this non-relevant passage...
				denominator += 1
				##debug("ASPECT TOPIC:%d NON-RELEVANT, %d of %d" % (topic, numerator, denominator))
		# compute average precision, divisor is the number of unique aspects for this topic,
		# this will automatically add in zeros for aspects not found...
		debug("ASPECT TOPIC:%d DID NOT RETRIEVE %d ASPECTS, MISSED %s" % (topic, len(uniqueAspectsByTopic) - len(aspectsFound), str(set(uniqueAspectsByTopic).difference(aspectsFound))))
		averagePrecisionByTopic[topic] = sumPrecision/len(uniqueAspectsByTopic)
	return averagePrecisionByTopic

# returns True of False depending on whether passage contains any relevant characters or not...
def checkForAnyRelevantChars(topic, pmid, runOffset, runLength):
	if not pmid in goldStdDx[topic]:
		return False
	else:
		runEnd = runOffset + runLength
		for (offset, length, aspects) in goldStdDx[topic][pmid]:
			end = offset + length
			if runOffset >= offset and runEnd <= end:
				# run is fully contained in quart passage...
				return True
			elif runOffset < offset and runEnd <= end and runEnd >= offset:
				# run starts before the quart and ends within the quart...
				return True
			elif  runOffset >= offset and runOffset <= end and runEnd > end:
				# run starts within the quart and ends beyond the quart...
				return True
			elif runOffset < offset and runEnd > end:
				# run starts before the quart and ends after the quart...
				return True
		return False	
	
# trackRelevantChars() has three possible return values:
# 	+1 means character was relevant and not previously seen
# 	-1 means character was not-relevant
# 	0 means character was relevant but was previously seen
def trackRelevantChars(topic, pmid, offset, topicGoldStdDx, trackingGoldStdDx):
	if pmid in topicGoldStdDx and offset in topicGoldStdDx[pmid]:
		# character is relevant, have we seen it before?		
		if offset in trackingGoldStdDx[pmid]:
			trackingGoldStdDx[pmid].remove(offset)
			return +1
		else:
			### sys.stderr.write("DUPLICATE RELEVANT CHARACTER\n")
			return 0			
	else:
		# character is not relevant...
		return -1
		
# return a dictionary where the keys are topics and the values are the passage average precision for that topic...
def calculatePassage2AveragePrecision(submissionDx, goldStdDx):
	averagePrecisionByTopic = {}	
	for topic in goldStdDx:
		# initialize...
		numerator = 0
		denominator = 0
		sumPrecision = 0.0
		topicGoldStdDx = {}
		trackingGoldStdDx = {}
		for pmid in goldStdDx[topic]:
			for (offset,length,aspects) in goldStdDx[topic][pmid]:
				r = range(offset, offset+length)
				topicGoldStdDx.setdefault(pmid, set()).update(r)
				trackingGoldStdDx.setdefault(pmid, set()).update(r)
		quartets = submissionDx.get(topic, [])
		quartets.sort() #sort the submission list by rank, then by pmid, then offset and then by length
		for quartet in quartets:
			# pass through each character, determining it's relevance...
			(rank, pmid, runOffset, runLength) = quartet
			# check for any relevant characters, this is not strictly necessary, but it does provide a significant
			# speedup, which is nice because the character-by-character accumulation is so slow...
			if not checkForAnyRelevantChars(topic, pmid, runOffset, runLength):
				# no relevant characters...
				denominator += runLength
			else:
				# some relevant characters...
				for offset in range(runOffset, runOffset+runLength):
					# three values:
					# 	+1 means character was relevant and not previously seen
					# 	-1 means character was not-relevant
					# 	0 means character was relevant but was previously seen
					val = trackRelevantChars(topic, pmid, offset, topicGoldStdDx, trackingGoldStdDx)
					if val != 0:
						# only count characters that have not been previously seen...
						denominator += 1
						if val == 1:				
							# only accumulate precision at relevant characters...
							numerator += 1
							sumPrecision += float(numerator)/float(denominator)
		# loop through trackingGoldStdDx counting the number of relevant characters, 
		# this is used as the denominator for average precision in order to effectively include zeros
		# for relevant characters that were not retrieved...
		count = 0
		for pmid in topicGoldStdDx:
			count += len(topicGoldStdDx[pmid])
		# note that this should be the same as the number of relevant characters found plus the number
		# of remaining characters in trackingGoldStdDx...
		### count2 = numerator
		### for pmid in trackingGoldStdDx:
		### 	count2 += len(trackingGoldStdDx[pmid])
		### assert(count == count2)	
		
		# compute the average precision for the topic...
		averagePrecisionByTopic[topic] = sumPrecision/count
	return averagePrecisionByTopic

# beginning of main program...

# check arguments...
if len(sys.argv) != 3:
	sys.stderr.write("usage:python %s [path-to-gold-standard-file] [glob-or-path-to-run-file] > STDOUT\n" % sys.argv[0])
	sys.exit(0)
	
# save paths...	
goldStandardFile = sys.argv[1]
submissionFileGlob = sys.argv[2]

# load gold standard data...
file = open(goldStandardFile, 'r')
goldStdDx = buildGoldStd(file)
file.close()

# output result header...
print '\t'.join(["RUN", "MEASURE", "TOPIC", "SCORE"])

# loop over submission files, which can be a multifile-glob...
for submissionFile in glob.glob(submissionFileGlob):
	# save run name as shortname of submission file...
	runName = (os.path.split(submissionFile)[1]).split('.')[0]
	# progress report...
	sys.stderr.write("Processing run %s..." % runName)
	# load submission data...
	file = open(submissionFile,'r')
	submissionDx = buildSubmissions(file)
	file.close()
	# compute topic scores...
	documentAveragePrecisionByTopic = calculateDocAveragePrecision(submissionDx,goldStdDx)
	passageAveragePrecisionByTopic = calculatePassageAveragePrecision(submissionDx,goldStdDx)
	aspectAveragePrecisionByTopic = calculateAspectAveragePrecision(submissionDx,goldStdDx)
	passage2AveragePrecisionByTopic = calculatePassage2AveragePrecision(submissionDx,goldStdDx)
	# compute final MAP's....
	documentMAP = sum(documentAveragePrecisionByTopic.values())/len(documentAveragePrecisionByTopic)
	passageMAP = sum(passageAveragePrecisionByTopic.values())/len(passageAveragePrecisionByTopic)
	aspectMAP = sum(aspectAveragePrecisionByTopic.values())/len(aspectAveragePrecisionByTopic)
	passage2MAP = sum(passage2AveragePrecisionByTopic.values())/len(passage2AveragePrecisionByTopic)
	# output document results...
	for topic in sorted(goldStdDx.keys()):
		print '\t'.join([runName, "DOCUMENT", str(topic), "%0.8f" % documentAveragePrecisionByTopic[topic]])
	print "%s" % '\t'.join([runName, "DOCUMENT", "MAP", "%0.8f" % documentMAP])
	# output passage results...
	for topic in sorted(goldStdDx.keys()):
		print '\t'.join([runName, "PASSAGE", str(topic), "%0.8f" % passageAveragePrecisionByTopic[topic]])
	print "%s" % '\t'.join([runName, "PASSAGE", "MAP", "%0.8f" % passageMAP])
	# output aspect results...
	for topic in sorted(goldStdDx.keys()):
		print '\t'.join([runName, "ASPECT", str(topic), "%0.8f" % aspectAveragePrecisionByTopic[topic]])
	print "%s" % '\t'.join([runName, "ASPECT", "MAP", "%0.8f" % aspectMAP])
	# output passage2 results...
	for topic in sorted(goldStdDx.keys()):
		print '\t'.join([runName, "PASSAGE2", str(topic), "%0.8f" % passage2AveragePrecisionByTopic[topic]])
	print "%s" % '\t'.join([runName, "PASSAGE2", "MAP", "%0.8f" % passage2MAP])
	# progress report...
	sys.stderr.write("OK.\n")

