#!/usr/bin/env perl

# Copyright 1997-2015 Vincent Lefevre <vincent@vinc17.net>
# Work done at:
#   [1997-2000] LIP, ENS-Lyon, France
#   [2000-2005] LORIA / INRIA Lorraine, France
#   [2015]      LIP, ENS-Lyon, France

# Since one needs a license, let's give one.  But note that this program
# has been written mainly for research and testing purpose; the goal was
# simplicity, not efficiency.  The main interest is the algorithm itself,
# which is not patended.  It is described in my papers available from the
# following web page:
#
#   https://www.vinc17.net/research/mulbyconst/
#   http://www.vinc17.org/research/mulbyconst/

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Place, Fifth Floor, Boston, MA
# 02111-1301, USA.

use 5.006;
use strict;

my ($proc) = '$Id: patterns 81727 2015-08-16 02:39:49Z vinc17/zira $'
  =~ /Id: (\S+)/ or die;

my $Usage = <<EOF;
$proc <mode> <algo> <args...>

Mode a:
  Arg1... output level.
  Arg2... odd binary number (or decimal with a 'd', e.g., d17 or 17d).
  Args... other numbers (optional), for algo g only.
  Out.... SAS-chain.

Mode b:
  Arg1... number of digits (> 1).
  Arg2... number of output cases for which there is a difference (optional).
  Out.... comparisons with Bernstein's algorithm.

Mode m:
  Arg1... number of digits (> 1).
  Arg2... number of threads (> 0).
  Out.... maximum and average number of operations.

Mode r:
  Arg1... number of digits (> 1).
  Arg2... number of tested constants (0 to test the rand function).
  Arg3... seed (for rand) -- optional.
  Out.... average number of operations on random values.

Mode w:
  Arg1... number of digits (> 1).
  Arg2... number of output worst cases (optional, default 32).
  Out.... worst cases.

Algo s: simple algorithm (as described in the thesis)
Algo g: global algorithm (common subpatterns)
Algo h: like algo g, but all equivalent Booth codes are tested
EOF

@ARGV > 1 or die $Usage;
my $mode = shift;
my $algo = shift;
my @sign = ('+','-');
my $level = 0;
my @sas;

my $subsas;
if ($algo eq 's')
  { $subsas = \&recsas; }
elsif ($algo eq 'g')
  { $subsas = \&gsas; }
elsif ($algo eq 'h')
  { $subsas = \&hsas; }
else
  { die $Usage; }

if ($mode eq 'a')
  {
    $level = shift;
    $algo eq 'g' || @ARGV == 1 or die $Usage;
    my @b;
    foreach (@ARGV)
      {
        if (/^(d\d+|\d+d)$/)
          {
            my ($d) = /(\d+)/ or die;
            $_ = '';
            while ($d)
              { $_ = ($d & 1 ? '1' : '0').$_;
                $d >>= 1; }
          }
        if (/^[01]*1$/)
          {
            $_ = &booth($_);
            $level > 1 and print "Booth recoding: $_.\n";
          }
        /^N/ and tr/NP/PN/;
        s/^(P|P0([NP]?0)*[NP])0*$/$1/ or die $Usage;
        push @b, $_;
      }
    my $q = &$subsas(@b);
    $level > 1 and print "\n";
    print "Cost(", join(',', @ARGV), ") = $q\n";
    $level and print @sas;
  }
elsif ($mode eq 'm' || $mode eq 'b')
  {
    my $modem = $mode eq 'm';
    @ARGV == 1 || @ARGV == 2 or die $Usage;
    my $n = $ARGV[0];
    $n =~ /^\d+$/ && $n > 1 or die $Usage;
    my ($bmax,$bsum,$bn) = (0,0);
    my ($pmax,$psum,$pn,$d,%d);

    sub modemb ($) {
      my ($iter) = @_;
      my ($tmax,$tsum,$tn) = (0,0);
      while (my $ti = &$iter)
        {
          @sas = ();
          my $a = sprintf '%b', $ti;
          $a =~ s/^0*//;
          my $q = &$subsas(&booth($a));
          $tsum += $q;
          $q > $tmax and $tmax = $q, $tn = $ti;
          next if $modem;
          print WR "$ti\n";
          <RD> =~ /^Cost\($ti\) = (\d+)$/
            or die "$proc: error with bernstein";
          $bsum += $1;
          $1 > $bmax and $bmax = $1, $bn = $ti;
          next if $q == $1;
          $d{$q-$1}++;
          next if ++$d > $ARGV[1];
          print "$ti (Bernstein: $1 - Patterns: $q)\n";
        }
      return ($tmax,$tsum,$tn);
    }

    if ($modem && @ARGV == 2)
      {
        my $fn = $ARGV[1];
        $fn =~ /^\d+$/ && $fn > 0 or die $Usage;
        require threads;
        import threads;
        require threads::shared;
        import threads::shared;
        my $iter = do {
          my $i :shared = (1<<($n-1))-1;
          sub { lock($i); ($i += 2) < 1<<$n ? $i : undef; };
        };
        my @thr;
        foreach my $t (0..$fn-1)
          {
            ($thr[$t]) = threads->create(\&modemb, $iter);
          }
        foreach my $t (0..$fn-1)
          {
            my ($tmax,$tsum,$tn) = $thr[$t]->join();
            $psum += $tsum;
            if ($tmax >= $pmax)
              {
                if ($tmax > $pmax)
                  {
                    $pmax = $tmax;
                    $pn = $tn;
                  }
                elsif ($tn < $pn)
                  {
                    $pn = $tn;
                  }
              }
          }
      }
    else
      {
        my $iter = do {
          my $i = (1<<($n-1))-1;
          sub { ($i += 2) < 1<<$n ? $i : undef; };
        };
        if (!$modem)
          {
            use IPC::Open2;
            open2(\*RD, \*WR, "bernstein 0");
            # Note: this program must make sure that it does not take too
            # much memory when caching information about the constants!
          }
        ($pmax,$psum,$pn) = modemb $iter;
        if (!$modem)
          {
            close WR;
            close RD;
            print "\n" if $d && $ARGV[1] > 0;
            print "Bernstein:\n";
            print "qmax = $bmax ($bn)\n";
            print "qav = $bsum / 2^", $n-2, " = ", $bsum/(1<<($n-2)), "\n\n";
            print "Patterns:\n";
          }
      }
    print "qmax = $pmax ($pn)\n";
    print "qav = $psum / 2^", $n-2, " = ", $psum/(1<<($n-2)), "\n";
    if (%d)
      {
        print "\nDifferences:\n";
        foreach (sort { $a <=> $b } keys %d)
          { printf "%+d: %5d\n", $_, $d{$_} }
      }
  }
elsif ($mode eq 'w')
  {
    my ($i,$n,@w);

    ($n = $ARGV[0]) =~ /^\d+$/ && $n > 1 or die $Usage;
    my $wmax = $ARGV[1] || 32;
    my $qmax = 0;
    for ($i = (1<<($n-1))+1; $i < 1<<$n; $i += 2)
      {
        @sas = ();
        my $a = sprintf '%b', $i;
        $a =~ s/^0*//;
        my $q = &$subsas(&booth($a));
        $q == $qmax && @w < $wmax and push @w, $a;
        $q > $qmax or next;
        $qmax = $q;
        @w = ($a);
      }
    print "qmax = $qmax\n";
    foreach (@w)
      { print "$_\n" }
  }
elsif ($mode eq 'r')
  {
    my ($n,$p,$seed) = @ARGV;
    $n =~ /^\d+$/ && $n > 1 && $p =~ /^\d+$/ or die $Usage;
    srand $seed;
    if ($p)
      {
        my $qmax = 0;
        my $t = 0;
        for (my $i = 0; $i < $p; $i++)
          { @sas = ();
            my $q = &$subsas(&booth(&getrand($n)));
            $q > $qmax and $qmax = $q;
            $t += $q; }
        $t /= $p;
        print <<EOF;
Number of digits:		$n
Number of tested constants:	$p
Largest number of operations:	$qmax
Average number of operations:	$t
EOF
      }
    else
      { print &getrand($n), "\n"; }
  }
else
  { die $Usage; }

sub booth
  { my ($a,$b) = ("00$_[0]","");
    while ($a =~ /^(.*?)(0*)0(1*)1$/o)
      { length $3 or ($a,$b) = ($1,$2.'0P'.$b), next;
        my $z = ('0' x length $3).'N'.$b;
        ($a,$b) = length $2 ? ($1,$2.'P'.$z) : ($1.'1',$z); }
    $b =~ s/^0*//;
    return $b; }

sub recsas
  { my $s = $_[0];
    $s =~ s/0*$//;
    my $l = length($s) - 1;
    $l or return 0;
    my ($i,$j,$c,$d,$w,$t,$u,$f,$g,@p,@n,%w);
    for($i=0;$i<=$l;$i++)
      { my $d = substr $s, $l-$i, 1;
        $d eq 'P' and push @p, $i;
        $d eq 'N' and push @n, $i; }
    print "\nrecsas: $s.\n  P: @p\n  N: @n\n" if ($level > 2);
    foreach $i (1..$#p) { foreach $j (0..$i-1) { $w{$p[$i]-$p[$j]}++ } }
    foreach $i (1..$#n) { foreach $j (0..$i-1) { $w{$n[$i]-$n[$j]}++ } }
    foreach $i (0..$#p) { foreach $j (0..$#n) { $w{-abs($p[$i]-$n[$j])}++ } }
    my @d = sort { $w{$a} <=> $w{$b} || $a <=> $b } keys %w;
    if ($level > 3)
      { my $i = $#d;
        do
          { printf "%3d, w = %d\n", $d[$i], $w{$d[$i]} }
        while (--$i >= 0); }
    while ($w{$d = pop @d} > $w)
      { $_ = $s;
        my $x = 0;
        my $z = abs($d) - 1;
        if ($d > 0)
          {
            while (s/P(.{$z})P/0${1}p/) { $x++ }
            while (s/N(.{$z})N/0${1}n/) { $x++ }
          }
        else
          {
            my ($y,@y);
            while (($y) = /(P.{$z}N|N.{$z}P)/)
              { @y = $y =~ /^(.)(.{$z})/ or die;
                s/$y/0$y[1]\L$y[0]\E/;
                $x++; }
          }
        $x > $w and $w = $x, $c = $d, $t = $_; }
    $u = $t;
    $t =~ tr/pn/00/;
    $u =~ tr/pnPN/PN00/;
    $f = $c < 0 and $c = -$c;
    $g = $u =~ /^0*N/ and $u =~ tr/PN/NP/;
    print <<EOF if ($level > 2);
    t = $t
    u = $u
  --> t $sign[$g] (u << $c $sign[$f] u)
EOF
    $u =~ s/^0*//;
    $i = &recsas($u);
    push @sas, "u".(@sas+1)." = u$i << $c $sign[$f] u$i\n";
    $i = @sas;
    $t =~ s/^0*//;
    $t or return $i;
    substr($t,0,1) eq 'P' or $g = 2, $t =~ tr/PN/NP/;
    $j = &recsas($t);
    $u =~ /(0*)$/;
    $u = "u$i";
    length($1) and $u .= ' << '.length($1);
    $t =~ /(0*)$/;
    $t = "u$j";
    length($1) and $t .= ' << '.length($1);
    push @sas, "u".(@sas+1).' = '.("$t + $u","$t - $u","$u - $t")[$g]."\n";
    return scalar @sas; }

sub gsas
  { my $cost = 0;
    my @x = map &makeind($_), @_;
    my %sign = ('PP' => '+', 'PN' => '-', 'NP' => '-', 'NN' => '+');
    while (@x)
      {
        $level > 2
          and print "\ngsas: ", join('+', map { $$_{'s'} } @x),
                    "\n      ", join(' ', map { &ind($$_{'s'}) } @x),
                    "\n";
        my (%v,%w);
        foreach (0 .. $#x)
          {
            my ($s1,$d1) = @{$x[$_]}{'s','d'};
            my $key1 = $_;
            foreach ($_ .. $#x)
              {
                my ($s2,$d2) = @{$x[$_]}{'s','d'};
                my $same = $_ == $key1;
                my $key2 = "$key1,$_,";
                my $j1;
                foreach $j1 (@$d1)
                  {
                    my $b1 = substr($s1, -($j1+1), 1);
                    my $j2;
                    foreach $j2 (@$d2)
                     {
                       $same && $j1 <= $j2 and next;
                       my $b2 = substr($s2, -($j2+1), 1);
                       my $key = $key2.($j1-$j2).$sign{$b1.$b2};
                       push @{$v{$key}}, $j2;  # --> already sorted
                       $w{$key}++;
                     }
                  }
              }
          }
        my @d = sort { $w{$a} <=> $w{$b} || -($a cmp $b) } keys %w;
        if ($level > 3)
          { my $i = $#d;
            print "Bounds:\n";
            do
              {
                my $key = $d[$i];
                my @a = $key =~ /^(\d+),(\d+),(-?\d+)([-+])$/ or die;
                printf "%3d, %3d, %3d, %s, w = $w{$key} (@{$v{$key}})\n", @a;
              }
            while (--$i >= 0);
            print "Tries:\n"; }
        my $c;
        my $w = 0;
        while ($w{my $k = pop @d} > $w)
          { my ($i1,$i2,$d) = $k =~ /^(\d+),(\d+),(-?\d+)/ or die;
            $i1 != $i2 and ($c,$w) = ($k,$w{$k}), last;
            my $v = $v{$k};
            my $j = 0;
            my %h;
            while ($j <= $#$v)
              {
                my $i = $v->[$j];
                if ($h{$i} || $h{$i+$d})
                  { splice @$v, $j, 1; }
                else
                  { $h{$i} = 1; $h{$i+$d} = 1; $j++; }
              }
            my $x = @$v;
            $level > 3
              and printf "%3d, %3d, %3d  -> w = $x (@$v)\n", $i1, $i2, $d;
            $x > $w and ($c,$w) = ($k,$x); }
        $w > 1 or last;
        my ($i1,$i2,$d,$o) = $c =~ /^(\d+),(\d+),(-?\d+)([-+])$/ or die;
        my $v = $v{$c};
        if ($level > 2)
          {
            print "Chosen:\n" if ($level > 3);
            printf "%3d, %3d, %3d, $o, w = $w (@$v)\n", $i1, $i2, $d;
          }
        my $s = $x[$i2]->{'s'};
        my $t = '0' x length($s);
        foreach (@$v)
          { substr($t, -($_+1), 1) = substr($s, -($_+1), 1); }
        $t =~ s/^0*([PN])(.*?)0*$/$1$2/ or die;
        $1 eq 'P' or $t =~ tr/PN/NP/;
        print "    t  = $t\n" if ($level > 2);
        my $u1 = &clear($x[$i1], $v, $d);
        print "    u1 = $u1\n" if ($level > 2);
        my $u2 = &clear($x[$i2], $v, 0);
        print "    u2 = $u2\n" if ($level > 2);
        push @x, &makeind($t);
        $i2 >= $i1 or die;
        splice @x, $i2, 1, length($u2) > 1 ? &makeind($x[$i2]->{'s'}) : ();
        splice @x, $i1, 1, length($u1) > 1 ? &makeind($x[$i1]->{'s'}) : ()
          unless ($i1 == $i2);
        $cost++ if (length($u2) >= 1);
        $cost++ if (length($u1) >= 1);
        print "\nCurrent cost: $cost\n" if ($level > 2);
      }
    foreach (@x)
      { $cost += $#{$_->{'d'}}; }
    return $cost; }

sub makeind
  { my $s = $_[0];
    $s =~ s/0*$//;
    my $l = length $s;
    my ($i,@d);
    for ($i = 0; $i < $l; $i++)
      { push @d, $i unless (substr($s, -($i+1), 1) eq '0'); }
    $level > 3 and print "\nmakeind: $s (@d).\n";
    return { 's' => $s, 'd' => \@d }; }

sub ind
  { my $s = $_[0];
    my ($i,$r);
    for ($i = length($s) - 1; $i >= 0; $i--)
      { $r .= $i % 10; }
    return $r; }

sub clear
  { my ($href,$v,$d) = @_;
    foreach (@$v)
      { substr($href->{'s'}, -($_+$d+1), 1) = '0'; }
    $href->{'s'} =~ s/^0*(.*?)0*$/$1/;
    return $1; }

sub getrand
  { my $n = $_[0] - 2;
    my $r = '1';
    while ($n--)
      { $r .= int rand 2 }
    return $r.'1'; }

sub hsas
  { my @t = @_;
    my %h;
    foreach (@t)
      { $h{$_} = 1; }
    my $oldlevel = $level;
    my $qopt = length $_[0];
    my $bopt;
    while ($b = pop @t)
      { my $i = -1;
        while (($i = index $b, 'P0N', $i + 1) >= 0)
          { my $c = $b;
            substr($c, $i, 3) = '0PP';
            $h{$c}++ or push @t, $c; }
        while (($i = index $b, 'N0P', $i + 1) >= 0)
          { my $c = $b;
            substr($c, $i, 3) = '0NN';
            $h{$c}++ or push @t, $c; }
        @sas = ();
        $level = 0;
        my $q = &gsas($b);
        $level = $oldlevel;
        $level > 1 and print "$b: cost = $q\n";
        $q < $qopt and $qopt = $q, $bopt = $b; }
    $level and print "Optimal cost for $bopt.\n";
    return $qopt; }