Writing a Language in Truffle. Part 4: Adding Features the Truffle Way

:: java, truffle, tutorial

I ended last time with a lisp that had the bare minimum of features and had a reached an acceptable speed. Now it’s time to make Mumbler a more useful languages with a couple of new features: arbitrary precision integers and—what no lisp should be without—tail call optimization.

I don’t want to undo all the work it took to make Mumbler fast, so I’m going to show how Truffle can help to include these features and still keep the langauge fast.

Table of Contents

Arbitrary Precision Arithmetic

If you go way back to the first post I stated that I was going to limit Mumbler’s number types to long in the interest of simplifying the implementation of the interpreter. Well now I have the interpreter written and the users are (theoretically) clamouring for more robust arithmetic. Let’s take a look at how we would implement it.

The thing I don’t want to do is replace our use of long with BigInteger throughout our interpreter. That would wreck havoc on performance when most uses of numbers will stay within the size of a long. What I want is to only fallback on BigInteger if the user specifies a number that can’t fit in a long or an arithmetic operation results in a number that is too big.

Adding BigInteger to Mumbler’s Types

Since I’m adding a new data type to Mumbler’s interpreter I need to update the class that defines all the built-in types. So I crack open MumblerTypes and add BigInteger

@TypeSystem({long.class, boolean.class, BigInteger.class, MumblerFunction.class,
    MumblerSymbol.class, MumblerList.class})
public class MumblerTypes {
    @ImplicitCast
    public static BigInteger castBigInteger(long value) {
        return BigInteger.valueOf(value);
    }
}

The change is simple enough. I add BigInteger.class to the list in @TypeSystem. Remember that the order of the types matters. Types that appear earlier bind more tightly than later ones. If I put BigInteger before long than operations would try to return BigInteger first, succeed, and never get to long. Bad news for performance.

The other change is the new castBigInteger method. The method does a simple conversion from long to BigInteger. The key is the @ImplicitCast annotation. The annotation tells Truffle a long can be converted to a BigInteger, and use this method if we ever need to do that.

The last thing I need to do to make BigInteger a first class citizen is allow all MumblerNode objects to return a BigInteger object. So we just add another method to our base node.

@TypeSystemReference(MumblerTypes.class)
@NodeInfo(language = "Mumbler Language", description = "The abstract base node for all expressions")
public abstract class MumblerNode extends Node {
    // other execute methods...

    public BigInteger executeBigInteger(VirtualFrame virtualFrame)
            throws UnexpectedResultException {
        return MumblerTypesGen.expectBigInteger(this.execute(virtualFrame));
    }
}

That’s it. Now all nodes can return BigInteger numbers. One problem, nothing yet returns BigInteger numbers. Where would we want to return arbitrarily long numbers? Why, when we add, substract or multiply two numbers and they no longer fit in a long. Of course!

The Truffle Implementation

The first problem is Java’s +, - and * operators just wrap around if we go passed MAX/MIN_VALUE. Thankfully, Java8 added new methods to java.lang.Math that will throw an exception if the operation overflows. Since Truffle supports Java7 they backported these methods to their com.oracle.truffle.api.ExactMath class. Now we can fallback on BigInteger if the value won’t fit in a long. For example, the AddBuiltinNode’s add method now looks like this.

public long add(long value0, long value1) {
    return ExactMath.addExact(value0, value1);
}

We still need to catch the exception. Truffle’s @Specialization annotation has a rewriteOn field where we can say “if an ArithmeticException is thrown rewrite the AST and upcast long to BigInteger.” Of course, if we’re going to upcast to BigInteger we’ll need a method add two BigInteger objects.

@NodeInfo(shortName = "+")
@GenerateNodeFactory
public abstract class AddBuiltinNode extends BuiltinNode {
    @Specialization(rewriteOn = ArithmeticException.class)
    public long add(long value0, long value1) {
        return ExactMath.addExact(value0, value1);
    }

    @Specialization
    protected BigInteger add(BigInteger value0, BigInteger value1) {
        return value0.add(value1);
    }
}

This is the new AddBuiltinNode class. It contains the old add method for long but now uses ExactMath.addExact, and the @Specialization annotation now states to rewrite the AST on an ArithmeticException. We’ve also added add for BigInteger. Since we don’t expect any exception here it just adds and returns.

That’s all we need to do to upgrade addition to gracefully move to BigInteger if long becomes too small. The changes to subtraction and multiplication work the same way.

Adding Literal BigInteger Numbers

The final piece of the arbitrary precision puzzle is allowing users to write any number and have Mumbler create the proper number type. The code to create literal long nodes is in place. I just need to add a literal node for BigInteger if the number can’t be cast to long. First the literal BigIntegerNode class.


public class BigIntegerNode extends MumblerNode {
    public final BigInteger value;

    public BigIntegerNode(BigInteger value) {
        this.value = value;
    }

    @Override
    public BigInteger executeBigInteger(VirtualFrame virtualFrame) {
        return this.value;
    }

    @Override
    public Object execute(VirtualFrame virtualFrame) {
        return this.value;
    }
}

Very simple. Just return the BigInteger object. So where do I create these objects? I have to modify the Reader to create BigIntegerNode objects if LongNode won’t work. I’ll optimistically try to create a long and only if there’s an exception will Reader fallback on BigIntegerNode. Here’s the relevent part of the Reader.readNumber.

private static Convertible readNumber(PushbackReader pstream)
        throws IOException {
    // read number from PushbackReader...
    try {
        return new LiteralConvertible(new LongNode(
                Long.valueOf(buffer.toString(), 10)));
    } catch (NumberFormatException e) {
        // Number doesn't fit in a long. Using BigInteger.
        return new LiteralConvertible(new BigIntegerNode(
                new BigInteger(buffer.toString(), 10)));
    }
}

If Long.valueOf throws a NumberFormatException (which in this method will only happen if the number is too big to fit into a long value) then we create a BigIntegerNode. We’re done. We don’t need to make any other changes. Now you can add to your heart’s content!

A Lisp Birthright: Tail Call Optimization

Perhaps “birthright” is a little strong. There are lisps out there without tail call optimization (TCO) like Clojure though it’s not for lack of trying. The JVM doesn’t natively support tail call optimization, but now with Truffle we have a way a way around So many languages have on the JVM would love to have TCO and now we have salvation from imperative purgatory. Hooray!

For a language without explicit loop constructs like for or while, Mumbler needs to be extra careful not to blow the call stack. This kinda makes TCO a requirement. A quick test on my machine shows Mumbler fills up the call stack after about 800 iterations. So don’t expect do anything more than 800 times or your program will crash. Not very pleasant.

To build this feature I had to bone up on TCO. As a user of Scheme, I was famililar at how to use TCO, but not how would I go about implementing it. In fact, I wasn’t sure what constituted TCO. I mostly relied on TCO when I wrote recursive functions, but after reading the Wikipedia page on Tail Calls I realized that you could do TCO on any function call that’s the last (leaf) node of the AST. So let’s go and optimize all function calls in the tail position.

What is Tail Call Optimization

But first, what does it mean to optimize a tail calls? Like I said, the goal is to make sure we don’t get a StackOverflowException—especially if we have a recursive function that will have a lot of iterations. So how do we avoid this? The basic idea behind tail call optimization is:

Once you have nothing left to do in your current function (aside from the final function call) you don't need that function's scope anymore. You can jump back to the caller with the arguments and the lexical scope and call the final function from the caller.

This way, a recursive function will take only two entry in the call stack: the caller’s frame and the current function being executed. Of course if function mades other function calls within that aren’t in the tail position those will add to the call stack. That’s why you sometimes have to rearrange a function body so it is optimized for tail calls.

Every time the last function is about to be called, you first jump back to the caller and call it from there. For example, say we were computing the factorial of a number, a prototypical lisp function. The code would look something like this:

(define factorial (lambda (n product)
                    (if (< n 2)
                        product
                        (factorial (- n 1) (* n product)))))

Without TCO Mumbler would create a new frame on the call stack for every call to factorial like:

<main>             <main>             <main>
(fibonacci 3 1)    (fibonacci 3 1)    (fibonacci 3 1)
                   (fibonacci 2 3)    (fibonacci 2 3)
                                      (fibonacci 1 6)

You can see how quickly this grows even though we don’t need the intermediate frames. It would be great if we could have a stack more like:

<main>             <main>             <main>
(fibonacci 3 1)    (fibonacci 2 3)    (fibonacci 1 6)

The one constant in all this is the top frame. The top frame can make all the intermediate calls on behalf of the functions and receive the result when we reach the terminal state (when n is less than 1). So how do we do that in Truffle?

Tail Call Optimization in Truffle

So how do we kick out of a function call? Well, we could call return with some special object that says “This is a tail call. Complete function call”, but that would get tedious having to check for a special object on every call and it won’t be efficient. Truffle doesn’t make us do that thankfully. Truffle uses Java’s other strategy for unwinding the call stack: exceptions.

Truffle has a special exception class ControlFlowException that should be used for all flow control. Graal has special knowledge of this class and its children so it can optimize away all the internal function calls. This way, control structures like for and while or even structures like break and return can be as fast in Graal as they are in regular Java.

So let’s create our special tail call exception.

public class TailCallException extends ControlFlowException {
    public final CallTarget callTarget;
    public final Object[] arguments;

    public TailCallException(CallTarget callTarget, Object[] arguments) {
        this.callTarget = callTarget;
        this.arguments = arguments;
    }
}

That was easy. The class contains the function (CallTarget) that’s going to be called plus all the arguments. All the arguments have been evaluated before we throw the exception. The CallTarget objects are created when we build our functions. You can read the previous post on function creation to see how it’s created. With these two pieces of information I can make a function call. Technically, we need a VirtualFrame object but we’ll get that from the caller.

Starting a Tail Call

So when we’re about to execute that final node in the function’s AST, we want to throw a TailCallException instead. All function calls occur in the InvokeNode.execute method so let’s update that code.

@Override
@ExplodeLoop
public Object execute(VirtualFrame virtualFrame) {
    MumblerFunction function = this.evaluateFunction(virtualFrame);
    CompilerAsserts.compilationConstant(this.argumentNodes.length);

    Object[] argumentValues = new Object[this.argumentNodes.length + 1];
    argumentValues[0] = function.getLexicalScope();
    for (int i=0; i<this.argumentNodes.length; i++) {
        argumentValues[i+1] = this.argumentNodes[i].execute(virtualFrame);
    }

    if (CompilerAsserts.compilationConstant(this.isTail())) {
        throw new TailCallException(function.callTarget, argumentValues);
    } else {
        return this.call(virtualFrame, function.callTarget, argumentValues);
    }
}

The method starts like before: evaluate the function, evaluate the arguments. After that we check if the node is in the tail position. I’ll show how that’s set later; for now assume it’s set to the correct value. There’s a call to CompilerAsserts.compilationConstant to tell Truffle this value is constant. A function call in the tail position isn’t going to move so we may as well eke out as much performance as we can. If the node is in the tail position then we create a TailCallException and throw it. If not we make the call like normal.

Catching a Tail Call

Starting a tail call was straightforward, but how do we catch a TailCallException? Furthermore, where do we catch a TailCallException?

I want to catch the exception in the body of the caller. So, when I call a function there’s a chance that function may say, “Hey, take care of this function call for me.” Okay, so when we call a function we have to wrap it in a try/catch in case it throws a TailCallException. But wait, what about the function call from TailCallException? Couldn’t that also throw another TailCallException? Yes! In the factorial example above the caller function () had to handle 3 tail calls. So not only do we have to handle the normal function, but we have to be prepared to catch any number of TailCallException. How will I do that? Let’s look at the InvokeNode.call method.

public Object call(VirtualFrame virtualFrame, CallTarget callTarget,
        Object[] arguments) {
    while (true) {
        try {
            return this.dispatchNode.executeDispatch(virtualFrame,
                    callTarget, arguments);
        } catch (TailCallException e) {
            callTarget = e.callTarget;
            arguments = e.arguments;
        }
}

You can see buried underneath all the wrappings is a call to our dispatch node that will take care of the actual call. Outside of that we catch the TailCallException and keep trying until we return normally. We basically keep going around and call the dispatch node however many times it takes until a TailCallException is not thrown. If a call throws a TailCallException we catch it and start over with the new values.

Dispatching to the Right Function

If you recall when we implemented function calls, Truffle requires all CallTarget objects to be wrapped in a CallNode. Previously Mumbler was using the subclass DirectCallNode because it is faster, but it does have the limitation of only working for one function in one invocation node. That wasn’t really a limitation because one InvokeNode would only ever be linked to one function, but now that an InvokeNode has to handle tail calls the node may have to deal with other CallTarget objects. We can’t rely on one DirectCallNode anymore. We also don’t want to only use IndirectCallNode because it’s much slower. What do we do?

We implement a cache. We take the most common CallTarget objects and wrap them in DirectCallNode objects. To prevent an explosion of DirectCallNode creation I set a limit on the cache size. Once the limit is reached, we’ll fall back on IndirectCallNode for further function calls. This way we get the fast speeds of DirectCallNode for most functions, but everything still works in case the cache is full. All this logic is encapsulated in the DispatchNode children.

The dispatch node starts with UninitializedDispatchNode. The job of this node is to keep track of how big the dispatch cache is. If it exceeds the limit we fall back on IndirectCallNode. If not, we create a DirectCallNode for the current function and use it.

final public class UninitializedDispatchNode extends DispatchNode {
    @Override
    protected Object executeDispatch(VirtualFrame virtualFrame,
            CallTarget callTarget, Object[] arguments) {
        CompilerDirectives.transferToInterpreterAndInvalidate();

        Node cur = this;
        int depth = 0;
        while (cur.getParent() instanceof DispatchNode) {
            cur = cur.getParent();
            depth++;
        }
        InvokeNode invokeNode = (InvokeNode) cur.getParent();

        DispatchNode replacement;
        if (depth < INLINE_CACHE_SIZE) {
            // There's still room in the cache. Add a new DirectDispatchNode.
            DispatchNode next = new UninitializedDispatchNode();
            replacement = new DirectDispatchNode(next, callTarget);
            this.replace(replacement);
        } else {
            replacement = new GenericDispatchNode();
            invokeNode.dispatchNode.replace(replacement);
        }

        // Call function with newly created dispatch node.
        return replacement.executeDispatch(virtualFrame, callTarget, arguments);
    }
}

Starting from the top, the first thing we do is call CompilerDirectives.transferToInterpreterAndInvalidate. This node doesn’t do anything except create other nodes and alters the AST. Graal will need to re-optimize the tree and transferToInterpreterAndInvalidate is how we tell Graal to do that. I then find the end of the linked list of dispatch nodes and also compute its size. I then check if the the max cache size is reached. If so, I switch to using GenericDispatchNode. If not, I create a new DirectDispatchNode. I stick a new UninitializedDispatchNode at the end so it can handle any future changes needed.

If the tail call has been caught there will be a cached DirectDispatchNode waiting for me to reuse. Graal can then inline it. Let’s see what DirectDispatchNode does.

public class DirectDispatchNode extends DispatchNode {
    private final CallTarget cachedCallTarget;

    @Child private DirectCallNode callCachedTargetNode;
    @Child private DispatchNode nextNode;

    public DirectDispatchNode(DispatchNode next, CallTarget callTarget) {
        this.cachedCallTarget = callTarget;
        this.callCachedTargetNode = Truffle.getRuntime().createDirectCallNode(
                this.cachedCallTarget);
        this.nextNode = next;
    }

    @Override
    protected Object executeDispatch(VirtualFrame frame, CallTarget callTarget,
            Object[] arguments) {
        if (this.cachedCallTarget == callTarget) {
            return this.callCachedTargetNode.call(frame, arguments);
        }
        return this.nextNode.executeDispatch(frame, callTarget, arguments);
    }
}

Pretty standard stuff for Truffle. We keep a reference to the CallTarget used to create the node, the DirectCallNode which I use to make the actual function call, and a reference to the next DispatchNode in the chain in case this isn’t the CallTarget we were looking for. The executeDispatch method couldn’t be simpler. We check if the CallTarget of the function is the same as the one used to create this node. If so, we call it. If not, we move one.

What if we get to the end of our cache and need to handle any CallTarget sent to us?

public class GenericDispatchNode extends DispatchNode {
    @Child private IndirectCallNode callNode = Truffle.getRuntime()
            .createIndirectCallNode();

    @Override
    protected Object executeDispatch(VirtualFrame virtualFrame,
            CallTarget callTarget, Object[] argumentValues) {
        return this.callNode.call(virtualFrame, callTarget, argumentValues);
    }
}

When you thought things couldn’t get simpler. We can reuse the same IndirectCallNode for all calls so we just pass in all the values to the IndirectCallNode.

In case the description above was confusing, here’s a flow chart of how one InvokeNode handles tail calls.

Flow diagram for a function call

Flow diagram for a function call

Setting Nodes as Tail

I have all the plumbing in place, but I still haven’t said how to set nodes as tails. On first glance it’s pretty simple: take the last node in a lambda body and set it as tail. There’s only one little wrinkle with that strategy: control flow nodes. Namely, if isn’t the tail, it’s the then/else nodes that are. In other lisps like Scheme the list of control structures can be quite long, but Mumbler only has to worry about if.

So first we add the method isTail to MumblerNode.

public abstract class MumblerNode extends Node {
    @CompilationFinal
    private boolean isTail = false;

    public boolean isTail() {
        return this.isTail;
    }

    public void setIsTail() {
        this.isTail = true;
    }

    // rest of class...
}

Since any node can be in the tail position of a function we need the predicate for all nodes. We’ll set the default to false since most nodes won’t be the last. We need to update the IfNode to propogate its “tailness” to its then/else nodes.

public class IfNode extends MumblerNode {
    @Override
    public void setIsTail() {
        super.setIsTail();
        this.thenNode.setIsTail();
        this.elseNode.setIsTail();
    }

    // rest of class...
}

That was simple. Now all that’s left is to set the last node in a lambda as a tail node. We modify Reader. While we create our LambdaNode we have to call setIsTail before we return the object.

// creating LambdaNode...
List<MumblerNode> bodyNodes = new ArrayList<>();
for (Convertible bodyConv : this.list.subList(2, this.list.size())) {
    bodyNodes.add(bodyConv.convert());
}

bodyNodes.get(bodyNodes.size() - 1).setIsTail();

// finish creating LambdaNode...

We get the last element in the list of nodes of lambda and call setIsTail. With that, Mumbler will start the tail call flow whenever it encounters a function call in a tail position, even one embedded inside an if.

Conclusion

Truffle is really starting to show it’s capabilities. Tail call optimization may have been a little more complicated than I had originally planned, but the dispatch cache is something that any non-trivial language would have to implement. I’m more suprised Mumbler was able to work without it. Adding arbitrary precision numbers was a cakewalk. I didn’t expect the fallback from long to BigInteger would require so little code. I’m wondering how much effort it would take to add a full number stack with floats—ooh, or even rationals!

At this point, you can almost use Mumbler to write Real Code®. It probably needs strings and some more builtin functions. It could probably use the ability to create and call Java objects… stay tuned.

Update: Some Benchmark Numbers

I neglected to show how tail call optimization has affected the speed of Mumbler. First, though, keep in mind the goal wasn’t to make Mumbler faster, but to allow code like this.

(define countdown
  (lambda (n)
    (if (< n 1)
        0
        (countdown (- n 1)))))

This is obviously a contrived example since countdown doesn’t do anything except heat up your computer and return 0, but without tail call optimization this function would eventually throw an exception. That thankfully won’t happen anymore, and with arbitrary precision numbers you can heat up your CPU all you want!

Having said that, how did TCO affect execution speed? Well, I created a benchmark that does basically the same as the example above except it breaks up the execution into smaller recursive function calls so it won’t throw an exception on the non-TCO version of Mumbler. That way we can directly compare TCO-Mumbler with non-TCO-Mumbler.

Here’s what I get with the non-optimized version of Mumbler. (Median after 5 runs).

mumbler (no TCO)
======================
('computation-time 410)

Not so great. How about after I add all that TCO goodness?

mumbler (TCO)
==================
('computation-time 7)

That’s a helluva drop. I wouldn’t expect such a drop since the same amount of work is being done, and Truffle was already optimizing our function calls in the non-TCO version. I think this is mainly due to a bug that was fixed while implementing TCO that allows further optimizations.

If we run the non-TCO version but with the bug fix what do we get?

mumbler (no TCO, bug fixed)
==================
('computation-time 229)

Well it certainly helped, but it didn’t get near 7ms. It looks like Graal can make some excellent optimizations for control flow exceptions that remove a lot of the overhead of function calls and exception throwing. Bravo.