/*
 3r_phasing-in.c
 (C) 2003 Yann GUIDON <whygee@f-cpu.org>

 Recursive Range Reduction
  example source code with phasing-in codes
 created Sun Aug 24 22:16:02 CEST 2003 by whygee@f-cpu.org
 version Sat Oct 25 03:32:59 CEST 2003 : cleanup for DDJ
 version Wed Dec 31 06:52:00 CET 2003 : implemented PI-codes
*/

#include <stdio.h>   /* for printf */
#include <stdlib.h>  /* for malloc */
#include <string.h>  /* for memset */

#ifndef OPT
#include "phasing-in.c"
#else
#include "pi_opt.c"
#endif

void encode_3R(int nb_elements, unsigned long *list) {
  int range = 0;
  unsigned long int *t, u, v, mask;

  /* recursive summing sub-function */
  unsigned long int recursive_sum(int index, int size) {
    unsigned long int w, x;
    int diff = size >> 2;

    if (index & 1)
      return list[index >> 1]; /* a leaf was reached */
    else {
      /* further explore the tree */
      w = recursive_sum(index + diff, size >> 1);
      x = recursive_sum(index - diff, size >> 1);
      w += x;
      t[index >> 1] = w;
      return w;
    }
  }

  /* recursive subfunction for the second pass */
  void send_stream(int index, int size, int range,
        int direction, unsigned long int last_value) {
    unsigned long int w, mask;
    int diff = size >> 2;

    if (index & 1) {
      /* odd index: we reached a leaf */
      if (direction == 1) {
        w = list[index >> 1];
        pi_put_bits(w, range, last_value);
      }
    }
    else {
      w = t[index >> 1];

      if (direction == 1)
        pi_put_bits(w, range, last_value);

      /* range reduction : */
      mask = 1 << (range - 1);
      while ((w & mask) == 0) {
        range--;
        if (range == 0) {
          printf("[the remaining zeroes are skipped]\n");
          return;
        }
        mask >>= 1;
      }

      send_stream(index - diff, size >> 1, range, 1, w + 1);
      send_stream(index + diff, size >> 1, range, 0, w + 1);
    }
  }

  /**** the function's body : ****/

  if ((nb_elements - 1) & nb_elements) {
    printf("Error : %d is not a power of 2\n", nb_elements);
    exit(-1);
  }

  t = malloc(nb_elements * sizeof(*t));
  if (t == NULL) {
    perror("malloc");
    exit(-1);
  }

  /* build the sum tree */
  recursive_sum(nb_elements, nb_elements << 1);

  /* size of the first number */
  u = v = t[nb_elements >> 1]; /* middle of t */
  while (u > 0) {
    u >>= 1;
    range++;
  }

  /* 5 bits can encode a number from 0 to 24 */
  pi_put_bits(range, 5, 25);

  if (range > 0) {
/* encode the first word without the implicit MSB */
    mask = 1U << (range - 1);
    pi_put_bits(v & ~mask, range - 1, mask);

    /* explore the tree again */
    send_stream(     nb_elements  >> 1, nb_elements, range, 1, v + 1);
    send_stream((3 * nb_elements) >> 1, nb_elements, range, 0, v + 1);
  }

  free(t);
}

void decode_3R(int nb_elements, unsigned long int *list) {
  int range;
  unsigned long int u, mask;

  /* recursive decoding subfunction */
  unsigned long int branch(int index, int size,
      int direction, int range, unsigned long int difference){

    unsigned long int w, t, mask;
    int diff = size >> 2;

    if (direction == 1)
      w = pi_get_bits(range, difference);
    else
      w = difference;

    if (index & 1) {
      /* reaching a leaf of the tree */
      printf("  [%2d] = %04lX\n", index >> 1, w);
      list[index >> 1] = w;
    }
    else {
      /* range reduction */
      mask = 1U << (range - 1);
      while (((w & mask) == 0) && (range > 0)) {
        range--;
        mask >>= 1;
      }

      t = branch(index - diff, size >> 1, 1, range, w + 1);
          branch(index + diff, size >> 1, 0, range, w - t);
    }
    return w;
  }

  /**** the function's body ****/

  range = pi_get_bits(5, 25);
  printf("\nrange = %d\n", range);

  if (range>0) {
    printf("Initial number (without MSB): ");
    /* restore the MSB of the root's value */
    mask = 1U << (range -1);
    /* get the first word and restore the implicit MSB */
    u = pi_get_bits(range - 1, mask) | mask; 
    printf("root = %lX\n", u);

    u -= branch(   nb_elements  >> 1, nb_elements, 1, range, u + 1);
         branch((3*nb_elements) >> 1, nb_elements, 0, range, u);
  }
}

