Upgrade GMP from 4.3.2 to 5.0.2 on the vendor branch
[dragonfly.git] / contrib / gmp / mpn / generic / mul_fft.c
1 /* Schoenhage's fast multiplication modulo 2^N+1.
2
3    Contributed by Paul Zimmermann.
4
5    THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6    SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7    GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8
9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008,
10 2009, 2010 Free Software Foundation, Inc.
11
12 This file is part of the GNU MP Library.
13
14 The GNU MP Library is free software; you can redistribute it and/or modify
15 it under the terms of the GNU Lesser General Public License as published by
16 the Free Software Foundation; either version 3 of the License, or (at your
17 option) any later version.
18
19 The GNU MP Library is distributed in the hope that it will be useful, but
20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
21 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
22 License for more details.
23
24 You should have received a copy of the GNU Lesser General Public License
25 along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
26
27
28 /* References:
29
30    Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
31    Strassen, Computing 7, p. 281-292, 1971.
32
33    Asymptotically fast algorithms for the numerical multiplication and division
34    of polynomials with complex coefficients, by Arnold Schoenhage, Computer
35    Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
36
37    Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
38    Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
39
40    TODO:
41
42    Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
43    Zimmermann.
44
45    It might be possible to avoid a small number of MPN_COPYs by using a
46    rotating temporary or two.
47
48    Cleanup and simplify the code!
49 */
50
51 #ifdef TRACE
52 #undef TRACE
53 #define TRACE(x) x
54 #include <stdio.h>
55 #else
56 #define TRACE(x)
57 #endif
58
59 #include "gmp.h"
60 #include "gmp-impl.h"
61
62 #ifdef WANT_ADDSUB
63 #include "generic/add_n_sub_n.c"
64 #define HAVE_NATIVE_mpn_add_n_sub_n 1
65 #endif
66
67 static mp_limb_t mpn_mul_fft_internal
68 __GMP_PROTO ((mp_ptr, mp_size_t, int, mp_ptr *, mp_ptr *,
69               mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, int));
70 static void mpn_mul_fft_decompose
71 __GMP_PROTO ((mp_ptr, mp_ptr *, int, int, mp_srcptr, mp_size_t, int, int, mp_ptr));
72
73
74 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
75    We have sqr=0 if for a multiply, sqr=1 for a square.
76    There are three generations of this code; we keep the old ones as long as
77    some gmp-mparam.h is not updated.  */
78
79
80 /*****************************************************************************/
81
82 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
83
84 #ifndef FFT_TABLE3_SIZE         /* When tuning, this is define in gmp-impl.h */
85 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
86 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
87 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
88 #else
89 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
90 #endif
91 #endif
92 #endif
93
94 #ifndef FFT_TABLE3_SIZE
95 #define FFT_TABLE3_SIZE 200
96 #endif
97
98 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
99 {
100   MUL_FFT_TABLE3,
101   SQR_FFT_TABLE3
102 };
103
104 int
105 mpn_fft_best_k (mp_size_t n, int sqr)
106 {
107   FFT_TABLE_ATTRS struct fft_table_nk *fft_tab, *tab;
108   mp_size_t tab_n, thres;
109   int last_k;
110
111   fft_tab = mpn_fft_table3[sqr];
112   last_k = fft_tab->k;
113   for (tab = fft_tab + 1; ; tab++)
114     {
115       tab_n = tab->n;
116       thres = tab_n << last_k;
117       if (n <= thres)
118         break;
119       last_k = tab->k;
120     }
121   return last_k;
122 }
123
124 #define MPN_FFT_BEST_READY 1
125 #endif
126
127 /*****************************************************************************/
128
129 #if ! defined (MPN_FFT_BEST_READY)
130 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
131 {
132   MUL_FFT_TABLE,
133   SQR_FFT_TABLE
134 };
135
136 int
137 mpn_fft_best_k (mp_size_t n, int sqr)
138 {
139   int i;
140
141   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
142     if (n < mpn_fft_table[sqr][i])
143       return i + FFT_FIRST_K;
144
145   /* treat 4*last as one further entry */
146   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
147     return i + FFT_FIRST_K;
148   else
149     return i + FFT_FIRST_K + 1;
150 }
151 #endif
152
153 /*****************************************************************************/
154
155
156 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
157    i.e. smallest multiple of 2^k >= pl.
158
159    Don't declare static: needed by tuneup.
160 */
161
162 mp_size_t
163 mpn_fft_next_size (mp_size_t pl, int k)
164 {
165   pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
166   return pl << k;
167 }
168
169
170 /* Initialize l[i][j] with bitrev(j) */
171 static void
172 mpn_fft_initl (int **l, int k)
173 {
174   int i, j, K;
175   int *li;
176
177   l[0][0] = 0;
178   for (i = 1, K = 1; i <= k; i++, K *= 2)
179     {
180       li = l[i];
181       for (j = 0; j < K; j++)
182         {
183           li[j] = 2 * l[i - 1][j];
184           li[K + j] = 1 + li[j];
185         }
186     }
187 }
188
189
190 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
191    Assumes a is semi-normalized, i.e. a[n] <= 1.
192    r and a must have n+1 limbs, and not overlap.
193 */
194 static void
195 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
196 {
197   int sh;
198   mp_limb_t cc, rd;
199
200   sh = d % GMP_NUMB_BITS;
201   d /= GMP_NUMB_BITS;
202
203   if (d >= n)                   /* negate */
204     {
205       /* r[0..d-1]  <-- lshift(a[n-d]..a[n-1], sh)
206          r[d..n-1]  <-- -lshift(a[0]..a[n-d-1],  sh) */
207
208       d -= n;
209       if (sh != 0)
210         {
211           /* no out shift below since a[n] <= 1 */
212           mpn_lshift (r, a + n - d, d + 1, sh);
213           rd = r[d];
214           cc = mpn_lshiftc (r + d, a, n - d, sh);
215         }
216       else
217         {
218           MPN_COPY (r, a + n - d, d);
219           rd = a[n];
220           mpn_com (r + d, a, n - d);
221           cc = 0;
222         }
223
224       /* add cc to r[0], and add rd to r[d] */
225
226       /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
227
228       r[n] = 0;
229       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
230       cc++;
231       mpn_incr_u (r, cc);
232
233       rd++;
234       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
235       cc = (rd == 0) ? 1 : rd;
236       r = r + d + (rd == 0);
237       mpn_incr_u (r, cc);
238     }
239   else
240     {
241       /* r[0..d-1]  <-- -lshift(a[n-d]..a[n-1], sh)
242          r[d..n-1]  <-- lshift(a[0]..a[n-d-1],  sh)  */
243       if (sh != 0)
244         {
245           /* no out bits below since a[n] <= 1 */
246           mpn_lshiftc (r, a + n - d, d + 1, sh);
247           rd = ~r[d];
248           /* {r, d+1} = {a+n-d, d+1} << sh */
249           cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
250         }
251       else
252         {
253           /* r[d] is not used below, but we save a test for d=0 */
254           mpn_com (r, a + n - d, d + 1);
255           rd = a[n];
256           MPN_COPY (r + d, a, n - d);
257           cc = 0;
258         }
259
260       /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
261
262       /* if d=0 we just have r[0]=a[n] << sh */
263       if (d != 0)
264         {
265           /* now add 1 in r[0], subtract 1 in r[d] */
266           if (cc-- == 0) /* then add 1 to r[0] */
267             cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
268           cc = mpn_sub_1 (r, r, d, cc) + 1;
269           /* add 1 to cc instead of rd since rd might overflow */
270         }
271
272       /* now subtract cc and rd from r[d..n] */
273
274       r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
275       r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
276       if (r[n] & GMP_LIMB_HIGHBIT)
277         r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
278     }
279 }
280
281
282 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
283    Assumes a and b are semi-normalized.
284 */
285 static inline void
286 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
287 {
288   mp_limb_t c, x;
289
290   c = a[n] + b[n] + mpn_add_n (r, a, b, n);
291   /* 0 <= c <= 3 */
292
293 #if 1
294   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
295      result is slower code, of course.  But the following outsmarts GCC.  */
296   x = (c - 1) & -(c != 0);
297   r[n] = c - x;
298   MPN_DECR_U (r, n + 1, x);
299 #endif
300 #if 0
301   if (c > 1)
302     {
303       r[n] = 1;                       /* r[n] - c = 1 */
304       MPN_DECR_U (r, n + 1, c - 1);
305     }
306   else
307     {
308       r[n] = c;
309     }
310 #endif
311 }
312
313 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
314    Assumes a and b are semi-normalized.
315 */
316 static inline void
317 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
318 {
319   mp_limb_t c, x;
320
321   c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
322   /* -2 <= c <= 1 */
323
324 #if 1
325   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
326      result is slower code, of course.  But the following outsmarts GCC.  */
327   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
328   r[n] = x + c;
329   MPN_INCR_U (r, n + 1, x);
330 #endif
331 #if 0
332   if ((c & GMP_LIMB_HIGHBIT) != 0)
333     {
334       r[n] = 0;
335       MPN_INCR_U (r, n + 1, -c);
336     }
337   else
338     {
339       r[n] = c;
340     }
341 #endif
342 }
343
344 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
345           N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
346    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
347
348 static void
349 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
350              mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
351 {
352   if (K == 2)
353     {
354       mp_limb_t cy;
355 #if HAVE_NATIVE_mpn_add_n_sub_n
356       cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
357 #else
358       MPN_COPY (tp, Ap[0], n + 1);
359       mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
360       cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
361 #endif
362       if (Ap[0][n] > 1) /* can be 2 or 3 */
363         Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
364       if (cy) /* Ap[inc][n] can be -1 or -2 */
365         Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
366     }
367   else
368     {
369       int j;
370       int *lk = *ll;
371
372       mpn_fft_fft (Ap,     K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
373       mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
374       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
375          A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
376       for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
377         {
378           /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
379              Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
380           mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
381           mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
382           mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
383         }
384     }
385 }
386
387 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
388           N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
389    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
390    tp must have space for 2*(n+1) limbs.
391 */
392
393
394 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
395    by subtracting that modulus if necessary.
396
397    If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
398    borrow and the limbs must be zeroed out again.  This will occur very
399    infrequently.  */
400
401 static inline void
402 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
403 {
404   if (ap[n] != 0)
405     {
406       MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
407       if (ap[n] == 0)
408         {
409           /* This happens with very low probability; we have yet to trigger it,
410              and thereby make sure this code is correct.  */
411           MPN_ZERO (ap, n);
412           ap[n] = 1;
413         }
414       else
415         ap[n] = 0;
416     }
417 }
418
419 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
420 static void
421 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
422 {
423   int i;
424   int sqr = (ap == bp);
425   TMP_DECL;
426
427   TMP_MARK;
428
429   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
430     {
431       int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
432       int **fft_l;
433       mp_ptr *Ap, *Bp, A, B, T;
434
435       k = mpn_fft_best_k (n, sqr);
436       K2 = 1 << k;
437       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
438       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
439       M2 = n * GMP_NUMB_BITS >> k;
440       l = n >> k;
441       Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
442       /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
443       nprime2 = Nprime2 / GMP_NUMB_BITS;
444
445       /* we should ensure that nprime2 is a multiple of the next K */
446       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
447         {
448           unsigned long K3;
449           for (;;)
450             {
451               K3 = 1L << mpn_fft_best_k (nprime2, sqr);
452               if ((nprime2 & (K3 - 1)) == 0)
453                 break;
454               nprime2 = (nprime2 + K3 - 1) & -K3;
455               Nprime2 = nprime2 * GMP_LIMB_BITS;
456               /* warning: since nprime2 changed, K3 may change too! */
457             }
458         }
459       ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
460
461       Mp2 = Nprime2 >> k;
462
463       Ap = TMP_ALLOC_MP_PTRS (K2);
464       Bp = TMP_ALLOC_MP_PTRS (K2);
465       A = TMP_ALLOC_LIMBS (2 * (nprime2 + 1) << k);
466       T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1));
467       B = A + ((nprime2 + 1) << k);
468       fft_l = TMP_ALLOC_TYPE (k + 1, int *);
469       for (i = 0; i <= k; i++)
470         fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
471       mpn_fft_initl (fft_l, k);
472
473       TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
474                     n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
475       for (i = 0; i < K; i++, ap++, bp++)
476         {
477           mp_limb_t cy;
478           mpn_fft_normalize (*ap, n);
479           if (!sqr)
480             mpn_fft_normalize (*bp, n);
481
482           mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
483           if (!sqr)
484             mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
485
486           cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
487                                      l, Mp2, fft_l, T, sqr);
488           (*ap)[n] = cy;
489         }
490     }
491   else
492     {
493       mp_ptr a, b, tp, tpn;
494       mp_limb_t cc;
495       int n2 = 2 * n;
496       tp = TMP_ALLOC_LIMBS (n2);
497       tpn = tp + n;
498       TRACE (printf ("  mpn_mul_n %d of %ld limbs\n", K, n));
499       for (i = 0; i < K; i++)
500         {
501           a = *ap++;
502           b = *bp++;
503           if (sqr)
504             mpn_sqr (tp, a, n);
505           else
506             mpn_mul_n (tp, b, a, n);
507           if (a[n] != 0)
508             cc = mpn_add_n (tpn, tpn, b, n);
509           else
510             cc = 0;
511           if (b[n] != 0)
512             cc += mpn_add_n (tpn, tpn, a, n) + a[n];
513           if (cc != 0)
514             {
515               /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
516               cc = mpn_add_1 (tp, tp, n2, cc);
517               ASSERT (cc == 0);
518             }
519           a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
520         }
521     }
522   TMP_FREE;
523 }
524
525
526 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
527    output: K*A[0] K*A[K-1] ... K*A[1].
528    Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
529    This condition is also fulfilled at exit.
530 */
531 static void
532 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
533 {
534   if (K == 2)
535     {
536       mp_limb_t cy;
537 #if HAVE_NATIVE_mpn_add_n_sub_n
538       cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
539 #else
540       MPN_COPY (tp, Ap[0], n + 1);
541       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
542       cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
543 #endif
544       if (Ap[0][n] > 1) /* can be 2 or 3 */
545         Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
546       if (cy) /* Ap[1][n] can be -1 or -2 */
547         Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
548     }
549   else
550     {
551       int j, K2 = K >> 1;
552
553       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
554       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
555       /* A[j]     <- A[j] + omega^j A[j+K/2]
556          A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
557       for (j = 0; j < K2; j++, Ap++)
558         {
559           /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
560              Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
561           mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
562           mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
563           mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
564         }
565     }
566 }
567
568
569 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
570 static void
571 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
572 {
573   int i;
574
575   ASSERT (r != a);
576   i = 2 * n * GMP_NUMB_BITS - k;
577   mpn_fft_mul_2exp_modF (r, a, i, n);
578   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
579   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
580   mpn_fft_normalize (r, n);
581 }
582
583
584 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
585    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
586    then {rp,n}=0.
587 */
588 static int
589 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
590 {
591   mp_size_t l;
592   long int m;
593   mp_limb_t cc;
594   int rpn;
595
596   ASSERT ((n <= an) && (an <= 3 * n));
597   m = an - 2 * n;
598   if (m > 0)
599     {
600       l = n;
601       /* add {ap, m} and {ap+2n, m} in {rp, m} */
602       cc = mpn_add_n (rp, ap, ap + 2 * n, m);
603       /* copy {ap+m, n-m} to {rp+m, n-m} */
604       rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
605     }
606   else
607     {
608       l = an - n; /* l <= n */
609       MPN_COPY (rp, ap, n);
610       rpn = 0;
611     }
612
613   /* remains to subtract {ap+n, l} from {rp, n+1} */
614   cc = mpn_sub_n (rp, rp, ap + n, l);
615   rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
616   if (rpn < 0) /* necessarily rpn = -1 */
617     rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
618   return rpn;
619 }
620
621 /* store in A[0..nprime] the first M bits from {n, nl},
622    in A[nprime+1..] the following M bits, ...
623    Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
624    T must have space for at least (nprime + 1) limbs.
625    We must have nl <= 2*K*l.
626 */
627 static void
628 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
629                        mp_size_t nl, int l, int Mp, mp_ptr T)
630 {
631   int i, j;
632   mp_ptr tmp;
633   mp_size_t Kl = K * l;
634   TMP_DECL;
635   TMP_MARK;
636
637   if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
638     {
639       mp_size_t dif = nl - Kl;
640       mp_limb_signed_t cy;
641
642       tmp = TMP_ALLOC_LIMBS(Kl + 1);
643
644       if (dif > Kl)
645         {
646           int subp = 0;
647
648           cy = mpn_sub_n (tmp, n, n + Kl, Kl);
649           n += 2 * Kl;
650           dif -= Kl;
651
652           /* now dif > 0 */
653           while (dif > Kl)
654             {
655               if (subp)
656                 cy += mpn_sub_n (tmp, tmp, n, Kl);
657               else
658                 cy -= mpn_add_n (tmp, tmp, n, Kl);
659               subp ^= 1;
660               n += Kl;
661               dif -= Kl;
662             }
663           /* now dif <= Kl */
664           if (subp)
665             cy += mpn_sub (tmp, tmp, Kl, n, dif);
666           else
667             cy -= mpn_add (tmp, tmp, Kl, n, dif);
668           if (cy >= 0)
669             cy = mpn_add_1 (tmp, tmp, Kl, cy);
670           else
671             cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
672         }
673       else /* dif <= Kl, i.e. nl <= 2 * Kl */
674         {
675           cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
676           cy = mpn_add_1 (tmp, tmp, Kl, cy);
677         }
678       tmp[Kl] = cy;
679       nl = Kl + 1;
680       n = tmp;
681     }
682   for (i = 0; i < K; i++)
683     {
684       Ap[i] = A;
685       /* store the next M bits of n into A[0..nprime] */
686       if (nl > 0) /* nl is the number of remaining limbs */
687         {
688           j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
689           nl -= j;
690           MPN_COPY (T, n, j);
691           MPN_ZERO (T + j, nprime + 1 - j);
692           n += l;
693           mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
694         }
695       else
696         MPN_ZERO (A, nprime + 1);
697       A += nprime + 1;
698     }
699   ASSERT_ALWAYS (nl == 0);
700   TMP_FREE;
701 }
702
703 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
704    op is pl limbs, its high bit is returned.
705    One must have pl = mpn_fft_next_size (pl, k).
706    T must have space for 2 * (nprime + 1) limbs.
707 */
708
709 static mp_limb_t
710 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
711                       mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
712                       mp_size_t nprime, mp_size_t l, mp_size_t Mp,
713                       int **fft_l, mp_ptr T, int sqr)
714 {
715   int K, i, pla, lo, sh, j;
716   mp_ptr p;
717   mp_limb_t cc;
718
719   K = 1 << k;
720
721   /* direct fft's */
722   mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
723   if (!sqr)
724     mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
725
726   /* term to term multiplications */
727   mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
728
729   /* inverse fft's */
730   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
731
732   /* division of terms after inverse fft */
733   Bp[0] = T + nprime + 1;
734   mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
735   for (i = 1; i < K; i++)
736     {
737       Bp[i] = Ap[i - 1];
738       mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
739     }
740
741   /* addition of terms in result p */
742   MPN_ZERO (T, nprime + 1);
743   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
744   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
745   MPN_ZERO (p, pla);
746   cc = 0; /* will accumulate the (signed) carry at p[pla] */
747   for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
748     {
749       mp_ptr n = p + sh;
750
751       j = (K - i) & (K - 1);
752
753       if (mpn_add_n (n, n, Bp[j], nprime + 1))
754         cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
755                           pla - sh - nprime - 1, CNST_LIMB(1));
756       T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
757       if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
758         { /* subtract 2^N'+1 */
759           cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
760           cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
761         }
762     }
763   if (cc == -CNST_LIMB(1))
764     {
765       if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
766         {
767           /* p[pla-pl]...p[pla-1] are all zero */
768           mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
769           mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
770         }
771     }
772   else if (cc == 1)
773     {
774       if (pla >= 2 * pl)
775         {
776           while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
777             ;
778         }
779       else
780         {
781           cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
782           ASSERT (cc == 0);
783         }
784     }
785   else
786     ASSERT (cc == 0);
787
788   /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
789      < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
790      < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
791   return mpn_fft_norm_modF (op, pl, p, pla);
792 }
793
794 /* return the lcm of a and 2^k */
795 static unsigned long int
796 mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
797 {
798   unsigned long int l = k;
799
800   while (a % 2 == 0 && k > 0)
801     {
802       a >>= 1;
803       k --;
804     }
805   return a << l;
806 }
807
808
809 mp_limb_t
810 mpn_mul_fft (mp_ptr op, mp_size_t pl,
811              mp_srcptr n, mp_size_t nl,
812              mp_srcptr m, mp_size_t ml,
813              int k)
814 {
815   int K, maxLK, i;
816   mp_size_t N, Nprime, nprime, M, Mp, l;
817   mp_ptr *Ap, *Bp, A, T, B;
818   int **fft_l;
819   int sqr = (n == m && nl == ml);
820   mp_limb_t h;
821   TMP_DECL;
822
823   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
824   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
825
826   TMP_MARK;
827   N = pl * GMP_NUMB_BITS;
828   fft_l = TMP_ALLOC_TYPE (k + 1, int *);
829   for (i = 0; i <= k; i++)
830     fft_l[i] = TMP_ALLOC_TYPE (1 << i, int);
831   mpn_fft_initl (fft_l, k);
832   K = 1 << k;
833   M = N >> k;   /* N = 2^k M */
834   l = 1 + (M - 1) / GMP_NUMB_BITS;
835   maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
836
837   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
838   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
839   nprime = Nprime / GMP_NUMB_BITS;
840   TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
841                  N, K, M, l, maxLK, Nprime, nprime));
842   /* we should ensure that recursively, nprime is a multiple of the next K */
843   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
844     {
845       unsigned long K2;
846       for (;;)
847         {
848           K2 = 1L << mpn_fft_best_k (nprime, sqr);
849           if ((nprime & (K2 - 1)) == 0)
850             break;
851           nprime = (nprime + K2 - 1) & -K2;
852           Nprime = nprime * GMP_LIMB_BITS;
853           /* warning: since nprime changed, K2 may change too! */
854         }
855       TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
856     }
857   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
858
859   T = TMP_ALLOC_LIMBS (2 * (nprime + 1));
860   Mp = Nprime >> k;
861
862   TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
863                 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
864          printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
865
866   A = TMP_ALLOC_LIMBS (K * (nprime + 1));
867   Ap = TMP_ALLOC_MP_PTRS (K);
868   mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
869   if (sqr)
870     {
871       mp_size_t pla;
872       pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
873       B = TMP_ALLOC_LIMBS (pla);
874       Bp = TMP_ALLOC_MP_PTRS (K);
875     }
876   else
877     {
878       B = TMP_ALLOC_LIMBS (K * (nprime + 1));
879       Bp = TMP_ALLOC_MP_PTRS (K);
880       mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
881     }
882   h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
883
884   TMP_FREE;
885   return h;
886 }
887
888 #if WANT_OLD_FFT_FULL
889 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
890 void
891 mpn_mul_fft_full (mp_ptr op,
892                   mp_srcptr n, mp_size_t nl,
893                   mp_srcptr m, mp_size_t ml)
894 {
895   mp_ptr pad_op;
896   mp_size_t pl, pl2, pl3, l;
897   int k2, k3;
898   int sqr = (n == m && nl == ml);
899   int cc, c2, oldcc;
900
901   pl = nl + ml; /* total number of limbs of the result */
902
903   /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
904      We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
905      pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
906      and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
907      (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
908      We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
909      which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
910
911   /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
912
913   pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
914   do
915     {
916       pl2++;
917       k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
918       pl2 = mpn_fft_next_size (pl2, k2);
919       pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
920                             thus pl2 / 2 is exact */
921       k3 = mpn_fft_best_k (pl3, sqr);
922     }
923   while (mpn_fft_next_size (pl3, k3) != pl3);
924
925   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
926                  nl, ml, pl2, pl3, k2));
927
928   ASSERT_ALWAYS(pl3 <= pl);
929   cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
930   ASSERT(cc == 0);
931   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
932   cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
933   cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
934   /* 0 <= cc <= 1 */
935   ASSERT(0 <= cc && cc <= 1);
936   l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
937   c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
938   cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
939   ASSERT(-1 <= cc && cc <= 1);
940   if (cc < 0)
941     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
942   ASSERT(0 <= cc && cc <= 1);
943   /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
944   oldcc = cc;
945 #if HAVE_NATIVE_mpn_add_n_sub_n
946   c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
947   /* c2 & 1 is the borrow, c2 & 2 is the carry */
948   cc += c2 >> 1; /* carry out from high <- low + high */
949   c2 = c2 & 1; /* borrow out from low <- low - high */
950 #else
951   {
952     mp_ptr tmp;
953     TMP_DECL;
954
955     TMP_MARK;
956     tmp = TMP_ALLOC_LIMBS (l);
957     MPN_COPY (tmp, pad_op, l);
958     c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
959     cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
960     TMP_FREE;
961   }
962 #endif
963   c2 += oldcc;
964   /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
965      at pad_op + l, cc is the carry at pad_op + pl2 */
966   /* 0 <= cc <= 2 */
967   cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
968   /* -1 <= cc <= 2 */
969   if (cc > 0)
970     cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
971   /* now -1 <= cc <= 0 */
972   if (cc < 0)
973     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
974   /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
975   if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
976     cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
977   /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
978      out below */
979   mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
980   if (cc) /* then cc=1 */
981     pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
982   /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
983      mod 2^(pl2*GMP_NUMB_BITS) + 1 */
984   c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
985   /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
986   MPN_COPY (op + pl3, pad_op, pl - pl3);
987   ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
988   __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
989   /* since the final result has at most pl limbs, no carry out below */
990   mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
991 }
992 #endif