Python使用余弦相似度比对特征,numpy余弦相似度比对特征
先介绍下背景:我们有个项目是基于深度神经网络的菜品识别类项目,简易流程可以理解为:通过深度学习模型,提取菜品图像的菜品特征。将特征存储到特征库中。获取摄像头下方的菜品图像,提取特征。拿到摄像头下发图像的特征,通过余弦相似度对比,得出相似度。
先来一段特征注册的代码逻辑:import numpy as np import pickle def featRegister(tag): """ repo 为字典类型, 特征库 { "包子": array([[2.9554849,..,0.45501372],[1.6043618,0.,...,0.]], dtype=float32), "鱼头": array([[1.6043618,..., 0.]], dtype=float32), "米饭": array([[0., 0., 2.9554849 , ..., 0., 0.,0.45501372]], dtype=float32) } """ # 1. 组建特征库 repo = {} for dishFeaId, featureSrc in tag.items(): if not featureSrc: continue feat = pickle.loads(featureSrc.encode("latin1")) repo[dishFeaId] = np.array([feat]) # 2. 计算 active_repo_feats all_repo_feats repo_count repo_concat = np.concatenate([repo[key] for key in repo], 0) all_repo_feats = np.array([x / np.linalg.norm(x) for x in repo_concat]) active_repo_ids = [] for id in repo: active_repo_ids += [id] * len(repo[id]) repo_count = len(repo) return all_repo_feats, active_repo_ids, repo_count
组成了一个repo的dict,并使用了pickle进行序列化,通过numpy进行拼接。
接下来,我们来模拟一个需要比对的特征值:feats = [] for dishFeaId, featureSrc in featsN.items(): feat_index = pickle.loads(featureSrc.encode("latin1")) feats.append(np.array(feat_index))
然后就是通过与余弦的相似度进行特征比对的逻辑:import uuid import numpy as np def featsCompare(feats, feat_boxes, all_repo_feats, active_repo_ids, repo_count): if repo_count <= 0: return None dishId_feaIds_ids = [] dishId_feaId_scores = [] feats = [x / np.linalg.norm(x) for x in feats] cos_distance = np.matmul(feats, all_repo_feats.transpose((1, 0))) ids_score_sorted = np.sort(cos_distance) ids_score_sorted = [x[::-1] for x in ids_score_sorted] ids_indices_sorted = np.argsort(cos_distance) ids_indices_sorted = [x[::-1] for x in ids_indices_sorted] for i in range(len(ids_indices_sorted)): food_ids, food_scores = [], [] all_repo_ids_sorted = np.array(active_repo_ids)[list(ids_indices_sorted[i])] for j in range(len(all_repo_ids_sorted)): if all_repo_ids_sorted[j] not in food_ids: food_ids.append(all_repo_ids_sorted[j]) food_scores.append(ids_score_sorted[i][j]) for x in food_ids: dishId_feaIds_ids.append(x) for x in food_scores: dishId_feaId_scores.append(round(float(100 * x), 2)) resultCount = len(feat_boxes) resultList = [] for i in range(resultCount): resultN = feat_boxes[i] resultN["dishes"] = [] idList, scoreList = [], [] for j in range(repo_count): k = j + i * repo_count idList.append(dishId_feaIds_ids[k]) scoreList.append(dishId_feaId_scores[k]) # 筛选dishId与feaId、score对应关系 dishId_feaId_score = {} for m in range(len(idList)): dishId_feaId = idList[m] dishId_feaIds = dishId_feaId.split("_") if len(dishId_feaIds) == 1: dishId = dishId_feaIds[0] feaId = dishId elif len(dishId_feaIds) == 2: dishId = dishId_feaIds[0] feaId = dishId_feaIds[1] else: uuId = str(uuid.uuid1()) dishId = uuId.replace("-", "") uuId = str(uuid.uuid1()) feaId = uuId.replace("-", "") if dishId not in dishId_feaId_score: # 排名最高的dishIds,及其对应的feaId和score 的汇总 dishId_feaId_score[dishId] = [feaId, scoreList[m]] if len(dishId_feaId_score) >= 10: break dishes = [] for dishId in dishId_feaId_score: feaId = dishId_feaId_score[dishId][0] score = dishId_feaId_score[dishId][1] topN = { "dishId": dishId, "score": score, } dishes.append(topN) resultN["dishes"] = dishes resultList.append(resultN) return resultList
以上有一些是为了测试的业务逻辑,但整体逻辑和主要代码是没什么问题的。
以上,仅供参考。