本文译自:http://www.cs.bilkent.edu.tr/~canf/CS533/hwSpring14/eightMinPresentations/handoutMMR.pdf
写在前面的话
MMR算法目的是减少排序结果的冗余,同时保证结果的相关性。最早应用于文本摘要提取和信息检索等领域。一般用于在找出用户与商品的相关性之后,依据商品间的相似性对商品列表的多样性进行处理。在推荐场景下体现在,给用户推荐相关商品的同时,保证推荐结果的多样性,即排序结果存在着相关性与多样性的权衡。通过阅读本文你能明白该算法的实现细节。
MMR
MMR的全称为最大边缘相关模型(Maximal Marginal Relevance)。在MMR模型中,同时将相关性和多样性进行衡量。因此,可以方便的调节相关性和多样性的权重来满足偏向“需要相似的内容”或者偏向“需要不同方面的内容”的要求。它的数学公式如下:
其中R是输入的列表,Di是集合R的成员,S是当前返回的结果集。Sim1(Di,Q)可以是协同过滤算法出来的相关度,也可以是其他算法模型预测出来的相关度。Sim2(Di,Dj)可以是余弦相似度、海明距离等可以度量的算法计算出来的值,它内部迭带的是Di和Dj,其中Dj是S的成员。步骤为外层对R列表进行迭代,计算Sim1(Di,Q)的值,内层对S列表进行迭代计算Sim2(Di,Dj)的值,然后对R列表进行排序后的结果取最大值,得到的i对应的元素Di就是我们本次迭代需要的值,我们可以将Di从R列表中删除并加入到S列表中,然后进行下一次的迭代。
内层对列表S进行迭代,先计算Di与S集合中的每个元素Dj之间的相似度取出最大值,然后外层与Sim1(Di,Q)按上面进行计算。
假设我们有一个包含5个文档di的数据库和一个查询q,给定一个对称的相似度度量,我们计算相似度值如下。提前假设用户设定的λ的值为0.5:
S是一个对称矩阵。
第一次迭代
目前我们的结果集S是空的。因此,方程的后半部分,即S内的最大成对相似度,将为零。第一次迭代时,MMR方程简化为:
MMR = arg max (Sim (di, q))
d1与q相似度最大,因此我们取其加入到S中,现在S = {d1}
第二次迭代
由于S = {d1},求S中某元素到给定di的最大距离就是sim(d1,di)。
对于d2:
sim(d1, d2) = 0.11
sim (d2, q) = 0.90
然后 MMR = λ0.90 – (1-λ)0.11 = 0.395
同样,d3, 4, 5的MMR值分别为0.135,-0.35和0.19。由于d2具有最大的MMR,我们将其添加到S中,此时S = {d1, d2}。
第三次迭代
这次S = {d1, d2}。我们应该找到max。对于方程的第二部分,为sim (di, d1)和sim (di, d2)。
对于d3:
max{sim (d1, d3), sim (d2, d3)} =
max {0.23, 0.29} = 0.29
sim (d3, q) = 0.50
MMR = 0.5*0.5 - 0.5*0.29 = 0.105
同理,其他MMR的值计算为:
d4: -0.35, d5: 0.06
d3的MMR最大,因此S = {d1, d2, d3}。
如果我们完全没有多样性(λ=1),那么我们的S将是{d1, d2, d5}。注意,不同情况的总体两两相似度为:
sim (d1, d2) + sim(d1, d3) + sim (d2, d3) = 0.63
而非多样性版本的总两两相似度为0.87。我们有效地使结果集中的项之间的差异更大。还要注意,查询的总相似度已经从2.44减少到2.31。为了多样性,我们牺牲了一些准确性。
附MMR的java实现
/**
* 迭代
*
* @param R
* @param matrix
* @param topNum
* @param lambda
* @return
*/
static List<ClassIdScore> iterator(LinkedList<ClassIdScore> R, double[][] matrix, int topNum, double lambda) {
AtomicInteger counter = new AtomicInteger(0);
Stopwatch stopwatch = Stopwatch.createStarted();
List<ClassIdScore> S = new ArrayList(topNum);
while (S.size() < topNum && R.size() > 0) {
double maxMMR = 0d;
ClassIdScore maxClassIdScore = R.get(0);
// 进行迭代
for (int i = 0; i < R.size(); i++) {
ClassIdScore di = R.get(i);
double maxSim2 = 0d;
for (int j = 0; j < S.size(); j++) {
counter.getAndIncrement();
ClassIdScore dj = S.get(j);
double sim2Score = matrix[di.getIndex()][dj.getIndex()];
if (j == 0) {
maxSim2 = sim2Score;
}
if (sim2Score > maxSim2) {
maxSim2 = sim2Score;
}
}
// 计算MMR
double mmr = lambda * di.score - (1 - lambda) * maxSim2;
if (i == 0) {
maxMMR = mmr;
}
if (mmr > maxMMR) {
maxMMR = mmr;
// System.out.println("第" + i + "次-----------------mmr:" + mmr + " " + JSONObject.toJSONString(di));
maxClassIdScore = di;
}
}
R.remove(maxClassIdScore);
S.add(maxClassIdScore);
}
stopwatch.stop();
// System.out.println("----删除耗时:" + removeMills + " 迭代耗时:" + stopwatch.elapsed(TimeUnit.MILLISECONDS) + " 迭代次数为:" + counter.get());
return S;
}