## Saturday, December 22, 2007

### Integer multiplication and dynamic programming (Part 2 of 2)

Alright, now with code. Unlike previous code samples, this one is actually debugged :-).

First, let's define a few data structures. We will need to define an operation somehow, and also the decomposition - a way of breaking up multiplication by a number into a sequence of operations that we define thusly:
`// Possible operations// F1:  shl reg, #iconst        ; * 2**iconst// F2:  lea reg, [reg + 2*reg]  ; * 3// F3:  lea reg, [reg + 4*reg]  ; * 5// F4:  lea reg, [reg + 8*reg]  ; * 9// F5:  add reg, reg2           ; +1// F6:  sub reg, reg2           ; -1// F7:  lea reg, [reg + 2*reg2] ; +2// F8:  lea reg, [reg + 4*reg2] ; +4// F9:  lea reg, [reg + 8*reg2] ; +8// F10: lea reg, [reg2 + 2*reg] ; *2 + 1// F11: lea reg, [reg2 + 4*reg] ; *4 + 1// F12: lea reg, [reg2 + 8*reg] ; *8 + 1`

Here's my definition for the operation:
`enum OPTYPE {  NONE,     // Just a sentry  SHIFT,    // F1  MULTIPLY, // F2-F4  ADDCONST, // F5-F9  INCRMUL   // F10-F12};struct OP {  OPTYPE type;  int argument;};`

As is customary in dynamic programming, for every multiplication operation we're trying to decompose, we will actually decompose a whole bunch of numbers that the original multiple breaks into. There will be a lot of repetitions, so we will store all the intermediary results in a hash, to avoid recalculating them over and over and over again.

Therefore our decomposition structure will be a member in two linked lists - one will point to the next decomposition after applying the operation, and the other is just a linked list of all the elements in a hash bucket.
`struct Decomposition {  int number;  int cost;  OP operation;  Decomposition *next;  Decomposition *next_in_bucket;};`

Let's now implement our hash structure, to get it out the way. As you remember, there are two conditions - when we have an extra register, we can use operations F5-F12, but there is one extra MOV that needs to happen before the register is computed first, so the cost is +1 (for the entire decomposition). If the register is not used, we are limited to operations F1-F4.

It turns out that the easiest way to deal with these cases is to do two decompositions in parallel - one with an extra register, and one without, and then compare the costs. Hence there will be two hashes.
`// These hashes must have identical sizes#define DECOMPOSITION_HASH_SIZE   511Decomposition *    hash_with_extra_reg[DECOMPOSITION_HASH_SIZE];Decomposition *    hash_no_extra_reg[DECOMPOSITION_HASH_SIZE];Decomposition *FindInHash(int num,    bool has_extra_reg) {  Decomposition **hash = has_extra_reg ?      hash_with_extra_reg :      hash_no_extra_reg;  Decomposition *runner = hash[num %      DECOMPOSITION_HASH_SIZE];  while (runner && runner->number != num)    runner = runner->next_in_bucket;  return runner;}void AddToHash(Decomposition *dc,    bool has_extra_reg) {  Decomposition **hash = has_extra_reg ?      hash_with_extra_reg :      hash_no_extra_reg;  int bucket = dc->number %      DECOMPOSITION_HASH_SIZE;  dc->next_in_bucket = hash[bucket];  hash[bucket] = dc;}`

Simple, huh? Now let's write a few service functions that let us compute the results of an operation. Remember that decomposition F exists IFF n == F(F-1(n)), so we will need functions to compute both F and F-1:
`int Compute(int num, OP *op) {  switch (op->type) {    case SHIFT:      return num << op->argument;    case MULTIPLY:      return num * (op->argument + 1);      break;    case ADDCONST:      return num + op->argument;      break;    case INCRMUL:      return num * op->argument + 1;  }  return 0;}int ComputeInverse(int num, OP *op) {  int val = 0;  switch (op->type) {    case SHIFT:      val = num >> op->argument;      break;    case MULTIPLY:      val = num / (op->argument + 1);      break;    case ADDCONST:      val = num - op->argument;      break;    case INCRMUL:      val = (num - 1) / op->argument;      break;  }  // The result can never be negative or 0.  // These values never have inverses.  if (val <= 0)    return 0;  // The operation must have the inverse...  // It is not always the case.  if (Compute(val, op) == num)    return val;  return 0;}`

And while we're at it, something to actually print the result:
`char *GetOpFormat(OP *op) {  switch (op->type) {    case SHIFT:      return "shl eax, %d";    case MULTIPLY:      return "lea eax, [eax + %d*eax]";    case ADDCONST:      if (op->argument == 1)        return "add eax, ecx";      else if (op->argument == -1)        return "sub eax, ecx";      else        return "lea eax, [eax + %d*ecx]";    case INCRMUL:      return "lea eax, [ecx + %d*eax]";  }  return NULL;}void ReversePrint(Decomposition *dc) {  if (!dc)    return;  ReversePrint(dc->next);  printf(GetOpFormat(&dc->operation),      dc->operation.argument);  printf("\n");}`

And now the actual dynamic part. We need to enumerate through operations, and within the operations, through the arguments. Let's write these functions, first (for bloggish simplicity, they are not iterators).
`bool NextOp(OP *op, bool has_extra_reg) {  if (op->type == NONE)    op->type = SHIFT;  else if (op->type == SHIFT)    op->type = MULTIPLY;  else if (op->type == MULTIPLY &&      has_extra_reg)    op->type = ADDCONST;  else if (op->type == ADDCONST &&      has_extra_reg)    op->type = INCRMUL;  else    return false;  op->argument = 0;  return true;}bool NextArg(OP *op) {  if (op->type == SHIFT)    op->argument++;  else if (op->type == MULTIPLY) {    if (op->argument == 0)     op->argument = 2;    else if (op->argument == 2)      op->argument = 4;    else if (op->argument == 4)      op->argument = 8;    else return false;  } else if (op->type == ADDCONST) {    if (op->argument == 0)      op->argument = -1;    else if (op->argument == -1)      op->argument = 1;    else if (op->argument == 1)      op->argument = 2;    else if (op->argument == 2)      op->argument = 4;    else if (op->argument == 4)      op->argument = 8;    else return false;  } else if (op->type == INCRMUL) {    if (op->argument == 0)      op->argument = 2;    else if (op->argument == 2)      op->argument = 4;    else if (op->argument == 4)      op->argument = 8;    else return false;  } else return false;  return true;}`

Now the main function. Note extensive use of sentries.
`#define MAX_DEPTH 10Decomposition *Decompose(int num,    bool has_extra_reg,    int recursion_depth) {  Decomposition *d = FindInHash(num,      has_extra_reg);  if (d) // We already have it!    return d;  // Limit level of recursion - this  // allows us to avoid loops  if (recursion_depth > MAX_DEPTH)    return NULL;  Decomposition temp_d;  memset(&temp_d, 0, sizeof(temp_d));  temp_d.number = num;  temp_d.operation.type = NONE;  OP temp_op;  temp_op.type = NONE;  temp_op.argument = 0;  for ( ; ; ) {    if (! NextOp(&temp_op, has_extra_reg))      break;    for ( ; ; ) {      if (! NextArg(&temp_op))        break;      int next_val = ComputeInverse(num,          &temp_op);      if (next_val == 0) {        if (temp_op.type == SHIFT)          break;        continue;      }      if (next_val == 1) { // Done!        temp_d.next = NULL;        temp_d.cost = 1;        temp_d.operation = temp_op;        break;      }      Decomposition *next = Decompose(          next_val, has_extra_reg,          recursion_depth + 1);      if (! next)        continue;      if (temp_d.operation.type == NONE          || next->cost + 1          < temp_d.cost) {        temp_d.next = next;        temp_d.cost = next->cost + 1;        temp_d.operation = temp_op;      }    }  }  if (temp_d.operation.type == NONE)    return NULL;  d = new Decomposition;  *d = temp_d;  AddToHash(d, has_extra_reg);  return d;}`

We are practically done. Just need something to call Decompose with and without the extra register, and compare the results.
`Decomposition *Decompose(int num) {  if (num <= 1)    return NULL;  Decomposition *with_reg = Decompose(      num, true, 0);  Decomposition *no_reg = Decompose(      num, false, 0);  if (! with_reg)    return no_reg;  if (! no_reg)    return with_reg;  // +1 is for mov operation to  // copy to extra register  return (with_reg->cost + 1 <      no_reg->cost) ? with_reg : no_reg;}`

Alright! We're done. Let's write a trivial main to collect user input and print out the results.
`int main(int argc, char **argv) {  for (;;) {    printf("Input number: ");    char input[128];    fgets(input, sizeof(input), stdin);    char *nl = strchr(input, '\n');    if (nl)      *nl = '\0';    if (strcmp(input, "exit") == 0)      break;    int num = atoi(input);    Decomposition *dc = Decompose(num);    if (!dc) {      printf ("Could not compute!\n");      continue;    }    if (NeedsExtraReg(dc))      printf("mov ecx, eax\n");    ReversePrint(dc);  }}`

Now you can copy all this into an editor of your choice, and enjoy multiplication without multiplication :-).