song recommendation system by spark

16 November 2019

cassandra에 노래와 사용자가 선호하는 노래의 정보가 들어있는 상태에서
성향이 비슷한 유저가 듣는 노래를 종합해 추천하는 시스템을 만들어보도록 하겠습니다.

먼저 아래와 같이 해당 곡들의 id를 모아 저장할 accumulator를 만들어줍니다.

package org.shashaka.test

import java.util
import java.util.Collections

import org.apache.spark.util.AccumulatorV2

import scala.collection.JavaConversions._

class MapAccumulator[T] extends AccumulatorV2[T, java.util.Map[T, java.lang.Integer]] {
  private val _map: java.util.Map[T, java.lang.Integer] = Collections.synchronizedMap(new util.HashMap[T, java.lang.Integer]())

  override def isZero: Boolean = _map.isEmpty

  override def copyAndReset(): MapAccumulator[T] = new MapAccumulator

  override def copy(): MapAccumulator[T] = {
    val newAcc = new MapAccumulator[T]
    _map.synchronized {
      newAcc._map.putAll(_map)
    }
    newAcc
  }

  override def reset(): Unit = _map.clear()

  override def add(v: T): Unit = _map.synchronized {
    val old = _map.put(v, 1)
    if (old != null) {
      _map.put(v, 1 + old)
    }
  }

  override def merge(other: AccumulatorV2[T, java.util.Map[T, java.lang.Integer]]): Unit = other match {
    case o: MapAccumulator[T] => {
      for ((k, v) <- o.value) {
        val old = _map.put(k, v)
        if (old != null) {
          _map.put(k, old + v)
        }
      }
    }
    case _ => throw new UnsupportedOperationException(
      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
  }

  override def value: java.util.Map[T, java.lang.Integer] = _map.synchronized {
    java.util.Collections.unmodifiableMap(new util.HashMap[T, java.lang.Integer](_map))
  }

  def getKeyOfMaxValue(): java.lang.Integer = {
    var maxValKey: java.lang.Integer = 0;
    var maxVal: java.lang.Integer = 0;
    _map.keys.foreach(key =>
      if (_map.get(key) > maxVal) {
        maxVal = _map.get(key)
        maxValKey = Integer.parseInt(key.toString)
      }
    )
    maxValKey;
  }
}
그리고 나서, 아래와 같이 현재 유저가 선호하는 곡의 id를 기준으로,
다른 유저들과 현재 유저의 similarity를 계산하고,
비슷한 성향 값이 높은 순서대로 sorting해주도록 합니다.
이후에, 상위 100개의 곡 리스트를 모으고 모두 합해 가장 많이 듣고 있는 노래 하나를 고릅니다.
package org.shashaka.test

import com.datastax.spark.connector._
import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object SparkApp extends App {

  override def main(args: Array[String]) {
    val conf = new SparkConf(true)
      .set("spark.cassandra.connection.host", "my_cassandra_ip")
      .set("spark.cassandra.auth.username", "cassandra")
      .set("spark.cassandra.auth.password", "cassandra")

    val sc = new SparkContext("local", "my_app", conf)

    val rdd = toSparkContextFunctions(sc).cassandraTable("my_keyspace", "user_rankings")

    val sampleMap = getSampleMap

    val accumulator = new MapAccumulator[Int]
    sc.register(accumulator)

    val grades = rdd
      .map(row =>
        getSimilarity(sampleMap, row.getMap[Int, Int]("rank"))
      ).filter(list => list.nonEmpty)
      .sortBy(_.keys, false)
      .collect()
      .take(100)
      .foreach(map => {
        map.values.foreach(list => list.foreach(value => accumulator.add(value)))
      })

    println("\nvalue : " + accumulator.getKeyOfMaxValue())

    sc.stop()
  }

  private def getSampleMap: Map[Int, Int] = {
    var sampleMap = new mutable.HashMap[Int, Int]();
    sampleMap.put(1, 1)
    sampleMap.put(3, 4)
    sampleMap.put(2, 7)
    sampleMap.put(5, 4)
    sampleMap.put(8, 2)
    return sampleMap.toMap
  }

  private def getSimilarity(a: Map[Int, Int], b: Map[Int, Int]): Map[Double, List[Int]] = {

    var sum = 0.0

    for ((k, v) <- a) {
      val aValue = a.getOrElse(k, 0).doubleValue()
      val bValue = b.getOrElse(k, 0).doubleValue()

      if (bValue == 0.0) {
      } else if (aValue > bValue) {
        sum += bValue / aValue
      } else {
        sum += aValue / bValue
      }
    }

    val songList = new ListBuffer[Int]()

    for ((k, v) <- b) {
      if (!a.contains(k)) {
        songList += k
      }
    }

    return Map[Double, List[Int]](sum -> songList.toList)
  }
}
위 코드를 running 시키면 아래와 같이 하나의 곡 id가 나오는 것을 확인할 수 있습니다.
value : 6