Blob Blame Raw


#include <ctime>
#include <cstdlib>
#include <cassert>

#include <complex>



typedef double Real;
typedef std::complex<Real> Complex;


const Real epsilon = 1e-6;
const Real epsilonSqr = epsilon*epsilon;

inline Real magSqr(Complex x)
  { return x.real()*x.real() + x.imag()*x.imag(); }
inline bool isZero(Real x)
  { return fabs(x) < epsilon; }
inline bool isZero(Complex x)
  { return magSqr(x) <= epsilonSqr; }


inline Complex verify2(Complex r, Real a, Real b, Real c) {
  assert(isZero(a*r*r + b*r + c));
  return r;
}


inline Complex verify3(Complex r, Real a, Real b, Real c, Real d) {
  assert(isZero(a*r*r*r + b*r*r + c*r + d));
  return r;
}


inline Complex verify4(Complex r, Real a, Real b, Real c, Real d, Real e) {
  assert(isZero(a*r*r*r*r + b*r*r*r + c*r*r + d*r + e));
  return r;
}


int solve2(Complex *roots, Real a, Real b, Real c) {
  if (isZero(a)) {
    if (isZero(b)) return 0;
    roots[0] = verify2(-c/b, a, b, c);
    return 1;
  }

  Real D = b*b - a*c*4;
  Real k = Real(0.5)/a;
  Complex d = D < 0 ? Complex(0, sqrt(-D)) : Complex(sqrt(D));
  roots[0] = verify2((-b - d)*k, a, b, c);
  roots[1] = verify2((-b + d)*k, a, b, c);
  return 2;
}


int solve3(Complex *roots, Real a, Real b, Real c, Real d) {
  if (isZero(a)) return solve2(roots, b, c, d);

  // x = y - b/(3*a)
  // y*y*y + p*y + q = 0
  Real p = (3*a*c - b*b)/(3*a*a);
  Real q = (27*a*a*d - 9*a*b*c + 2*b*b*b)/(27*a*a*a);

  Real Q = p*p*p/27 + q*q/4;
  Complex Qs = Q < 0 ? Complex(0, sqrt(-Q)) : Complex(sqrt(Q));
  Complex A = pow(-q/2 + Qs, Real(1)/3);
  Complex B = pow(-q/2 - Qs, Real(1)/3);

  // choose complimentary B for A (A*B must be equal -p/3)
  Complex rot(Real(-0.5), sqrt(Real(3))/2);
  if (!isZero(A*B + p/3)) B *= rot;
  if (!isZero(A*B + p/3)) B *= rot;

  Complex Y = (A - B)*Complex(0, sqrt(Real(3)));
  Complex y0 = A + B;
  Complex y1 = (-y0 - Y)/Real(2);
  Complex y2 = (-y0 + Y)/Real(2);

  Real dd = b/(3*a);
  roots[0] = verify3(y0 - dd, a, b, c, d);
  roots[1] = verify3(y1 - dd, a, b, c, d);
  roots[2] = verify3(y2 - dd, a, b, c, d);
  return 3;
}


int solve4(Complex *roots, Real a, Real b, Real c, Real d, Real e) {
  if (isZero(a)) return solve3(roots, b, c, d, e);

  //printf("%g*x^4 + %g*x^3 + %g*x^2 + %g*x + %g = 0\n", a, b, c, d, e);

  // x = y - b/(4*a)
  // y^4 + p*y^2 + q*y + r = 0
  Real dd = b/(4*a);
  Real p = (8*a*c - 3*b*b)/(8*a*a);
  Real q = (8*a*a*d - 4*a*b*c + b*b*b)/(8*a*a*a);
  Real r = (256*a*a*a*e - 64*a*a*b*d + 16*a*b*b*c - 3*b*b*b*b)/(256*a*a*a*a);
  //printf(", y^4 + %g*y^2 + %g*y + %g = 0", p, q, r);

  if (isZero(q)) {
    // biquadratic equation
    // y^4 + p*y^2 + r = 0
    Complex y[2];
    solve2(y, 1, p, r);
    y[0] = sqrt(y[0]);
    y[1] = sqrt(y[1]);
    roots[0] = verify4(-y[0] - dd, a, b, c, d, e);
    roots[1] = verify4( y[0] - dd, a, b, c, d, e);
    roots[2] = verify4(-y[1] - dd, a, b, c, d, e);
    roots[3] = verify4( y[1] - dd, a, b, c, d, e);
    return 4;
  }

  // solve cubic equation
  // z*z*z + (p/2)*z*z + ((p*p - 4*r)/16)*z - q*q/64 = 0
  Real pp = p/2;
  Real qq = (p*p - 4*r)/16;
  Real rr = -q*q/64;
  //printf("z^3 + %g*z^2 + %g*z + %g = 0\n", pp, qq, rr);
  Complex z[3];
  solve3(z, 1, pp, qq, rr);

  z[0] = sqrt(z[0]);
  z[1] = sqrt(z[1]);
  z[2] = sqrt(z[2]);

  // we need to find signs combination where following is valid:
  // (+-z0)*(+-z1)*(+-z2) = -q/8
  Complex zzz = z[0]*z[1]*z[2];
  //printf(
  //  "sqrt(z): (%g, %g), (%g, %g), (%g, %g), zzz: (%g, %g)\n",
  //  z[0].real(), z[0].imag(),
  //  z[1].real(), z[1].imag(),
  //  z[2].real(), z[2].imag(),
  //  zzz.real(), zzz.imag() );
  assert(isZero(zzz.imag()));
  assert(isZero(fabs(zzz.real()) - fabs(q/8)));
  if ((zzz.real() > 0) == (q > 0))
    z[0] = -z[0];
  assert(isZero(z[0]*z[1]*z[2] + q/8));

  roots[0] = verify4( z[0] - z[1] - z[2] - dd, a, b, c, d, e);
  roots[1] = verify4(-z[0] + z[1] - z[2] - dd, a, b, c, d, e);
  roots[2] = verify4(-z[0] - z[1] + z[2] - dd, a, b, c, d, e);
  roots[3] = verify4( z[0] + z[1] + z[2] - dd, a, b, c, d, e);
  return 4;
}


int main() {
  srand(time(NULL));

  bool failed = false;
  for(int i = 0; i < 1000000; ++i) {
    Real a = rand()%20 - 10;
    Real b = rand()%20 - 10;
    Real c = rand()%20 - 10;
    Real d = rand()%20 - 10;
    Real e = rand()%20 - 10;
    printf("%6d: %g*x^4 + %g*x^3 + %g*x^2 + %g*x + %g = 0", i, a, b, c, d, e);

    int vcnt = a ? 4
             : b ? 3
             : c ? 2
             : d ? 1 : 0;

    Complex roots[4];
    int cnt = solve4(roots, a, b, c, d, e);
    bool f = cnt != vcnt;
    for(int i = 0; i < cnt; ++i) {
      const Complex &x = roots[i];
      printf(", (%g, %g)", x.real(), x.imag());
      Complex r = a*x*x*x*x + b*x*x*x + c*x*x + d*x + e;
      if (!isZero(r))
        { printf(" ![%g, %g]", r.real(), r.imag()); f = true; }
    }

    printf(", %s\n", f ? "FAILED" : "passed");
    if (f) {
      printf("^^^^^^^^^^\n");
      failed = true;
      break;
    }
  }

  printf("\n%s\n\n", failed ? "FAILED" : "success");
  return failed;
}