Archive for March, 2011

The Balanced Partition Problem

March 7, 2011

The Balanced Partition Problem

The other night, at about 11pm, a friend of mine decided to start asking me for help on his homework that was due that night.  I normally just try to provide general hints and not do the problems myself (it’s his homework), when I see an interesting problem I can play around with in Python, it’s tough to ignore.

The problem was essentially something like: ‘you have N buckets which each have a different weight.  give an algorithm that separates them into two piles as close to the same weight as possible’.  Which is one way to state ‘The Balanced Partition Problem’, and Python is a great language for quickly throwing together and testing algorithms.

First things first: build the different sized ‘buckets’:

buckets = []
for x in xrange(0,40):
    y = random.randint(2,500)
    while y in buckets:
        y = random.randint(2,500)
    buckets.append(y)

Attempt #1:

The first idea that occurred to me was to solve it like you would a simple root solver.  This was the code I came up with:

t = time.time()
# method 1
suba = list(buckets)
subb = list()
suma = sum(suba)
sumb = sum(subb)
diff = abs(suma-sumb)

it = 0
while min(suba)*2 < diff:
    (v,i) = min(izip((abs(x*2-diff) for x in suba) ,count()))
    v = suba[i]
    suba.remove(v)
    subb.append(v)
    suma -= v
    sumb += v
    diff = suma-sumb

it += 1

print 'meth1: {0}s'.format(time.time()-t)
print '{0} iterations.'.format(it)
print 'suma: {0}\nsumb: {1}'.format(suma,sumb)
print ' diff: {0}, min ele: {1}'.format(abs(suma-sumb),min(buckets))

The code starts with all buckets in ‘set a’, and each iteration it looks for the element in ‘set a’ that will come closest to reducing the difference to zero and moving it to ‘set b’.  This stops when the difference is too small to be fixed by the smallest element in ‘set a’.

This algorithm is very fast (linear), and gives two sets that are pretty close to equal. Unfortunately this doesn’t always produce the best result, as there may be cases where the difference between a single element in one set and a subset of elements in the other set is smaller than the smallest element in either set.  This method doesn’t account for these different types of combinations, but for a simple and (very) quick method, it gets pretty close to equal.

Attempt #2: Brute Force

How close is method 1?  Well, I guess we have to find the best possible solution to figure that out, and if you can’t think of a clever way to do it, there is always the brute force method.  Since order doesn’t matter in a sum, we just need to iterate over every possible combination of subsets and pick the one which has the sum closest to half that of the sum of the full set. Since we’re using Python, we can use the combinations function (since 2.6), which will improve our speed a bit:

t = time.time()
##### method 2
best = sum(buckets)
target = best/2
for i in xrange(0, len(buckets)):
    for s in combinations(buckets,i):
        x = sum(s)
        if abs(x-target) < abs(best-target):
            best = x
            bests = s

suba = sorted(bests)
subb = sorted(set(buckets)-set(bests))
suma = sum(suba)
sumb = sum(subb)
print 'meth2: {0}s'.format(time.time()-t)
print 'suba: {0}\nsubb: {1}\nsuma: {2}\nsumb: {3}'.format(sorted(suba),sorted(subb),suma,sumb)
print ' diff: {0}, min ele: {1}'.format(abs(suma-sumb),min(buckets))

Now I was originally using a set size of 20, which this method does in a couple seconds, but because of the complexity of the combination function, bumping that size up to just 30 takes method 2 over 12 minutes on my desktop (method 1 takes 1.5e-5s). Making the size any larger makes this method unfeasible.

Attempt #3: A Better Solution

Before asking me for help, my friend found this clever solution to the problem that’s much better than the brute force one:

http://people.csail.mit.edu/bdean/6.046/dp/dp_4.swf

Pretty nice, but then what do you expect from MIT?

Looking at his method, he says that checking P(i,j) for all possible i,j \in (0..n,0..nk) will take O(n^2k) time.  This is true, but why check them all?  From the recursive equation he gives:

P(i,j) = \max(P(i-1,j), P(i-1,j-A_i))

and the initial condition:

P(0, 0) = 1, P(0, j\neq0) = 0

The recursive formula can be used to generate all the cases where P(i,j) = 1.  Additionally, because I am using Python, I decided to make P a dictionary (actually a dictionary of dictionaries) to keep track of the subsets.  Using this method, we need to generate considerably less than n^2k subsets.  The code is:

t = time.time()
##### method 3
best = sum(buckets)
target = best/2
bests = ()
P = {}
P[-1] = {0:()}
for i in xrange(len(buckets)):
    x = buckets[i]
    P[i] = {}
    for j in P[i-1].keys():
        P[i][j] = P[i-1][j]
        P[i][j+x] = P[i-1][j] + (x,)

        if abs((j+x)-target)<abs(best-target):
            best = j+x
            bests = P[i][j+x]

suba = sorted(bests)
subb = sorted(set(buckets)-set(suba))
suma = sum(suba)
sumb = sum(subb)
print 'meth3: {0}s'.format(time.time()-t)
print 'suma: {0}\nsumb: {1}'.format(suma,sumb)
print ' diff: {0}, min ele: {1}'.format(abs(suma-sumb),min(buckets))

This method yields the best answer very fast, and the python code could probably be optimized to improve the speed.  On my desktop it takes ~0.1s for 30 ‘buckets’ (vs. 12+ minutes for method 2), and still only ~20s for 200 ‘buckets’. It is interesting to note that method 1 produces very close to the best result (usually single different differences for 200 buckets, if not exactly zero) and still takes 5e-3s for 200 buckets, so for an approximation it is quite good.

The final improvement can be made by realizing that since your target is half the sum of the full set, then one of the subsets must be equal to or less than this target, so we can skip computing and checking sums over this value.  This two line change cuts the execution time in half:

t = time.time()
##### method 3
best = sum(buckets)
target = best/2
bests = ()
P = {}
P[-1] = {0:()}
for i in xrange(len(buckets)):
    x = buckets[i]
    P[i] = {}
    for j in P[i-1].keys():
        P[i][j] = P[i-1][j]
        if j+x > target:
            continue
        P[i][j+x] = P[i-1][j] + (x,)

        if abs((j+x)-target)<abs(best-target):
            best = j+x
            bests = P[i][j+x]

suba = sorted(bests)
subb = sorted(set(buckets)-set(suba))
suma = sum(suba)
sumb = sum(subb)
print 'meth3: {0}s'.format(time.time()-t)
print 'suma: {0}\nsumb: {1}'.format(suma,sumb)
print ' diff: {0}, min ele: {1}'.format(abs(suma-sumb),min(buckets))

Ok, I know I said the final improvement, but I thought to myself, ‘what if I check if the target is reached early, then bail out’. This requires a little refactoring of the code into a function so that I can use return to break out of it early:

t = time.time()
##### method 3
best = sum(buckets)
bests = ()
def m3():
    global best, bests
    target = best/2
    P = {}
    P[-1] = {0:()}
    for i in xrange(len(buckets)):
        x = buckets[i]
        P[i] = {}
        for j in P[i-1].keys():
            P[i][j] = P[i-1][j]
            if j+x > target:
                continue
            P[i][j+x] = P[i-1][j] + (x,)

            if abs((j+x)-target)<abs(best-target):
                best = j+x
                bests = P[i][j+x]
                if best == target:
                    return

m3()

suba = sorted(bests)
subb = sorted(set(buckets)-set(suba))
suma = sum(suba)
sumb = sum(subb)
print 'meth3: {0}s'.format(time.time()-t)
print 'suma: {0}\nsumb: {1}'.format(suma,sumb)
print ' diff: {0}, min ele: {1}'.format(abs(suma-sumb),min(buckets))

As it happens, after running several tests, it seems that the function reaches the target early quite often, and this change cuts the runtime, on average, down to 1/3 of what it was.

Thats all I’ve done with it so far.  It can probably be improved further, but for a just-for-fun little exercise with Python, it’s pretty interesting (to me at least).