Code: Word2Vec in Spark

Here is a snippet that might be useful to you if you are looking to implement Word2Vec and save the embeddings of the trained model. I’ve added types to the variables as well as to some placeholder names to make it easier to understand what is expected as an input to various functions

First you will need to import Word2Vec and Word2VecModel

    import org.apache.spark.ml.feature.{Word2Vec, Word2VecModel}

Then you will need to import spark session’s implicits as you will be working with Datasets.

    import sparkSession.implicits._

Then you need to prepare your input data for the algorithm. The training algorithm expects a Dataset of Sequences of Strings. In the example below “getTrainset” is a function that retrieves your training corpus and formats it into the required type.

    val trainset: Dataset[Seq[String]] = getTrainset() 

Now you are ready to start using Word2Vec. You need to first configure the Word2Vec algorithm with the parameters that you have selected.

    val wordToVec: Word2Vec = new Word2Vec()
    .setInputCol("column-name-of-your-trainset-dataset")
    .setOutputCol("output-column-name")
    .setStepSize(learningRateFloat)
    .setVectorSize(vectorSizeInt)
    .setWindowSize(windowSizeInt)
    .setMaxIter(numberOfIterationsInt)
    .setMinCount(minimumCountForDiscartingInt)
    .setNumPartitions(wordToVecPartitions)

Next comes the training step.

    val model = wordToVec.fit(trainset)

Once your model has been trained you will want to process it and save it in the format that you want. Because you are dealing with a dataset you need to map the results to something more type safe. You can do this with a case class as follows:

    case class MyCustomType(word: String, vector: Vector) {
        def toPair = (word.toLong, vector.toArray)
    }

    model.getVectors.as[MyCustomType].map(_.toPair)

Lastly, you want to save you model. Note that the Word2Vec model does have a save function that saves it in a special format easily reloadable into a Word2VecModel. In this example we will save to parquet instead as you may need a more raw version of your model.

   model
      .repartition(partitionsToSaveModelInto)
      .withColumnRenamed("_1", "word")
      .withColumnRenamed("_2", "vector")
      .select("word", "vector")
      .write.parquet(options.outputFile)

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s