Currying Functions in Scala

Currying is a means of transforming a function that takes more than one argument into a chain of calls to functions, each of which takes a single argument.

Let us consider the function below to calculate the final price of the product. The function takes in 3 parameters:

  • VAT (vat) for the region
  • Service charge (serviceCharge) of the shop
  • Product price (productPrice)
  def finalPrice(vat: Double, serviceCharge: Double, productPrice: Double): Double = {
    productPrice + productPrice * serviceCharge / 100 + productPrice * vat / 100
  }

But if you think about the function finalPrice again, a shopkeeper has to provide all the above values time and again whenever he wants to calculate the final price. Of course, that ignores the fact that the values of:

  • VAT is already defined for a country
  • Service charge for a shop is constant

So we will try to make life of our client a little bit easier. Let us define curried finalPrice:

  def finalPriceCurried(vat: Double)(serviceCharge: Double)(productPrice: Double): Double = {
    productPrice + productPrice * serviceCharge / 100 + productPrice * vat / 100
  }

We are taking this approach because our vat and serviceCharge will not change very often. So, let’s use currying to split our method. We will declare a new val: vatApplied. I will provide the value of vat to the finalPriceCurried method and assign it to vatApplied.

val vatApplied = finalPriceCurried(20) _

Next, we will provide a service charge to my vatApplied val, and we will leave the price to be provided by the shopkeeper whenever they need it.

val serviceChargeApplied = vatApplied(12.5)

Let us test our serviceChargeApplied function to calculate the final price of the product.

val finalProductPrice = serviceChargeApplied(120)
println(finalProductPrice); // 159.0

We have reduced our method from accepting 3 parameters to accept one parameter. So, we have split our method in such a way that we don’t have to provide all the arguments at the same time. I can provide these arguments whenever they are available. This transformation is called currying.

We can also convert our existing methods to curry methods using function curried method.  Example:

  def add3Num(num1: Int, num2: Int, num3: Int): Int = {
    num1 + num2 + num3
  }

  val curriedAdd3Num = (add3Num _).curried

  println(add3Num(1, 2, 3))

  val addFirstNum = curriedAdd3Num(1)
  val addSecondNum = addFirstNum(2)
  val addThirdNum = addSecondNum(3)
  println(addThirdNum)
;

Complete Program:

object CurryingExample {

  def finalPrice(vat: Double, serviceCharge: Double, productPrice: Double): Double = {
    productPrice + productPrice * serviceCharge / 100 + productPrice * vat / 100
  }

  def finalPriceCurried(vat: Double)(serviceCharge: Double)(productPrice: Double): Double = {
    productPrice + productPrice * serviceCharge / 100 + productPrice * vat / 100
  }

  def add2Num(num1: Int)(num2: Int): Int = {
    num1 + num2
  }

  def add3Num(num1: Int, num2: Int, num3: Int): Int = {
    num1 + num2 + num3
  }

  def main(args: Array[String]): Unit = {

    val vatApplied = finalPriceCurried(20)_
    val serviceApplied = vatApplied(12.5)
    val finalPriceAppliedForPizza = serviceApplied(250) // 331.25
    val finalPriceAppliedForPasta = serviceApplied(200) // 265
    println(finalPriceAppliedForPizza)
    println(finalPriceAppliedForPasta)

    val finalPriceForPizza = finalPrice(20, 12.5, 250);
    val finalPriceForPasta = finalPrice(20, 12.5, 200);
    println(finalPriceForPizza) // 331.25
    println(finalPriceForPasta) // 265.0

    val addNum1 = add2Num(5)_
    val sumOf2Num = addNum1(10)
    println(sumOf2Num) // 15

    val curriedAdd3Num = (add3Num _).curried
    println(add3Num(1, 2, 3)) // 6
    val addFirstNum = curriedAdd3Num(1)
    val addSecondNum = addFirstNum(2)
    val addThirdNum = addSecondNum(3)
    println(addThirdNum) // 6
  }
}

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.