#include "lexer.c"
typedef enum {
T_U32,
T_F32,
} Type;
typedef enum {
M_VALUE = 0,
M_FUNC,
M_FUNC2,
M_ARRAY,
M_INDEX,
M_PREFIX,
M_BINARY,
M_TERNARY,
} Mode;
typedef struct {
Type type;
union {
U32 u;
F32 f;
};
} Var;
typedef struct {
U32 t;
Var a, b, c;
} Params;
typedef Var (*Func)(Params p);
typedef struct {
Mode mode;
int priority;
const char *key;
Func func;
U32 k;
} OpDesc;
static inline Var vari(U32 u)
{ Var v; v.type = T_U32; v.u = u; return v; }
static inline Var varf(F32 f)
{ Var v; v.type = T_F32; v.f = f; return v; }
static inline int isint(Var v)
{ return v.type == T_U32; }
static inline int isflt(Var v)
{ return v.type == T_F32; }
static inline int isflt2(Var a, Var b)
{ return isflt(a) || isflt(b); }
static inline U32 asint(Var v)
{ return isflt(v) ? (U32)(I32)v.f : v.u; }
static inline F32 asflt(Var v)
{ return isflt(v) ? v.f : (F32)(I32)v.u; }
Var f_var(Params p)
{ return vari(p.t); }
Var f_neg(Params p)
{ return isflt(p.a) ? varf(-p.a.f) : vari(-p.a.u); }
Var f_pos(Params p)
{ return p.a; }
Var f_not(Params p)
{ return vari(~p.a.u); }
Var f_boolnot(Params p)
{ return vari(!p.a.u); }
Var f_casti(Params p)
{ return vari(asint(p.a)); }
Var f_castf(Params p)
{ return vari(asflt(p.a)); }
Var f_mul(Params p)
{ return isflt2(p.a, p.b) ? varf(asflt(p.a) * asflt(p.b)) : vari(asint(p.a) * asint(p.b)); }
Var f_div(Params p)
{ F32 a = asflt(p.a), b = asflt(p.b); return varf(b ? a/b : a); }
Var f_mod(Params p)
{ U32 a = asint(p.a), b = asint(p.b); return vari(b ? a%b : a); }
Var f_add(Params p)
{ return isflt2(p.a, p.b) ? varf(asflt(p.a) + asflt(p.b)) : vari(asint(p.a) + asint(p.b)); }
Var f_sub(Params p)
{ return isflt2(p.a, p.b) ? varf(asflt(p.a) - asflt(p.b)) : vari(asint(p.a) - asint(p.b)); }
Var f_shl(Params p)
{ return vari(asint(p.a) << asint(p.b)); }
Var f_shr(Params p)
{ return vari(asint(p.a) >> asint(p.b)); }
Var f_lt(Params p)
{ return isflt2(p.a, p.b) ? vari(asflt(p.a) < asflt(p.b)) : vari(asint(p.a) < asint(p.b)); }
Var f_gt(Params p)
{ return isflt2(p.a, p.b) ? vari(asflt(p.a) > asflt(p.b)) : vari(asint(p.a) > asint(p.b)); }
Var f_leq(Params p)
{ return isflt2(p.a, p.b) ? vari(asflt(p.a) <= asflt(p.b)) : vari(asint(p.a) <= asint(p.b)); }
Var f_geq(Params p)
{ return isflt2(p.a, p.b) ? vari(asflt(p.a) >= asflt(p.b)) : vari(asint(p.a) >= asint(p.b)); }
Var f_eq(Params p)
{ return vari(p.a.u == p.b.u); }
Var f_neq(Params p)
{ return vari(p.a.u != p.b.u); }
Var f_and(Params p)
{ return vari(asint(p.a) & asint(p.b)); }
Var f_xor(Params p)
{ return vari(asint(p.a) ^ asint(p.b)); }
Var f_or(Params p)
{ return vari(asint(p.a) | asint(p.b)); }
Var f_booland(Params p)
{ return vari(p.a.u && p.b.u); }
Var f_boolor(Params p)
{ return vari(p.a.u || p.b.u); }
Var f_cond(Params p)
{ return isflt2(p.b, p.c) ? (p.a.u ? varf(asflt(p.b)) : varf(asflt(p.c))) : (p.a.u ? vari(asint(p.b)) : vari(asint(p.c))); }
Var f_sin(Params p)
{ return varf(sinf(asflt(p.a))); }
Var f_cos(Params p)
{ return varf(cosf(asflt(p.a))); }
OpDesc ops[] = {
{ M_VALUE, 0, "" },
{ M_VALUE, 0, "t", f_var },
{ M_FUNC, 0, "sin", f_sin },
{ M_FUNC, 0, "cos", f_cos },
{ M_ARRAY, 0, "[" },
{ M_INDEX, 1, "[" },
{ M_PREFIX, 2, "+", f_pos },
{ M_PREFIX, 2, "-", f_neg },
{ M_PREFIX, 2, "~", f_not },
{ M_PREFIX, 2, "!", f_boolnot },
{ M_PREFIX, 2, "(U32)", f_casti },
{ M_PREFIX, 2, "(int)", f_casti },
{ M_PREFIX, 2, "(unsigned)", f_casti },
{ M_PREFIX, 2, "(unsigned int)", f_casti },
{ M_PREFIX, 2, "(F32)", f_castf },
{ M_PREFIX, 2, "(float)", f_castf },
{ M_PREFIX, 2, "(double)", f_castf },
{ M_BINARY, 3, "*", f_mul },
{ M_BINARY, 3, "/", f_div },
{ M_BINARY, 3, "%", f_mod },
{ M_BINARY, 4, "+", f_add },
{ M_BINARY, 4, "-", f_sub },
{ M_BINARY, 5, "<<", f_shl },
{ M_BINARY, 5, ">>", f_shr },
{ M_BINARY, 6, "<", f_lt },
{ M_BINARY, 6, ">", f_gt },
{ M_BINARY, 6, "<=", f_leq },
{ M_BINARY, 6, ">=", f_geq },
{ M_BINARY, 7, "==", f_eq },
{ M_BINARY, 7, "!=", f_neq },
{ M_BINARY, 8, "&", f_and },
{ M_BINARY, 9, "^", f_xor },
{ M_BINARY, 10, "|", f_or },
{ M_BINARY, 11, "&&", f_booland },
{ M_BINARY, 12, "||", f_boolor },
{ M_TERNARY, 13, "?", f_cond },
};
int opsCnt = sizeof(ops)/sizeof(*ops);