A co-worker recently gave a talk on shuffle sharding and the thought came to me: we know that that for
n things taken k at a time you have nCk, and what that evaluates to; given n and k.
How do we map a number between 1 and nCk to a unique set of selections, and vice versa,
given the set of selections, can you get back to the original number?
Enumeration, Part 1
It seemed evident at the time that the main problem was enumeration the combinations. However, at least in
my education and experience, I’d never had to do it before. My first and wrong approach was to think:
it’s easy enough to generate the binary numbers between 0 and 2n and just filter out those that
have a popcount other than k. Fortunately, I quickly realized this was a dead end.
Break out the pen and paper and a number of attempts later I had a table of the combinations
of 6C3. Why 6 and 3? It turned out that in my first attempts to code things, I
had used 5 and 2 and in my solution, things
happened to work, but would fail on values of k > 2. So I decided to kick it up so that would be less likely going forward. It was also convenient because 6C3 is only 20 and easily fit on a notebook page.
I had intuitively figured out how to generate all the combinations of 6C3, but the arrow between intuition to working code can be longer than one would think; and did the way I intuitively wrote down the enumeration have an ordering that made sense when coding it? I figured I was probably close as I started with ones all the way on the left and the ones marched across the row as you went down the page.
The solution I arrived at the end seemed obvious in retrospect, but I wasn’t there yet, and had to stare at my page and try to see what was going on.
Pascal’s Triangle
It was during this where I remembered that combinations were directly related to binomial coefficients and so I figured I’d draw up the table, and maybe this would help:
n
0               1 /
1             1 /  1 /
2           1 /  2 /  1 /
3         1 /  3 /  3 /  1 /
4       1 /  4 /  6 /  4 /  1 /
5     1 /  5 / 10 / 10 /  5 /  1 / 
6   1 /  6 / 15 / 20 / 15 /  6 /  1 /
k 0 /  1 /  2 /  3 /  4 /  5 /  6 /
I then started looking at the table to see what I could figure out.
01 : [1, 1, 1, 0, 0, 0]
02 : [1, 1, 0, 1, 0, 0]
03 : [1, 1, 0, 0, 1, 0]
04 : [1, 1, 0, 0, 0, 1]
05 : [1, 0, 1, 1, 0, 0]
06 : [1, 0, 1, 0, 1, 0]
07 : [1, 0, 1, 0, 0, 1]
08 : [1, 0, 0, 1, 1, 0]
09 : [1, 0, 0, 1, 0, 1]
10 : [1, 0, 0, 0, 1, 1]
11 : [0, 1, 1, 1, 0, 0]
12 : [0, 1, 1, 0, 1, 0]
13 : [0, 1, 1, 0, 0, 1]
14 : [0, 1, 0, 1, 1, 0]
15 : [0, 1, 0, 1, 0, 1]
16 : [0, 1, 0, 0, 1, 1]
17 : [0, 0, 1, 1, 1, 0]
18 : [0, 0, 1, 1, 0, 1]
19 : [0, 0, 1, 0, 1, 1]
20 : [0, 0, 0, 1, 1, 1]
First I noticed that the first bit was on for half of the table, and thought “binary!”, but then looking at the second bit, it became clear that wasn’t right. So looking at the first bit again and considering Pascal’s triangle, I theorized:
01 : [1, 1, 1, 0, 0, 0] --+
02 : [1, 1, 0, 1, 0, 0]   |
03 : [1, 1, 0, 0, 1, 0]   |
04 : [1, 1, 0, 0, 0, 1]   |
05 : [1, 0, 1, 1, 0, 0]   |
06 : [1, 0, 1, 0, 1, 0]   | This is 5C 2 or 3
07 : [1, 0, 1, 0, 0, 1]   |
08 : [1, 0, 0, 1, 1, 0]   |
09 : [1, 0, 0, 1, 0, 1]   |
10 : [1, 0, 0, 0, 1, 1] --+
11 : [0, 1, 1, 1, 0, 0] -----+
12 : [0, 1, 1, 0, 1, 0]      |
13 : [0, 1, 1, 0, 0, 1]      |
14 : [0, 1, 0, 1, 1, 0]      | This is 5C 2 or 3
15 : [0, 1, 0, 1, 0, 1]      |
16 : [0, 1, 0, 0, 1, 1]      |
17 : [0, 0, 1, 1, 1, 0]      |
18 : [0, 0, 1, 1, 0, 1]      |
19 : [0, 0, 1, 0, 1, 1]      |
20 : [0, 0, 0, 1, 1, 1] -----+
Then the second bit:
01 : [1, 1, 1, 0, 0, 0] --+----+
02 : [1, 1, 0, 1, 0, 0]   |    | This looks like 4C1?
03 : [1, 1, 0, 0, 1, 0]   |    |
04 : [1, 1, 0, 0, 0, 1]   |----+
05 : [1, 0, 1, 1, 0, 0]   |--------------------+
06 : [1, 0, 1, 0, 1, 0]   | This is 5C2        |
07 : [1, 0, 1, 0, 0, 1]   |                    | This looks like 4C2?
08 : [1, 0, 0, 1, 1, 0]   |                    | 
09 : [1, 0, 0, 1, 0, 1]   |                    | 
10 : [1, 0, 0, 0, 1, 1] --+--------------------+
11 : [0, 1, 1, 1, 0, 0] -----+-------------------+
12 : [0, 1, 1, 0, 1, 0]      |                   |
13 : [0, 1, 1, 0, 0, 1]      |                   | This looks like 4C2?
14 : [0, 1, 0, 1, 1, 0]      | This is 5C3       |
15 : [0, 1, 0, 1, 0, 1]      |                   |
16 : [0, 1, 0, 0, 1, 1]      |-------------------+
17 : [0, 0, 1, 1, 1, 0]      |-------------------+
18 : [0, 0, 1, 1, 0, 1]      |                   | This looks like 4C3
19 : [0, 0, 1, 0, 1, 1]      |                   |
20 : [0, 0, 0, 1, 1, 1] -----+-------------------+
I then had the Eureka! moment. Of course it would work this way! In the first half of the table, starting with
6C3, the first one had been selected so we’d decrement both n and k, which yields 5C2 == 10.
In the second half of the table, we didn’t select the first, so we only decrement n, which yields
5C3 which also == 10. This was a data point, but since they both were the same, well, I couldn’t
be sure. Had I known then what I do now, I probably would have went with 6C4 as when n is twice k, you get the even distribution of the first item in the list.
But going to the second bit, things got more clear. In the first batch from 01 to 04 the second bit has been
selected, so if the theory held, we should decrement both n and k which yields 4C1 == 4,
and that’s exactly what we see. On the elements 11 to 16, we start with 5C3 and pick the
second element, which again should decrement both n and k, which 4C2 == 6, which again
matches. In the case of 17-20, we didn’t take the element (i.e. it’s a zero) so we only decrement n and not k, which gives 4C3 == 4. I then applied this all the way down and it all held. Woot!
Enumeration Part 2
Now that I knew the structure the trick was to use the relationships I’d found to generate the combinations.
def choose(x, n, k, l):
    k -= 1
    n -= 1
    # while we still items to choose from
    while n >= 0:
        # if we've chosen all k elements, we take no more elements
        if k < 0:
            l.append(0)
        else:
            # how big is the span where we take the next element?
            lim = comb(n, k)
            if x < lim: # if x is in the span from 0 - lim, take the element
                l.append(1)
                k -= 1  # one less to choose, since we took one
            else: # x is over that limit, don't take the element
                l.append(0)
                x -= lim  # scale leftover x
        # one less item to choose from
        n -= 1
    return l
And this is the code that I used to generate the tables above.
Reversal
Once I can go forwards, can I go back? That is: take the list of selected elements and go back to the number that generated them? It turns out, once you understand the forward direction, the backward is pretty easy:
def reverse(l, n, k):
    n -= 1
    k -= 1
    v = 0 # our accumulator
    for val in l: 
        if k < 0: # if we've selected everything there is to select, we can bail
            break
        lim = comb(n, k) # what's the size of the span if we take the element?
        if val == 1:  # we did take the element, so v is in the span, so don't add to it
            k -= 1    # note that we took the element
        if val == 0:  # we didn't take it, so add the limit to put it outside the span
            v += lim  
            
        n -= 1        # we used up a slot to put things in
    return v
Optimization
With something like this, it occurred to me: computing comb a bunch of times seems pretty expensive if we were doing this in a production environment. I imagine a
sane version of comb doesn’t explode out all the factorials, but it can still be pretty expensive run a bunch of times.
Especially in the case of the inspiration of this whole endeavor where n == 2048. Can we do better? Well this is
written in python, but only for ease of expression, not for speed, so we’ll ignore that for the moment.
I decided to do some math, and since we start from nCk and go to
n-1Ck or nCk-1 – actually in the case of k-1 we actually do both,
I realized that a formula to go from nCk to
n-1Ck or nCk-1 should look pretty reasonable. It turns out it does.
         n!
nCk = --------
      k!(n-k)!
                  k
nCk-1 = nCk * ---------
              n - k - 1
              n - k
n-1Ck = nCk * -----
                n
I imagine these are posted elsewhere on the internet, but it was a lazy weekend day working most of this out on paper, so I wound up deriving it myself.
So I can now compute comb exactly once, and subsequently use the other two to handle stepping things down along the way.
def decn(n, k, v):
    "from nCk, yield (n-1)Ck"
    if n == 0:
        return v
    return v * ((n - k) / n)
def deck(n, k, v):
    "from nCk, yield nC(k-1)"
    if k == 0:
        return v
    # reordered a bit to avoid under/overflow
    return k * (v / (n - k + 1))
def lesscomb_choose(x, n, k, l):
    "A variant where we only call comb once"
    k -= 1
    n -= 1
    # how many numbers that whose next digit is 1?
    lim = comb(n, k)
    while n >= 0:
        # no more to choose
        if k < 0:
            l.append(0)
        else:
            if x < lim:
                l.append(1)
                lim = deck(n, k, lim) # does the formula to nCk -> nCk-1
                k -= 1  # one less to choose
            else:
                l.append(0)
                x -= lim  # scale leftover
        lim = decn(n, k, lim) # does the formula to nCk -> n-1Ck
        n -= 1
    return l
I thought of a few other optimization techniques that could easily be applied. If n doesn’t change,
you can make a list of n elements that have the index of the first element to have a 1. With that,
given an input number, you could binary search to find what range of things to continue from, and thus
skip a bunch of elements. Worst case is still O(n), but on average it should do well. And when n gets
sufficiently small, you could just have a lookup table for the various k values; which in the shuffle
sharding case k started at 4, so the number of k values is thus quite limited, and so you can cover
a fairly large number of values of n.
From there you could look to remove branching from the code so to optimize away the problem of branch
prediction failures. For an example of what that sort of might look like, see the code golfing bit
below. Python doesn’t quite give you what you need to remove it all (i.e. the second assignment to lim),
but you get the gist.
Even more, since both the space of n and k are known, you could precompute all values of comb(n,k) easily enough
into a table, and avoid the whole computation at runtime. Again, assuming n == 2048 and k == 4, this would only cost
about 64KB (2048 * 4 * 8), assuming 64-bit (8 byte) integers - this will fit in many L1 caches, depending
on what processor you’re running; and there’s likely a low-cost compression trick available, to get that
even smaller if necessary.
It also occurs to me that you could have choose return an array of k integers rather than a list of
booleans to say whether an item was selected or not. For the conversion from int to selected, it probably
wouldn’t affect runtime much save for the insertion cost – which might be reclaimed by branch prediction
misses - but still O(n). However, for the selection -> int conversion, again assuming n is fixed, it could probably save a lot, especially if you precompute what you’d need to add to v for each input value with its position in the input list – assuming they’re all ordered. That is, it goes from O(n) to O(k) on runtime with O(nk) space.
In any case, I’m sure there is endless fun to this exercise, which I’ll leave to the reader. Hope you enjoyed the journey!
Code Golfing, Because Why Not?
Apply some code golfing technique:
def lesscomb_choose(x, nx, k, l):
    lim = comb(nx, k - 1)
    for n in range(nx, 0, -1):
        lim *= (n - k + 1) / n
        l.append(k and x < lim and 1 or 0)
        lim = k - 1 > 0 and l[-1] * lim * (k - 1) // (n - k + 1) or lim
        k -= l[-1]
        x += (l[-1] - 1) * lim
    return l
Using a comb cache, you get a simpler:
def xnocomb_choose(x, n, k, cache, l):
    k, n = k - 1, n - 1
    for n in range(n, -1, -1):
        l.append(k >= 0 and x < cache[(n, k)] and 1 or 0)
        x, k = x - (k >= 0 and (1 - l[-1]) * cache[(n, k)] or 0), k - l[-1]
    return l
If I changed the control structure to while instead of for I could get
rid of the k >= 0 checks you see, but I’d have to manage n myself – but not
big deal.
Using comb cache and only outputting the k selected items:
def xchoose_idxs(x, n, k, cache, l):
    k, ok, n, on = k - 1, k - 1, n - 1, n - 1
    while n >= 0 and k >= 0:
        l[ok - k], dodec = on-n, x < cache[(n,k)] and 1 or 0
        k, x, n = k - dodec, x - (1 - dodec) * cache[(n,k)], n - 1
    return l
We cheat here, and continually overwrite the current position in the list so
as to avoid branching. Since we know that we just step from n-1 to 0, assuming
n is fixed, something else that could be tried would be fully unrolling
the loop – but then we’d need a mechanism for early exit for k < 0