Dec 27, 2022 - 8 minute read - programming

Choosing All Possibilities

There’s a chunk of code I’ve been playing with the last two and a half-ish years (judging by its git history) that for me is just really cool. It’s something I’ve been aware of since reading SICP, and I wrote a version of it probably a decade ago, but it was in a very specific application, rather than the more simplified and generalized form I have it in here. I had been reading about Prolog when I was reminded of it, and that solving things in a normal programming language rather than Prolog would be simpler, if only there was something that dealt with it. After thinking about it more, it could be used to do formal verification, since what TLA+ does is search the state space of the system specified in TLA+ or PlusCal (because raw TLA+ is pretty terrible, though PlusCal isn’t great either). While I’ve not actually done formal verification with it yet, I do have code that uses it to solve mazes, sudoku and KenKen puzzles, figure out valid magic squares, generating power sets, counting in binary.

What is it? The amb operator. It turns out, it admits a very simple implementation that I’ll explain.

The question relates to the question of “how I can I choose all possible combinations of a set of choices in an easy way?” Not that I’m above doing a normal recursive, depth first search, but it turns out that the choice generation side of things can be simple enough so that you really only need to worry about dealing with the choices rather than figuring out how to compute them and handle any backtracking, etc.

As a simple for example, we’ll count in binary:

def binary_counter(c):
    print("%s %s %s" % (c.choose(2), c.choose(2), c.choose(2)))
if __name__ == "__main__":

So here, we’re just saying, choose a number from 0 to 1 – you can think of the argument to choose as being the argument to range. In any event, this code, when run produces:

0 0 0
0 0 1
0 1 0
0 1 1
1 0 0
1 0 1
1 1 0
1 1 1

Or for a more complex example, how to solve a magic square:

def solve_magic_square(c):
    left = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    square = []

	# pick the first row and check for sum == 15
    square.append(c.pick(left))  # 0
    square.append(c.pick(left))  # 1
    square.append(c.pick(left))  # 2
    if square[0] + square[1] + square[2] != 15:

	# pick the next row anh check for sum == 15
    square.append(c.pick(left))  # 3
    square.append(c.pick(left))  # 4
    square.append(c.pick(left))  # 5
    if square[3] + square[4] + square[5] != 15:

	# pick the first of the third row
    square.append(c.pick(left))  # 6

    if (
        # check first column
        square[0] + square[3] + square[6] != 15
        # check diagonal top right to bottom left
        or square[2] + square[4] + square[6] != 15

	# pick the second of the third row
    square.append(c.pick(left))  # 7
	# check the column
    if square[1] + square[4] + square[7] != 15:

	# pick the third of the third row
    square.append(c.pick(left))  # 8
    # check across
    if square[6] + square[7] + square[8] != 15:
    # check column
    elif square[2] + square[5] + square[8] != 15:
    # check diagonal
    elif square[0] + square[4] + square[8] != 15:

    # chooser.stop() # to stop at first solution

Specifically, as it relates to Prolog, it’s much easier to say: stop searching from here (with return) because we cannot validly proceed, than the way you sometimes have to do it in prolog. It also leverages a more conventional language, rather than one which has the oddities that Prolog has.

Here, we start with the list of numbers to consume, and then one by one, we pick items from that list, using a method called pick, defined below, which chooses and removes an item from the passed in list.

def pick(self, l):
    c = self.choose(len(l))
    ret = l[c]
    del l[c]
    return ret

But returning to solve_magic_square, it picks three numbers, checks if they add up to 15, picks another three numbers, checks if they add up to 15, and then over the course of picking the last three numbers, checks columns and diagonals. If it fails to add up to 15 at any point, we just return. If we make it all the way to the bottom, we’ve succeeded, and the result is printed. There are eight answers, but the chooser also has a stop method which we’ll get to which can stop the chooser running as soon as the function calls stop.

How would this work? Let’s start with the binary counter again. Since starting out, whatever run does, it doesn’t know anything at the beginning of that first execution of binary_counter. It can safely choose zero for all calls to c.choose, but something else is needed. So what we can do is say: for the first call to c.choose we’ll return zero, but there needs to be a point where it returns 1 instead, so let’s throw that into a list. For the second call to c.choose we’ll still return zero, but there needs to be a point where it returns zero for the first call, and one for the second – that is, the first choice is the same, but the second one differs. Then for the third call, it will return zero, but we’ll need a point where it’ll return one (while still returning zeros for the first two calls).

This implies two things:

  1. we need a way to prechoose things for parts of other “executions”
  2. we need a place to store these new pre-choices.

Once we’ve exhausted these “executions”, we should be done.

The top level loop would look like this:

def run(f):
    executions = [[]] # the first execution has nothing pre-chosen
    while executions:
        prechosen = executions.pop()
        c = Chooser(prechosen)

The Chooser

 class Chooser:
    def __init__(self, prechosen):
        self.prechosen = prechosen
        self.index = 0
        self.newchoices = []
        self.newexecutions = []

And now the choose method admits a simple implementation:

    def choose(self, n):
        # if we're still making prechosen choices, return the next one
        if self.index < len(self.prechosen):
			ret = self.prechosen[self.index]
            self.index += 1
            return ret

        # We're going to return 0, but we need to set up pre-choices for the ones
        # we're not choosing right now but with the same choices up until now
        for i in range(1, n):
            self.newexecutions.append(self.prechosen + self.newchoices + [i])

        return 0

Even though I’ve been fiddling with this algorithm for a while, I still marvel that it’s capable of doing what it does. I’ve got variants in six languages: Common LISP, Go, Java, Python, TypeScript and Rust which are available here: https://github.com/drewcsillag/chooser. There are concurrent versions in Java and Rust, that benchmark at least faster than the single threaded versions. Because the implementations are small, it’s more meant to be copy-pasted, than being consumed from the repo or packaged. Additionally, there are sudoku solvers, KenKen solvers, and a city travel example that’s an implementation of something out of a Prolog book I have.

At some point, I’d like to write a multi-machine parallel version, and to possibly as part of the same effort, do an implementation of something that I have a TLA+ spec for (I have one in mind I might take a swing at), ideally one with a bug, and one that’s got the bug fixed so I can verify my answers match.

And one last application, I found a previous Advent of Code problem Day 8 from 2021 which I solved the meat of with this:

def solve(s):
    samp, digs = s.split(' | ')
    all10 = [set(i) for i in samp.split(' ')]
    num = digs.split(' ')

    s2 = [i for i in all10 if len(i) == 2][0] # Identify 1
    s3 = [i for i in all10 if len(i) == 3][0] # Identify 7
    s4 = [i for i in all10 if len(i) == 4][0] # Itendify 4
    a = list(s3 - s2)[0] # The a segment is the top part of 7, distinct from 1

    def set_choose(ch, s, picked):
        # Subtract things we've picked already from the things to choose from
        # so we don't wind up choosing the same segment twice. Also allows us
        # to bail out if we have no valid choices for a segment -- it'll blow
        # an IndexError. Not needed for correctness, but lowers the iteration
        # from 168 per to 48 per solve. Not that 168 isn't plenty fast, but
        # more because, why not?
        s = s - picked
        l = list(s)
        c = ch.choose(len(l))
        r = l[c]

        return r

    ans = [""] # a place to stow the answer generated in innner
    def inner(ch):
            picked = {a}
            # We could start checking results against the input set as soon as we
            # have a full number so to bail sooner, and the ordering below 
            # could be optimized to fit that, but meh.
            b = set_choose(ch, s4 - s2, picked)
            c = set_choose(ch, s2, picked)
            d = set_choose(ch, {'a', 'b', 'c', 'd', 'e', 'f', 'g'}, picked)
            e = set_choose(ch, {'a', 'b', 'c', 'd', 'e', 'f', 'g'}, picked)
            f = set_choose(ch, s2, picked)
            g = set_choose(ch, {'a', 'b', 'c', 'd', 'e', 'f', 'g'}, picked)
            newdigs = [{a,b,c,e,f,g}, {c,f}, {a,c,d,e,g}, {a,c,d,f,g}, {b,c,d,f},
                       {a,b,d,f,g}, {a,b,d,e,f,g}, {a,c,f}, {a,b,c,d,e,f,g},
            for nd in newdigs:
                if nd not in all10:
            nset = [set(i) for i in num]
            coded = [newdigs.index(i) for i in nset]
            ans[0] = ''.join([str(i) for i in coded])
            print("ANS " + ans[0])
        except IndexError: # set_choose input set only contains things we've picked already

    return ans[0]

Hat tip to this page which the above code is a port of theirs.