Play With Lua!

Iterators

without comments

Lua doesn’t have a lot of control structures. There’s the obvious if statement, the while loop and repeat/until loop, and the for loop. Mostly, the for gets used to iterate over tables:

t = {'a','b','c','d','e'}
 
for i, v in ipairs(t) do
  print(i,t)
end

It’s annoying to have to remember to type ipairs every time. I’ve forgotten more than once. But, that minor annoyance is a good trade for the benefit of what the for statement actually does: generic iteration.

What’s an iterator?

The for statement isn’t just used for looping across tables, it loops through any sort of sequence represented by an iterator. An iterator in Lua is any function that conforms to a certain interface, and a lot of problems can be made simpler by writing your own iterators, instead of just using pairs and ipairs.

Let’s look at a standard loop:

t = {a=1, b=2, c=3, d=4}
for k, v in pairs(t) do
  print(k, v)
end

This is using the standard library iterator pairs. Let’s take a look at what that call to pairs actually returns:

pairs(t)
-- function: 0x418590    table: 0x141b600    nil

Not to keep you in suspense, the table it returns is t itself, and the function is the builtin function next. What next does is take a table and a key in the table, and return the “next” key / value. It’s guaranteed to go over all the keys, in no particular order:

next(t, nil)
-- c    3
next(t, 'c')
-- b    2
next(t, 'b')
-- a    1
next(t, 'a')
-- d    4
next(t, 'd')
-- nil

So let’s write that original for loop in a more explicit way:

for k, v in next, t, nil do
  print(k, v)
end

Try it, it does the same thing. In fact, if you just want to go over part of the table, you can pass in an initial argument for next:

for k, v in next, t, 'b' do
  print(k, v)
end

That will print out just the ‘a’ and ‘d’ keys (yours may vary; next iterates in an arbitrary order). That’s all the for statement does: it takes a function, an “invariant” first argument (the table), and an initial second argument. It calls the function repeatedly, making the second argument of each call be the first return value of the previous call, until the function returns nil.

Making an iterator

With that in mind, we can write our own iterators. Here’s one that loops over members of the triangular series:

function triangular(_, n)
    if not n then n = 0 end
    n = n + 1
    return n, n * (n+1) / 2
end
 
for n, v in triangular do
    print(v)
    if n == 10 then break end
end

Since this iterator will never naturally end, we insert a break statement after a while.

Note how easy calling the iterator is compared to writing it, not that writing it is very hard. This is pretty common with iterators; you spend some effort writing one in order to make the rest of the code simpler.

A more complex example: depth-first traversal

So let’s write some iterators that do actually-useful things, like traversing a tree. What we’d like to be able to do is visit every node of a tree, and get the value of that node and the path down to it from the root. For example:

tree = {a = {p = {p = {l = {e = {}}}},
             n = {t = {}}}}
 
for n, path in dft(tree) do
  print(n, inspect(path))
end

(I’m using inspect.lua to print the path, a really handy library).

This iterator is a little different from the other one: it has complex state, rather than just a single number, so we’ll have dft just return a function that keeps the state in a closure. You can do that; for doesn’t care if your invariant state is nil, it’ll still pass it in every time but you can just ignore it. So, here’s the code:

function dft(tree)
    local value_stack = {}
    local node_stack = {}
 
    return function()
        -- These represent the current node:
        local value, node = value_stack[#value_stack], node_stack[#node_stack]
 
        -- Now, to find the next node:
        if not next(node_stack) then
            -- Node stack empty, push the root node:
            table.insert(value_stack, (next(tree)))
            table.insert(node_stack, tree)
 
        elseif next(node[value]) then
            -- Otherwise, if the current node has children, push them on to the stack:
            table.insert(value_stack, (next(node[value])))
            table.insert(node_stack, node[value])
 
        elseif next(node, value) then
            -- Otherwise, if there's a right sibling, alter the stack to show it:
            value_stack[#value_stack] = next(node, value)
 
        else
            -- Otherwise, pop the stack and find the next node of our parent
            while true do
                table.remove(node_stack)
                table.remove(value_stack)
 
                -- Must be the end of the tree:
                if not next(node_stack) then return nil end
 
                local value, node = value_stack[#value_stack], node_stack[#node_stack]
                if next(node, value) then
                    value_stack[#value_stack] = next(node, value)
                    break
                end
            end
        end
 
        -- Return the top of the value stack, and the current value stack
        return value_stack[#value_stack], value_stack
    end
end

This is pretty straightforward. We keep a stack of the node labels / values we’ve visited, going back up to the root. We start with empty stacks, and find the next node in the traversal:

  • If the stacks are empty, the next node is the root.
  • If the current node (top of the stacks) has children, the next node is the first child.
  • If the current node has no children but a next sibling, then it’s next.
  • Finally, if none of those are true, we go to the previous node and look for a next sibling there.

After all that, the value stack has a path to the current node, so we return its value, and the value stack itself. When I run it as above, I get this:

a	{ "a" }
n	{ "a", "n" }
t	{ "a", "n", "t" }
p	{ "a", "p" }
p	{ "a", "p", "p" }
l	{ "a", "p", "p", "l" }
e	{ "a", "p", "p", "l", "e" }

Coroutine iterators

But, doing it iteratively like that is somewhat of a pain. It’s more natural to traverse a tree recursively. But how do we recurse in an iterator?

First, let’s write this as a coroutine. Forget about iterators for now, let’s just write a coroutine that will yield all the nodes / paths:

stack = {}
 
function traverse(node)
    for k, v in pairs(node) do
        table.insert(stack, k)
        coroutine.yield(k, stack)
        if type(v) == 'table' then
            traverse(v)
        end
        table.remove(stack)
    end
end
 
co = coroutine.create(traverse)

Then, we can call it like this:

repeat
    local _, node, path = coroutine.resume(co, tree)
    print(node, inspect(path))
until not node

(We pass in the same tree every time but that’s okay, because the argument is ignored every time after the first.)

So, let’s now turn this general pattern into an iterator:

function co_dft(tree)
    local stack = {}
 
    local function traverse(node)
        for k, v in pairs(node) do
            table.insert(stack, k)
            coroutine.yield(k, stack)
            if type(v) == 'table' then
                traverse(v)
            end
            table.remove(stack)
        end
    end
 
    local co = coroutine.create(function() traverse(tree) end)
 
    return function()
        local _, value, stack = coroutine.resume(co)
        return value, stack
    end
end

It’s a pretty simple transformation. Move everything into the local scope of the iterator, and return a function that’s a wrapper for coroutine.resume. We can do this to make an iterator out of any coroutine, actually. And now that it’s an iterator, we can call it just like the iterative version:

for node, path in co_dft(tree) do
    print(node, inspect(path))
end

Iterators are powerful

So, that’s how iterators work. It’s more than just an awkward syntax for a for loop; it’s actually an incredibly powerful feature of Lua.

As always, the code for this is available on Github.

Written by randrews

May 16th, 2015 at 12:25 am

Posted in Uncategorized