Sunday, August 8, 2010

Delimited Continuations

Scala's delimited continuations, introduced in version 2.8, can be used to implement all sorts of interesting control constructs.

This is a very long blog post. It took me quite a while to get my head around Scala's reset and shift operators. To help others hopefully avoid the stumbling blocks I encountered, I have tried here to start with the basics and build up from there in some detail. If you want a shorter explanation, see the Resources section at the end of this post for pointers to some other blog entries that are more succinct.

Contents

Mechanics

In order to use Scala's delimited continuations, you must use version 2.8, and you must use the continuations (or CPS) compiler plugin. You do this by specifying a command line option when running both the compiler and the runtime:

$ scalac -P:continuations:enable ${sourcefiles}
$ scala -P:continuations:enable ${classname}
In your source code, you must import the appropriate continuations elements, which you can do most simply by using a wildcard to import everything:
import scala.util.continuations._
If you forget to do the import you will get an error message similar to this:
<console>:6: error: not found: value reset
       reset {
       ^

Continuation Passing Style (CPS)

In order to understand how Scala's delimited continuations work, you have to understand the "continuation passing style", or CPS.

Consider this code in which a method makes a subroutine call:
def main {
    pre
    sub()
    post
}
def sub() {
    substuff
}
where pre and post represent all of the code in main respectively before and after the call to sub, and substuff represents all of the code in sub.

When the sub method gets called, the system, in effect, instructs the processor to execute the sub code, then to continue execution within main immediately after the call to sub.

We can conceptually refactor the code in main so that all of the stuff in pre is in a separate method, and all of the post stuff is in a separate method. We can further refactor the code so that each section (pre, sub, post) takes in all of its input data as arguments and passes all of its data changes out as an aggregate return value (such as a Map or Tuple) of the method for that section. Adding arguments and return value to main, we have something that looks like this:
def main(m:M):Z = {
    val x:X = pre(m)
    val y:Y = sub(m,x)
    val z:Z = post(m,x,y)
    return z
}
def sub(m:M,x:X):Y {
    val y:Y = substuff(m,x)
    return y
}
Now, instead of the system automatically continuing execution at post after finishing sub, let's make that explicit in our code by passing the chunk of code that calls post as an extra argument to sub. We will then modify sub so that, after doing all of its calculations and generating the values it would have returned to main as y, it instead calls post with its arguments as specified, and returns as its own value the return value of post, which is z in main.
def main(m:M) {
    val x:X = pre(m)
    val z:Z = sub(m,x, { post(m,x,_) } )
    return z
}
def sub(m:M,x:X, subCont: (Y) => Z) {
    val y:Y = substuff(m,x)
    val z:Z = subCont(y)
    return z
}
When we pass the code fragment containing post to sub, Scala generates a closure that captures the values available to post at that point, including m and x, so that when that closure is evaluated later it can get those values.

Note that the main method no longer sees y, the original return value from sub, so it can't be explicitly passed to post; instead, we use a placeholder, which is filled in by the code in sub that calls post. We can rewrite that line to use the more explicit function syntax (where, for convenience, we use y as our parameter name):
    val z:Z = sub(m,x, { (y:Y) => post(m,x,y) } )
The gist of CPS is that we don't use return. Rather than calling a subroutine and having it return to us, as is the case in the normal Direct Style, we pass a continuation to the subroutine for it to execute when it is done.

Nested CPS

In the above example we have only taken the first step in converting to CPS. To be able to take advantage of CPS, we need to complete the transformation.

At the top, our main method is still returning a value. Since we have no return in CPS, how do we handle this? The answer is that the topmost level can not return a value. Let's add a top-level wrapper like this:
def prog(m:M) {
    val z:Z = main(m)
    println(z)
    System.exit(z.exitValue)
}
Now we can make the same CPS transformation on prog and main as we did before on main and sub:
def prog(m:M) {
    main(m, { (z:Z) =>
        println(z)
        System.exit(z.exitValue)
    })
}

def main(m:M, mainCont:(Z)=>Unit):Unit = {
    val x:X = pre(m)
    val z:Z = sub(m,x, { (y:Y) => post(m,x,y) } )
    mainCont(z)
}
We are still using a return statement in sub, with code in main following the return from sub. To fix this, we need to push the mainCont in main into the continuation we pass to sub. We modify both main and sub to do this:
def main(m:M, mainCont:(Z)=>Unit):Unit = {
    val x:X = pre(m)
    sub(m,x, { (y:Y) => {
        val z:Z = post(m,x,y) } )
        mainCont(z)
    })
}

def sub(m:M,x:X, subCont: (Y) => Unit) {
    val y:Y = substuff(m,x)
    subCont(y)
}
We have now threaded our top-level continuation - the one that includes the call to System.exit - all the way down to sub, so when we execute the subCont in sub, it will first execute the post method with the code in main that originally appeared after sub, then it will execute the code in prog that originally appeared after the call to main, which will call println and then exit the program by calling System.exit.

If we wanted to convert substuff to CPS, we would apply the same transformation to it and sub, after which the call from sub to substuff would pass an additional argument which was the continuation of the rest of sub, which includes the continuation passed from main to sub, which in turn includes the continuation passed from prog into main.

As you can see, each continuation that we pass down to another subroutine always includes the continuations for all of the callers. In other words, every continuation includes all of the rest of the program to be executed after the called subroutine is done. The other important point is that in every method where we call a subroutine using CPS, that call is always the very last thing in the method.

Full versus Delimited Continuations

In the discussion above we have assumed that the entire program is converted over to CPS. This is the classical definition of continuations, which can be referred to as full continuations. However, using CPS in languages (such as Scala) that were not specifically designed for it can be awkward, so it would be nicer if we could restrict the use of CPS to the specific areas in our code where we want to use it.

This is exactly the intent of a delimited continuation. Rather than attempting to capture the entire remainder of the program execution in a continuation, we only capture the remaining execution of the program up to a specified point.

If we reexamine the start of our sample program, the prog method, we see that the only difference between it and any arbitrary method is that we can't return a Direct Style value from it. If we remove the call to System.exit, we can call prog from normal Direct Style code, with CPS being used within prog and all of its converted subroutines. Program execution within the CPS code proceeds normally using CPS, each method ending by passing a continuation along to the next method. After the last continuation is finally executed, the CPS code is done and control returns to the caller of prog.
def prog(m:M) {
    main(m, { (z:Z) =>
        println(z)
    })
}

Uses

We have gone to a lot of trouble to restructure our code to use CPS while keeping the functionality the same. Now we can examine how we can make changes to the code that are only possible because it uses CPS.

The key ability that CPS gives us is that we have an explicit object (the continuation) representing the remainder of execution of our program (or, in the case of a delimited continuation, of a portion of our program). In the code sample above, we executed that continuation once we reached the end of the line in sub. But what would happen if, instead of executing the continuation at that point, we just saved it somewhere, such as into a singleton?
object ContinuationSaver {
    var savedContinuation:Option[()=>Unit] = None
    def save(saveCont: =>Unit) = savedContinuation = Some(saveCont _)
}
def sub(m:M,x:X, subCont: (Y) => Unit) {
    val y:Y = substuff(m,x)
    ContinuationSaver.save { subCont(y) }
}
After sub saves the continuation, it is done, and in fact the entire delimited continuation is done; control returns to the caller of prog. But in ContinuationSaver we still have the continuation that represents execution of the remainder of that portion of the program, which we can execute later. In effect, we have placed the execution of that code into suspended animation, to be revived at some later time of our choosing.

Not only can we call the continuation later, we can call it multiple times. We can also write a more sophisticated ContinuationSaver that can save multiple continuations and keep track of which ones we should execute later, including the order and whether to call them multiple times. We can even save the continuations to persistent storage or move them to another computer, as is done by Swarm.

CPS With Return

In pure CPS, there are no returns. But code in Scala does return, even when we are using CPS. In the previous section I used the phrase "control returns to the caller of prog." This happens in the normal way, by having each of the intervening methods return to its caller until the stack unwinds to the first CPS call. I have assumed that each CPS method returns no value (Unit), but there is nothing preventing us from adding code to each method in the transformed CPS chain to make it return a value.

The examples above demonstrate a transformation from Direct Style code to CPS code, and that transformation always results in code that returns Unit. If we add a return value to the transformed code, this is not something we can get as a result of using the above transformation technique.

What happens if we add a return value to our CPS code? In our examples above, the execution of the continuation was always the last thing in the subroutine. If we keep this as our default behavior, then when we change the CPS methods to return a value, the return value from the last CPS method in a chain of continuations will propagate back up through the chain of CPS callers all the way out to the topmost CPS method, and will appear to the Direct Style code as the value of that outermost method. Of course, one of the intervening CPS method might modify or replace that value as it is being returned through it.

For example, let's take the most recent version of sub above (the one that saves the continuation for later execution) and make it return an Int value:
object ContinuationSaver {
    var numberOfSavedContinuations = 0
    var savedContinuation:Option[()=>Unit] = None
    def save(saveCont: =>Unit):Int = {
        savedContinuation = saveCont _
        numberOfSavedContinuations = numberOfSavedContinuations + 1
        numberOfSavedContinuations
    }
}
def sub(m:M,x:X, subCont: (Y) => Unit):Int = {
    val y:Y = substuff(m,x)
    ContinuationSaver.save { subCont(y) }
}
We also change the rest of the methods in our calling chain to allow us to propagate this value all the way out. Since the call to sub is the last call in main, all we need to do is change the return type on main to match the return type of sub. Likewise, since the call to main is the last call in prog, we change the return type of prog to match the return type of main:
def prog(m:M):Int = {
    main(m, { (z:Z) =>
        println(z)
    })
}

def main(m:M, mainCont:(Z)=>Unit):Int = {
    val x:X = pre(m)
    sub(m,x, { (y:Y) => {
        val z:Z = post(m,x,y) } )
        mainCont(z)
    })
}
We could, if we wanted to, modify main to make a change to the value returned by sub before passing it back as its own return value, or we could make main return something else entirely.

If you think about the CPS code as having been created by transforming some Direct Style code, you can see that the untransformed code had its original return type, and the now-CPS transformed code has a (potentially different) transformed return type.

Reset and Shift

Finally, we have enough background to understand Scala's reset and shift keywords.

The Scala implementation of delimited continuations was created by Tiark Rompf of EPFL, and is described in his explanatory paper on Delimited Continuations in Scala with co-authors Ingo Maier and Martin Odersky. There are also some quotes below from some of Tiark's posts.

Reset is the keyword that demarcates the limits of the delimited continuation. Within the body of the reset, the code is CPS code; the return value of reset is not CPS.

Shift is the keyword that indicates the bottoming out of the CPS path. The body of the shift is not CPS code, but it's untransformed return value is CPS. The shift call gets passed as its argument the continuation that has been collected from all of the callers out to the (dynamically) enclosing reset.

Reset and shift are thus the keywords that take you from Direct Style to CPS, and from CPS to Direct Style, respectively. All of the code between reset and shift is CPS. Any method that includes shift must be marked as CPS, and any method that calls a CPS method must be marked as CPS, until you reach the enclosing reset call.

When you use reset and shift in your code, the continuations compiler plugin transforms your code in a manner similar to the CPS transformation I described above. All of the code from the end of the shift block to the end of the enclosing method or reset block is packaged up as a closure and passed to the body of the shift block as the continuation function.

Let's break down some examples of reset and shift in Scala.
reset {
  shift { k: (Int=>Int) =>
    k(7)
  } + 1
}
The shift statement tells the compiler plugin to restructure the code as in our CPS examples, by converting the code after the shift call into a continuation that gets passed as an argument to the shift. To make it easier to see what that means in this case, let's do that code transformation in a few steps.

First, we assign the result of the shift call to a variable and use that variable later in the code:
reset {
  var r = shift { k: (Int=>Int) =>
    k(7)
  }
  r + 1
}
Second, we convert all of the code following the shift into a function and call it:
reset {
  var r = shift { k: (Int=>Int) =>
    k(7)
  }
  def f(x:Int) = x + 1
  f(r)
}
The function f is our continuation function that represents all of the code between the end of the shift block and the end of the enclosing reset block. Finally, we transform the code as is done by the compiler plugin, binding our continuation function f(x) to the shift parameter k, and making the return value of the fully transformed code be the return value of the body of the shift:
reset {
  def f(x:Int) = x + 1
  f(7)
}
Now we can easily see that the return value is 8.

We can apply the same transformations to
reset {
  shift { k: (Int=>Int) =>
    k(k(k(7)))
  } + 1
}
to get
reset {
  def f(x:Int) = x + 1
  f(f(f(7)))
}
from which we can quickly calculate that this will return a value of 10.

All of our transformations have no effect on anything outside of the reset; for example,
reset {
  shift { k: (Int=>Int) =>
    k(7)
  } + 1
} * 2
just multiplies the return value of the reset expression by 2, so the result of this code snippet would be 16.

Tiark's paper gives this interesting example:
reset {
  shift { k: (Int=>Int) =>
    k(k(k(7))); "done"
  } + 1
}
and points out that the value of this code snippet is "done". The continuation function k is called three times, but the value of that expression is discarded. If we apply our code transformations as before, we see that this transforms into:
reset {
  def f(x:Int) = x + 1
  f(f(f(7))); "done"
}
which makes it more obvious why the result of this code snippet is "done".

A key detail to note here is that the value of the evaluated reset block is not the value of the last expression in that block, as it is in most code. Instead, the value of the evaluated reset block is the value of the last expression in the shift block that gets executed within that reset block. Execution of the body of the shift is always the last thing that happens within the enclosing reset block.

When you look at a shift block and see its return value being used in an expression, as in the "shift + 1" examples above, remember that, due to code transformation, that "return" from the shift block never actually happens as a return. Instead, once execution reaches the shift block, the code after that block gets passed to it as a continuation; if the code in the shift block calls the continuation, the value which is passed as an argument to the continuation appears as the value being returned from the shift block. Thus the type of the argument passed to the shift block's continuation function is the same as the type of the return value of the shift in the source code, and the type of the return value of that continuation function is the same as the type of the return value of the original last value in the reset block that encloses the shift block.

There are thus three types associated with shift:
  • The type of the argument to pass to the continuation, which is the same as the syntactic return type of the shift in the source code.
  • The type of the return from the continuation, which is the same as the return type of all of the code that follows the shift block in the source code (i.e. the type of the last value in the block of code between the shift block and the end of the function or reset block containing the shift block). This is called the untransformed return type.
  • The type of the last value in the shift block, which becomes the type of the return value of the enclosing function or return block. This is called the transformed return type.
In the signature for shift, the above three types appear as A, B and C, respectively:
def shift[A, B, C](fun: ((A) => B) => C): A @scala.util.continuations.cpsParam[B,C]
The two types in the cpsParam annotation always represent the untransformed and the transformed return types, respectively. The CPS annotations are described in more detail below.

The signature for reset only uses two types: the first type is the untransformed type of the code block passed to reset, which matches the B type of shift, and the second type is the type of the transformed code block, which matches the C type of shift, and is also the real return type of the reset block to its caller. The scaladoc for reset uses parameter type names A and C, but I write it here using B and C so that the signature of the ctx by-name parameter matches the signature of the return value of shift:
def reset[B, C](ctx: => B @scala.util.continuations.cpsParam[B,C]): C   
Here's where those types appear:
C = reset { ...; A = shift { k:(A=>B) => ...; C } ...; B }  
In the following example, A=Int, B=String and C=Boolean:
def is123(n:Int):Boolean = {
  reset {
    shift { k : (Int=>String) =>
      (k(n) == "123")
    }.toString
  }
}

Annotations

As you saw above, the signatures for reset and shift include the cpsParam annotation. The compiler plugin uses this type annotation to select what pieces of code to transform to CPS; in Tiark's paper this is referred to as a "type-directed selective CPS transform." If you just use reset and shift without any subroutine calls, you may never need to explicitly use a CPS annotation. But if you put any shift calls into subroutines, as described below, then you will need to use a CPS annotation.

The base annotation is cpsParam[-B, +C]. This annotation tells the compiler that the corresponding block of code has an untransformed return value of type B and a transformed return value of type C, as described in the discussion of the types for reset and shift above.

To simplify the annotation for the common case where the transformed return type is the same as the untransformed type, the continuations package defines the convenience type cps:
type cps[A] = cpsParam[A, A]
If you are looking at old posts on the web, be aware that the cpsParam annotation used to be called simply cps; the old cps annotation was renamed to cpsParam and the new one-type-parameter cps type alias was added.

In the Uses section above we discussed the possibility of saving away the continuation for later execution, after which control returns to the caller. If we do this, we can't return a value from the suspended code to the original caller, since that code has not been executed yet, and the eventual executor of the continuation may not know where it came from, so it too is likely not to care about a return value.

In order to simplify the source code for this typical case, the Scala continuations library includes a special annotation type, suspendable:
type suspendable = cpsParam[Unit, Unit]
In addition to being more succinct, this annotation type can be used to make it clear that this function may suspend its continuation so that it can finish execution later.

Nested Shift

In all of the above examples, the shift block appears directly inside the reset block, and the cpsParam type of the reset block must match the cpsParam type of the shift block.

What happens if you put the shift block in a separate function and call that function from the reset block? In this case, the function containing the shift block must be marked as a CPS function by using the cpsParam annotation on its return type, and that cpsParam type must be the same as the cpsParam type of the enclosed shift block. When this function is invoked from within the reset block, the compiler plugin knows how to transform that block such that the code after the call to the CPS function becomes part of a continuation which is passed in to the CPS function, just as in the Nested CPS examples above.
def is123(n:Int):Boolean = {
  reset {
    is123sub(n)
  }
}

def is123sub(n:Int):String @cpsParam[String,Boolean] = {
    shift { k : (Int=>String) =>
      (k(n) == "123")
    }.toString
}
The function containing the shift block can be refactored to push that shift block down into another function, in which case that new function must also have the same signature as the original function and the shift block. Thus the entire chain of functions between the reset and the shift are all tied together with the same CPS signature.

What if you have an existing CPS function, but you want to call it and change its return type? If you were to follow the pattern of regular code, you might start by trying something like this in order to return floating point 1 or 0 rather than the Boolean true or false returned by a reset block that just calls is123sub.
//this won't compile
def is123f(n:Int):Float = {
    reset {
        val x = is123sub(n)
        if (x) 1.0 else 0.0
    }
}
This does not work as expected; the line of code following the call to is123sub is not operating on what will be the return value of the reset block, despite it being the last statement in that block. Instead, due to the code transformation described above that is being done by the CPS compiler plugin, code added after the call to is123sub gets bundled up as part of the continuation passed to the shift block within is123sub. The code that follows the call to the CPS function must end with a type that matches the first parameter of the cpsParam part of the signature of the function; in this case, String The untransformed return type of is123sub is also String, so in this case the block of code that follows the call to is123sub must take a String (as the return value of the call to is123sub) and must also return a String (which becomes the return value of the shift block within is123sub).

If we want to intercept the Boolean value that is being calculated in the shift block within is123sub, we must do that from within another shift block. The body of a shift block is written in Direct Style, and our subroutine is123sub is CPS, so we can't call it from within the new shift block. What we have to do is to put the new shift block before the call to is123sub. The call to is123sub then becomes part of the continuation that is passed to the new shift block, and we can add code within the new shift block that receives the transformed result of the shift block in is123sub and converts it as desired.

To see the control flow a little more clearly, you can execute this code snippet:
reset {
    println("A")
    shift { k1: (Unit=>Unit) =>
        println("B")
        k1()
        println("C")
    }
    println("D")
    shift { k2: (Unit=>Unit) =>
        println("E")
        k2()
        println("F")
    }
    println("G")
}
Here's the output the above code produces:
A
B
D
E
G
F
C
You can see from the order of execution that the second shift block is being executed as part of the continuation that is passed to the first shift block. Despite the fact that one appears before the other in the source code, the two shift blocks are actually nested. The compiler plugin notices this and handles them slightly differently to prevent the nested shift block from escaping from the enclosing reset block.

To show how all of the types thread together, here is a little piece of code with explicit type annotations on the reset and shift blocks in which you can see sets of places for which the same type needs to be used. The assert statements help show how the values are getting passed around.
def nestedShifts[T1,T2,T3,T4,T5](t1:T1,t2:T2,t3:T3,t4:T4,t5:T5):T2 = {
    reset[T1,T2] {
        val s1:T3 = shift[T3,T5,T2] { k1: (T3=>T5) =>
            val r1:T5 = k1(t3)
            assert(r1==t5)
            t2  //this is the return value of nestedShifts
        }
        assert(s1==t3)
        val s2:T4 = shift[T4,T1,T5] { k2: (T4=>T1) =>
            val r2:T1 = k2(t4)
            assert(r2==t1)
            t5
        }
        assert(s2==t4)
        t1
    }
}
If you get a compiler error when nesting CPS functions like this, try modifying the code to assign the value of the nested CPS function to a local variable, then end with that variable:
def is123f(n:Int):Float = {
    reset {
        val x = shift { k:(Int=>Boolean) =>
            if (k(n)) 1.0f else 0.0f
        }
        val r = is123sub(x)
        r
    }
}
If you leave out the val r and just end the reset block with the call to is123sub, you will get an error such as this:
<console>:13: error: type mismatch;
 found   : String @scala.util.continuations.cpsParam[String,Boolean]
 required: String @scala.util.continuations.cpsParam[String,Float]
           is123sub(x)
                   ^

Control Construct Restrictions

Because of the code transformation performed by the continuations compiler plugin, there are some control constructs that can not be used when calling a CPS function.

Using return statements in a CPS function is unlikely to do what you expect, and may cause type mismatch compiler errors, so you should not use them.

When using an if statement, you may get an error like this:
Foo.scala:21: error: then and else parts must both be cps code or neither of them
Tiark's advice is not to use explicit return, and maybe use shiftUnit on the non-CPS value.

The compiler plugin does not handle try blocks, so you can't catch exceptions within CPS code. Those exceptions will be propagated out to the enclosing reset block and can be caught there - unless the continuation is suspended and executed later, in which case any exceptions would be propagated to the reset block of the code doing that later execution.

You need to be careful when using looping constructs. As Tiark says,
Capturing delimited continuations inside a while loop turns the loop basically into a general recursive function.
You can follow the above link for details, but basically each invocation of shift within a looping construct allocates another stack frame, so after "looping" many times you will likely get a StackOverflowError.

Some looping constructs can not be used with a shift inside them. To quote Tiark again:
In a reset block you can do anything, but shifts are not allowed everywhere. The limitation is that everything on the call path between a shift and its enclosing reset must be "shift-aware". That rules out the regular foreach, map and filter methods because they know nothing about continuations, so they can't call closures containing shift.

Advice

As I mentioned at the start of this post, it took me some time to feel that I had a good understanding of how reset and shift work. You may not get it in one reading of this post. As with any new coding concept, the best way to gain a working understanding is to try using it in some of your own code. You will need patience; the CPS error messages are not always clear.

If you are interested in playing with control constructs, such as actors or generators, then you should definitely take the time to understand reset and shift. You might also want to take a look at Swarm.

On the other hand, you may never need to deal with reset and shift. Now that they are available in Scala, I expect some people will create libraries that build on reset and shift to present APIs for developers that are simpler to understand. Still, even when using those simpler APIs you may find that an understanding of the content of this post will be useful.

Resources

Updated 2010-08-09 to fix error pointed out by mgm7734.
Updated 2010-09-26 to fix error pointed out by Nikolay.

9 comments:

mgm7734 said...

Error:

The A type is the syntactic (or "untransformed") return value of the shift block, but the C type is the real return value of the transformed shift block.

should be

The B type is...

Jim McBeath said...

mgm7734: Thanks for pointing out that problem. I had meant that to refer to the "A = shift" portion, but did not properly describe that. I have removed that sentence and instead added the "Here's where those types appear" example.

Tiark said...

Excellent writeup!

toland said...

Crystal explanations especially the examples around the parameterization of shift and rest! Brilliant writing :)

Nikolay said...

Thank you very much for clean explanation of complex things! Btw, there are small typo in section "Uses":

def save(saveCont: =>Unit) = savedContinuation = saveCont _

should be

def save(saveCont: =>Unit) = savedContinuation = Some(saveCont _)

Jim McBeath said...

Nikolay: Thanks, I have made that correction in the post.

Dimitri said...

Best explanation I've come across yet. Thanks!

Cay Horstmann said...

I didn't have to use -P:continuations:enable with scala, only with scalac. This is with version 2.8.1.

Lachlan said...

This is the best introduction to Scala continuations that I've yet read. Well done.