There’s a chunk of code I’ve been playing with the last two and a half-ish years (judging by its git history) that for me is just really cool. It’s something I’ve been aware of since reading SICP, and I wrote a version of it probably a decade ago, but it was in a very specific application, rather than the more simplified and generalized form I have it in here. I had been reading about Prolog when I was reminded of it, and that solving things in a normal programming language rather than Prolog would be simpler, if only there was something that dealt with it. After thinking about it more, it could be used to do formal verification, since what TLA+ does is search the state space of the system specified in TLA+ or PlusCal (because raw TLA+ is pretty terrible, though PlusCal isn’t great either). While I’ve not actually done formal verification with it yet, I do have code that uses it to solve mazes, sudoku and KenKen puzzles, figure out valid magic squares, generating power sets, counting in binary.
What is it? The amb
operator. It turns out, it admits a very simple
implementation that I’ll explain.
The question relates to the question of “how I can I choose all possible combinations of a set of choices in an easy way?” Not that I’m above doing a normal recursive, depth first search, but it turns out that the choice generation side of things can be simple enough so that you really only need to worry about dealing with the choices rather than figuring out how to compute them and handle any backtracking, etc.
As a simple for example, we’ll count in binary:
def binary_counter(c):
print("%s %s %s" % (c.choose(2), c.choose(2), c.choose(2)))
if __name__ == "__main__":
run(binary_counter)
So here, we’re just saying, choose a number from 0 to 1 – you can
think of the argument to choose
as being the argument to range
.
In any event, this code, when run produces:
0 0 0
0 0 1
0 1 0
0 1 1
1 0 0
1 0 1
1 1 0
1 1 1
Or for a more complex example, how to solve a magic square:
def solve_magic_square(c):
left = [1, 2, 3, 4, 5, 6, 7, 8, 9]
square = []
# pick the first row and check for sum == 15
square.append(c.pick(left)) # 0
square.append(c.pick(left)) # 1
square.append(c.pick(left)) # 2
if square[0] + square[1] + square[2] != 15:
return
# pick the next row anh check for sum == 15
square.append(c.pick(left)) # 3
square.append(c.pick(left)) # 4
square.append(c.pick(left)) # 5
if square[3] + square[4] + square[5] != 15:
return
# pick the first of the third row
square.append(c.pick(left)) # 6
if (
# check first column
square[0] + square[3] + square[6] != 15
# check diagonal top right to bottom left
or square[2] + square[4] + square[6] != 15
):
return
# pick the second of the third row
square.append(c.pick(left)) # 7
# check the column
if square[1] + square[4] + square[7] != 15:
return
# pick the third of the third row
square.append(c.pick(left)) # 8
# check across
if square[6] + square[7] + square[8] != 15:
return
# check column
elif square[2] + square[5] + square[8] != 15:
return
# check diagonal
elif square[0] + square[4] + square[8] != 15:
return
print(square[0:3])
print(square[3:6])
print(square[6:9])
print()
# chooser.stop() # to stop at first solution
Specifically, as it relates to Prolog, it’s much easier to say: stop
searching from here (with return
) because we cannot validly proceed,
than the way you sometimes have to do it in prolog. It also leverages
a more conventional language, rather than one which has the oddities
that Prolog has.
Here, we start with the list of numbers to consume, and then one by
one, we pick items from that list, using a method called pick
,
defined below, which chooses and removes an item from the passed in
list.
def pick(self, l):
c = self.choose(len(l))
ret = l[c]
del l[c]
return ret
But returning to solve_magic_square
, it picks three numbers, checks
if they add up to 15, picks another three numbers, checks if they add
up to 15, and then over the course of picking the last three numbers,
checks columns and diagonals. If it fails to add up to 15 at any
point, we just return. If we make it all the way to the bottom, we’ve
succeeded, and the result is printed. There are eight answers, but the
chooser also has a stop
method which we’ll get to which can stop the
chooser running as soon as the function calls stop
.
How would this work? Let’s start with the binary counter again. Since
starting out, whatever run
does, it doesn’t know anything at the
beginning of that first execution of binary_counter
. It can safely
choose zero for all calls to c.choose
, but something else is
needed. So what we can do is say: for the first call to c.choose
we’ll return zero, but there needs to be a point where it returns 1
instead, so let’s throw that into a list. For the second call to
c.choose
we’ll still return zero, but there needs to be a point
where it returns zero for the first call, and one for the second –
that is, the first choice is the same, but the second one
differs. Then for the third call, it will return zero, but we’ll need
a point where it’ll return one (while still returning zeros for the
first two calls).
This implies two things:
- we need a way to prechoose things for parts of other “executions”
- we need a place to store these new pre-choices.
Once we’ve exhausted these “executions”, we should be done.
The top level loop would look like this:
def run(f):
executions = [[]] # the first execution has nothing pre-chosen
while executions:
prechosen = executions.pop()
c = Chooser(prechosen)
f(c)
executions.extend(c.newexecutions)
The Chooser
class Chooser:
def __init__(self, prechosen):
self.prechosen = prechosen
self.index = 0
self.newchoices = []
self.newexecutions = []
And now the choose
method admits a simple implementation:
def choose(self, n):
# if we're still making prechosen choices, return the next one
if self.index < len(self.prechosen):
ret = self.prechosen[self.index]
self.index += 1
return ret
# We're going to return 0, but we need to set up pre-choices for the ones
# we're not choosing right now but with the same choices up until now
for i in range(1, n):
self.newexecutions.append(self.prechosen + self.newchoices + [i])
self.newchoices.append(0)
return 0
Even though I’ve been fiddling with this algorithm for a while, I still marvel that it’s capable of doing what it does. I’ve got variants in six languages: Common LISP, Go, Java, Python, TypeScript and Rust which are available here: https://github.com/drewcsillag/chooser. There are concurrent versions in Java and Rust, that benchmark at least faster than the single threaded versions. Because the implementations are small, it’s more meant to be copy-pasted, than being consumed from the repo or packaged. Additionally, there are sudoku solvers, KenKen solvers, and a city travel example that’s an implementation of something out of a Prolog book I have.
At some point, I’d like to write a multi-machine parallel version, and to possibly as part of the same effort, do an implementation of something that I have a TLA+ spec for (I have one in mind I might take a swing at), ideally one with a bug, and one that’s got the bug fixed so I can verify my answers match.
And one last application, I found a previous Advent of Code problem Day 8 from 2021 which I solved the meat of with this:
def solve(s):
samp, digs = s.split(' | ')
all10 = [set(i) for i in samp.split(' ')]
num = digs.split(' ')
print(s)
s2 = [i for i in all10 if len(i) == 2][0] # Identify 1
s3 = [i for i in all10 if len(i) == 3][0] # Identify 7
s4 = [i for i in all10 if len(i) == 4][0] # Itendify 4
a = list(s3 - s2)[0] # The a segment is the top part of 7, distinct from 1
def set_choose(ch, s, picked):
# Subtract things we've picked already from the things to choose from
# so we don't wind up choosing the same segment twice. Also allows us
# to bail out if we have no valid choices for a segment -- it'll blow
# an IndexError. Not needed for correctness, but lowers the iteration
# from 168 per to 48 per solve. Not that 168 isn't plenty fast, but
# more because, why not?
s = s - picked
l = list(s)
l.sort()
c = ch.choose(len(l))
r = l[c]
picked.add(r)
return r
ans = [""] # a place to stow the answer generated in innner
def inner(ch):
try:
picked = {a}
# We could start checking results against the input set as soon as we
# have a full number so to bail sooner, and the ordering below
# could be optimized to fit that, but meh.
b = set_choose(ch, s4 - s2, picked)
c = set_choose(ch, s2, picked)
d = set_choose(ch, {'a', 'b', 'c', 'd', 'e', 'f', 'g'}, picked)
e = set_choose(ch, {'a', 'b', 'c', 'd', 'e', 'f', 'g'}, picked)
f = set_choose(ch, s2, picked)
g = set_choose(ch, {'a', 'b', 'c', 'd', 'e', 'f', 'g'}, picked)
newdigs = [{a,b,c,e,f,g}, {c,f}, {a,c,d,e,g}, {a,c,d,f,g}, {b,c,d,f},
{a,b,d,f,g}, {a,b,d,e,f,g}, {a,c,f}, {a,b,c,d,e,f,g},
{a,b,c,d,f,g}]
for nd in newdigs:
if nd not in all10:
return
nset = [set(i) for i in num]
coded = [newdigs.index(i) for i in nset]
ans[0] = ''.join([str(i) for i in coded])
print("ANS " + ans[0])
except IndexError: # set_choose input set only contains things we've picked already
return
run(inner)
return ans[0]
Hat tip to this page which the above code is a port of theirs.