Writing a Spark Word Count Program with Unit Tests

We set up a boilerplate project in our previous post, we will now use the same project as base and build a word count program in spark using scala.

This is going to be very verbose, if you just want a one liner to clear an interview question, go right to the bottom.

The steps we will follow will look like this:

Spark Word Count Steps

We will isolate our business logic, in this case the clean and tokenize functions, from the spark framework logic. This helps us write unit tests that can validate our functions in isolation.

Exporting the boilerplate template, I have created a scala object class, called WordCount.scala, that will hold all our code required to compute the frequencies of words. This class should be Serializable, as this will be used across worker nodes.

Step 1: Clean Data Method

We clean our data using the following steps:

  • Convert line to lower case
  • Replace comma and period with a space
  • Remove all non alphanumeric characters, except hypen
  • Convert multiple spaces to single space
  • Trim spaces from beginning and end of line

The code would go as below


  def cleanData(line: String): String = {
    line
      .toLowerCase
      .replaceAll("[,.]"," ")
      .replaceAll("[^a-z0-9\\s-]","")
      .replaceAll("\\s+"," ")
      .trim
  }

Example: "This is test, you are testing." -> "this is test you are testing"

Step 2: Tokenize Data

Tokenizing the data is the processing of converting text to tokens, in our case, words. We use a simple split method to split the lines by space and generate a list of words.


  def tokenize(line: String): List[String] = {
    line.split("\\s").toList
  }

Example: "this is test you are testing" -> List("this","is","test","you","are","testing")

Step 3: Key Value Generator

Before counting the frequencies of words, it is first required to convert our data to key values pairs and then we reduce (aggregate) these, based on key, to calculate sum of values for each key.


  def keyValueGenerator(word: String): (String, Int) = {
    (word, 1)
  }

Example: "test" -> ("test", 1)

Note how we set the value as 1, we just sum up these ones at the end to get the total count of that word.

Step 4: Read, Compute Word Count and Write

We will now come to the spark part of the code, this is where we use the above functions and do our word count. This will go into Driver.scala.

  • We first read the text file, the spark readText() method returns a dataset of string
    • The spark implicits import statement provides encoders for the dataset
  • We then use our cleanData() method to clean every line
    • map() function iterates over all the lines present in the text file
  • We then use tokenize() method with flatMap() to split lines to words
    • flatMap function takes in one input and generates one or more rows
    • In our case, it takes in a line of text, get a List of string as output from tokenize functions, it expands this list into separate records
    • Thus we again get a dataset of string, but the data is now individual word and not full sentences
  • We use filter function to drop any empty strings
  • We then use our keyValueGenerator() method to convert the data set to key value pairs
  • After which, we apply a reduce function, with addition as our reduction function, to add up our values and generate a sum.
    • Note how we had to convert our dataset to rdd to perform reduction
  • We then convert our rdd to dataframe by specifying the column names to make it user friendly to read the data after we write to CSV
  • We then write the data to CSV file, setting "header" option to true. This will provide the column name headers in CSV file

  def run(spark: SparkSession, inputFilePath: String, outputFilePath: String): Unit = {
    import spark.implicits._
    val data: Dataset[String] = spark.read.textFile(inputFilePath)
    val words: Dataset[String] = data
      .map(cleanData)
      .flatMap(tokenize)
      .filter(_.nonEmpty)
    val wordFrequencies: DataFrame = words
      .map(keyValueGenerator)
      .rdd.reduceByKey(_ + _)
      .toDF("word", "frequency")
    wordFrequencies.write.option("header","true").csv(outputFilePath)
    LOG.info(s"Result successfully written to $outputFilePath")
  }

Step 5: Driver Main Method

As you would have observed, we take the input and output paths as a parameter to our run() method, which is called from main() method.

Below is the code for main() method from Driver class.


  def main(args: Array[String]): Unit = {
    if(args.length != 2) {
      println("Invalid usage")
      println("Usage: spark-submit --master yarn spark-wordcount-1.0.jar /path/to/input/file.txt /path/to/output/directory")
      LOG.error(s"Invalid number of arguments, arguments given: [${args.mkString(",")}]")
      System.exit(1)
    }
    val spark: SparkSession = SparkSession.builder().appName(JOB_NAME).getOrCreate()
    run(spark, args(0), args(1))
  }

Step 6: Testing Individual Methods

We will now write unit tests to check if the methods cleanData(), tokenize() and keyValueGenerator(), are working as expected.

I am using a WordCountTest.scala class to write the unit tests. The below code should be self explanatory. This class follows a similar structure of DriverTest.scala from the boilerplate.


  test("Test data cleaning") {
    assertResult("this is text with 123 from spark-wordcount")(WordCount.cleanData("$This is text, with 123, from Spark-WordCount."))
  }
  test("Test tokenizer") {
    List("tokenized","this","is") should contain theSameElementsAs WordCount.tokenize("this is tokenized")
  }
  test("Test key value generator") {
    assertResult(("test", 1))(WordCount.keyValueGenerator("test"))
  }

Step 7: Testing Spark Output

We should not write unit test cases to test if spark API is working as expected, if it wasn't, you wouldn't have been able to download it directly.

We will write a simple unit test to check, if our methods work fine when stitched together using spark APIs.

In this case, I have created a dummy text file with 1000 words and will check if, after going through the above methods, we have any data drops.

  test("Verify no data is dropped") {
    import implicits._
    val data: Dataset[String] = spark.read.textFile(TEST_INPUT)
    val words: Dataset[String] = data
      .map(WordCount.cleanData)
      .flatMap(WordCount.tokenize)
      .filter(_.nonEmpty)
    assertResult(1000L)(words.count())
  }

That should give you an idea of how to write unit tests in spark.

The full source code, along with execution instructions, can be found at https://git.barrelsofdata.com/barrelsofdata/spark-wordcount.

Word Count for Interviews

When asked this question in interviews, you can't tell the above story to the interviewer. So, use the below one liner to clear it.

spark.read.textFile("text.txt")
    .flatMap(_.split(" "))
    .reduceByKey(_ + _)
    .show(truncate = false)