for txt_file in ground_truth_files_list:#遍历ground_truthfile_id = txt_file.split(".txt", 1)[0]file_id = os.path.basename(os.path.normpath(file_id))temp_path = os.path.join(DR_PATH, (file_id + ".txt"))if not os.path.exists(temp_path):error_msg = "Error. File not found: {}\n".format(temp_path)error(error_msg)lines_list = file_lines_to_list(txt_file)#每行读取bounding_boxes = []is_difficult = Falsealready_seen_classes = []for line in lines_list:try:if "difficult" in line:class_name, left, top, right, bottom, _difficult = line.split()is_difficult = Trueelse:class_name, left, top, right, bottom = line.split()except:if "difficult" in line:line_split = line.split()_difficult = line_split[-1]bottom = line_split[-2]right = line_split[-3]top = line_split[-4]left = line_split[-5]class_name = ""for name in line_split[:-5]:class_name += name + " "class_name = class_name[:-1]is_difficult = Trueelse:line_split = line.split()bottom = line_split[-1]right = line_split[-2]top = line_split[-3]left = line_split[-4]class_name = ""for name in line_split[:-4]:class_name += name + " "class_name = class_name[:-1]bbox = left + " " + top + " " + right + " " + bottomif is_difficult:bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})is_difficult = Falseelse:bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})if class_name in gt_counter_per_class:gt_counter_per_class[class_name] += 1else:gt_counter_per_class[class_name] = 1if class_name not in already_seen_classes:if class_name in counter_images_per_class:counter_images_per_class[class_name] += 1else:counter_images_per_class[class_name] = 1already_seen_classes.append(class_name)with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:json.dump(bounding_boxes, outfile)gt_classes = list(gt_counter_per_class.keys())
gt_classes = sorted(gt_classes)
n_classes = len(gt_classes)
读取预测结果内容
这里我们梳理一下其运行流程: 最外层循环:for class_index, class_name in enumerate(gt_classes) 按照分类类别进行循环 第二层循环:for idx, detection in enumerate(dr_data) 按照预测文件名称进行循环,如1.txt,2.txt 第三层循环: for obj in ground_truth_data 从真实标注文件中依次获得标注框并与预测框进行iou比对,保留iou值最大的(这里计算时为该图片内某个类别) 完成第三层循环后,判断是否变为TP,否则为FP 紧接着完成第二层循环,然后进行总结TP,FP,计算Precision和ReCall
cumsum = 0for idx, val in enumerate(fp):fp[idx] += cumsumcumsum += valcumsum = 0for idx, val in enumerate(tp):tp[idx] += cumsumcumsum += valrec = tp[:]for idx, val in enumerate(tp):rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)prec = tp[:]for idx, val in enumerate(tp):prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
接着便计算AP值,F1等,ap值计算代码如下:
def voc_ap(rec, prec):"""--- Official matlab code VOC2012---mrec=[0 ; rec ; 1];mpre=[0 ; prec ; 0];for i=numel(mpre)-1:-1:1mpre(i)=max(mpre(i),mpre(i+1));endi=find(mrec(2:end)~=mrec(1:end-1))+1;ap=sum((mrec(i)-mrec(i-1)).*mpre(i));"""rec.insert(0, 0.0) # insert 0.0 at begining of listrec.append(1.0) # insert 1.0 at end of listmrec = rec[:]prec.insert(0, 0.0) # insert 0.0 at begining of listprec.append(0.0) # insert 0.0 at end of listmpre = prec[:]"""This part makes the precision monotonically decreasing(goes from the end to the beginning)matlab: for i=numel(mpre)-1:-1:1mpre(i)=max(mpre(i),mpre(i+1));"""for i in range(len(mpre)-2, -1, -1):mpre[i] = max(mpre[i], mpre[i+1])"""This part creates a list of indexes where the recall changesmatlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;"""i_list = []for i in range(1, len(mrec)):if mrec[i] != mrec[i-1]:i_list.append(i) # if it was matlab would be i + 1"""The Average Precision (AP) is the area under the curve(numerical integration)matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));"""ap = 0.0for i in i_list:ap += ((mrec[i]-mrec[i-1])*mpre[i])return ap, mrec, mpre