`

Als最小二乘法

阅读更多

spark-submit --class   com.ones.soc.cf.MoiveRecommender --master yarn --num-executors 3 --driver-memory 5g --executor-memory 4g /root/bigData.jar 2 5 0.01 /ones/mldata/1u.user /ones/mldata/1u.data /ones/result/1



package com.ones.soc.cf


import com.ones.soc.json.JSONObject
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.mllib.recommendation._
import org.apache.spark.rdd.{ PairRDDFunctions, RDD }
import org.apache.spark.SparkContext
import scala.collection.mutable.HashMap
import java.util.List
import java.util.ArrayList
/**
  * Created by tom
  */
object MoiveRecommender {

  val numRecommender = 10

  case class Params(
                     input: String = null,
                     numIterations: Int = 20,
                     lambda: Double = 1.0,
                     rank: Int = 10,
                     numUserBlocks: Int = -1,
                     numProductBlocks: Int = -1,
                     implicitPrefs: Boolean = false,
                     userDataInput: String = null)

  def main(args: Array[String]) {
      run(args: Array[String])
  }

  def run(args: Array[String]) {
    val confighdfs = new Configuration();
    val fs=FileSystem.get(confighdfs) ;
    if(args(5) != null && args(5).trim().length > 1){
      val output = new Path(args(5));
      if(fs.exists(output)){ //删除输出目录
         fs.delete(output, true);
      }
    }

    var input: String = null
    var numIterations: Int = 20
    var lambda: Double = 1.0
    var rank: Int = 10
    var numUserBlocks: Int = -1
    var numProductBlocks: Int = -1
    var implicitPrefs: Boolean = false
    var userDataInput: String = null

    rank=args(0).toInt
    numIterations=args(1).toInt
    lambda=args(2).toDouble
    userDataInput=args(3).toString
    input=args(4).toString
    var outpath=args(5).toString

    //本地运行模式,读取本地的spark主目录
    var conf = new SparkConf().setAppName("Moive Recommendation")
    //.setSparkHome("D:\\work\\hadoop_lib\\spark-1.1.0-bin-hadoop2.4\\spark-1.1.0-bin-hadoop2.4")
    //conf.setMaster("local[*]")

    //集群运行模式,读取spark集群的环境变量
    //var conf = new SparkConf().setAppName("Moive Recommendation")
    val context = new SparkContext(conf)
    //加载数据
    val data = context.textFile(input)
    /**
      * *MovieLens ratings are on a scale of 1-5:
      * 5: Must see
      * 4: Will enjoy
      * 3: It's okay
      * 2: Fairly bad
      * 1: Awful
      */
    val ratings = data.map(_.split("\t") match {
      case Array(user, item, rate, time) => Rating(user.toInt, item.toInt, rate.toDouble)
    })

    //使用ALS建立推荐模型
    //也可以使用简单模式    val model = ALS.train(ratings, ranking, numIterations)
    val model = new ALS()
      .setRank(rank)
      .setIterations(numIterations)
      .setLambda(lambda)
      .setImplicitPrefs(implicitPrefs)
      .setUserBlocks(numUserBlocks)
      .setProductBlocks(numProductBlocks)
      .run(ratings)

    //预测数据并保存
    predictMoive(userDataInput, context, model,fs,outpath)
    //模型评估
    evaluateMode(ratings, model)
    //clean up
    context.stop()
  }

  /**
    * 模型评估
    */
  private def evaluateMode(ratings: RDD[Rating], model: MatrixFactorizationModel) {

    //使用训练数据训练模型
    val usersProducets = ratings.map(r => r match {
      case Rating(user, product, rate) => (user, product)
    })

    //预测数据
    val predictions = model.predict(usersProducets).map(u => u match {
      case Rating(user, product, rate) => ((user, product), rate)
    })

    //将真实分数与预测分数进行合并
    val ratesAndPreds = ratings.map(r => r match {
      case Rating(user, product, rate) =>
        ((user, product), rate)
    }).join(predictions)

    //计算均方差
    val MSE = ratesAndPreds.map(r => r match {
      case ((user, product), (r1, r2)) =>
        var err = (r1 - r2)
        err * err
    }).mean()

    //打印出均方差值
    println("Mean Squared Error = " + MSE)
  }

  /**
    * 预测数据并保存
    */
  private def predictMoive(userDataInput: String, context: SparkContext, model: MatrixFactorizationModel,fs:FileSystem,outpath:String) {

    var recommenders = new ArrayList[java.util.Map[String, String]]();
    var sb=new StringBuilder

    //读取需要进行电影推荐的用户数据
    val userData = context.textFile(userDataInput) //u.user

    userData.map(_.split("\\|") match {
      case Array(id, age, sex, job, x) => (id)
    }).collect().foreach(id => {
      //为用户推荐电影
      var rs = model.recommendProducts(id.toInt, numRecommender)
      var value = ""
      var key = 0


      rs.foreach(r => {
        key = r.user
        value = value + r.product + ":" + r.rating + ","
      })
      sb.append("user="+key+"\t"+"value="+value).append("\r\n")
     //成功,则封装put对象,等待插入到Hbase中
     /*
     if (!value.equals("")) {
        var put = new java.util.HashMap[String, String]()
        put.put("rowKey", key.toString)
        put.put("t:info", value)
        recommenders.add(put)
       }
      */
    })
    outputHdfs(fs,sb.toString(),outpath)

    //保存到到HBase的[recommender]表中
    //recommenders是返回的java的ArrayList,可以自己用Java或者Scala写HBase的操作工具类,这里我就不给出具体的代码了,应该可以很快的写出
    //HbaseUtil.saveListMap("recommender", recommenders)
  }


  def outputHdfs(fs:FileSystem,text:String,textdir:String):Unit={
    try{
      val fsDataOutputStream = fs.create(new Path(textdir+"/result.txt"), true);
      val s=text.getBytes("UTF-8")
      fsDataOutputStream.write(s,0,s.length)
      fsDataOutputStream.hflush();
    }catch{
      case e:Exception =>
    }

  }
}


分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics