import { Decimal } from "decimal.js";
import { Class } from "utility-types";

Decimal.set({ precision: 28 });

class FormulaException extends Error {}
class ParseException extends FormulaException {}

type Params = {
    [key: string]: Decimal;
};

export class Formula {
    static IDENT_RE = /^[a-zA-Z_][a-zA-Z0-9_]*$/;

    static parseBlob(o: string) {
        return Formula.parse(JSON.parse(o));
    }

    static fromSerialized(o: string) {
        return Formula.parseBlob(o);
    }

    static parse(o: string | Array<string>): Formula {
        if (Array.isArray(o)) {
            if (o.length < 2) {
                throw new ParseException(
                    `Function call ${o.toString()} must take at least 1 argument`
                );
            }
            const name = o[0];
            if (typeof name !== "string") {
                throw new ParseException(
                    `Expected string function name, got ${(
                        name as any
                    ).toString()}`
                );
            }
            const args: Array<Formula> = o.slice(1).map(Formula.parse);
            if (!FUNCTIONS[name]) {
                throw new ParseException(`Unknown function ${name}`);
            }
            return new FUNCTIONS[name](args);
        }
        if (typeof o === "string") {
            if (Formula.IDENT_RE.test(o)) {
                return new VariableLookup(o);
            }
            try {
                const d: Decimal = new Decimal(o);
                return new Constant(d);
            } catch (err: any) {
                throw new ParseException(err.toString());
            }
        } else {
            throw new ParseException(
                `Unexpected JSON object ${(o as any).toString()}`
            );
        }
    }

    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    evaluate(params: Params): Decimal {
        throw new Error("abstract method");
    }

    getVars(): Set<string> {
        return new Set();
    }
}

class VariableLookup extends Formula {
    name: string;

    constructor(name: string) {
        super();
        this.name = name;
    }

    evaluate(params: Params): Decimal {
        return params[this.name];
    }

    getVars(): Set<string> {
        return new Set([this.name]);
    }
}

class Constant extends Formula {
    value: Decimal;

    constructor(value: Decimal) {
        super();
        this.value = value;
    }

    evaluate(): Decimal {
        return this.value;
    }
}

class FunctionCall extends Formula {
    args: Array<Formula>;

    arity?: number | undefined = undefined;

    constructor(args: Array<Formula>) {
        super();
        this.args = args;
        if (this.arity && args.length < this.arity) {
            throw new ParseException(
                `Expected ${this.arity} args, got ${args.length}`
            );
        }
    }

    getVars(): Set<string> {
        const result = new Set<string>();
        for (const arg of this.args) {
            for (const name of arg.getVars()) {
                result.add(name);
            }
        }
        return result;
    }
}

class StrictFunctionCall extends FunctionCall {
    evaluate(params: Params): Decimal {
        return this.evalStrict(this.args.map((arg) => arg.evaluate(params)));
    }

    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    evalStrict(values: Array<Decimal>): Decimal {
        throw new Error("Abstract method");
    }
}

class DecimalOpCall extends StrictFunctionCall {
    opName!: "sub" | "div" | "mod" | "pow";

    arity = 2;

    evalStrict([a, b]: Array<Decimal>): Decimal {
        const result = a[this.opName](b);
        return result;
    }
}

class AssociativeFunctionCall extends StrictFunctionCall {
    evalStrict(values: Array<Decimal>): Decimal {
        return values.reduce(this.op.bind(this));
    }

    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    op(lhs: Decimal, rhs: Decimal): Decimal {
        throw new Error("Abstract method");
    }
}

class ComparisonCall extends FunctionCall {
    opName!: "lt" | "lte" | "gt" | "gte" | "eq";

    evaluate(params: Params): Decimal {
        let cur = this.args[0].evaluate(params);
        let next: Decimal;
        for (const arg of this.args.slice(1)) {
            next = arg.evaluate(params);
            if (!cur[this.opName](next)) {
                return new Decimal(0);
            }
            cur = next;
        }
        return new Decimal(1);
    }
}

class RoundingCall extends StrictFunctionCall {
    arity = 2;

    Context!: Class<Decimal>;

    evalStrict([n, roundTo]: Array<Decimal>): Decimal {
        // return n.toNearest(roundTo, this.mode);
        const result = new Decimal(
            new this.Context(n.div(roundTo)).round().mul(roundTo)
        );
        return result;
    }
}

const FUNCTIONS: {
    [key: string]: Class<FunctionCall>;
} = {
    "+": class extends AssociativeFunctionCall {
        op(lhs: Decimal, rhs: Decimal) {
            return lhs.add(rhs);
        }
    },
    "-": class extends DecimalOpCall {
        opName = "sub" as const;
    },
    "*": class extends AssociativeFunctionCall {
        op(lhs: Decimal, rhs: Decimal) {
            return lhs.mul(rhs);
        }
    },
    // todo(ben) catch zeroes here
    "/": class extends DecimalOpCall {
        opName = "div" as const;
    },
    "%": class extends DecimalOpCall {
        opName = "mod" as const;
    },
    "^": class extends DecimalOpCall {
        opName = "pow" as const;
    },
    min: class extends StrictFunctionCall {
        evalStrict(values: Array<Decimal>) {
            return Decimal.min(...values);
        }
    },
    max: class extends StrictFunctionCall {
        evalStrict(values: Array<Decimal>) {
            return Decimal.max(...values);
        }
    },
    if: class extends FunctionCall {
        arity = 3;

        evaluate(params: Params): Decimal {
            const cond = this.args[0].evaluate(params);
            if (cond.equals(0)) {
                return this.args[2].evaluate(params);
            }
            return this.args[1].evaluate(params);
        }
    },
    and: class extends FunctionCall {
        evaluate(params: Params): Decimal {
            let result = new Decimal(1);
            for (const arg of this.args) {
                result = arg.evaluate(params);
                if (result.equals(0)) {
                    return new Decimal(0);
                }
            }
            return result;
        }
    },
    or: class extends FunctionCall {
        evaluate(params: Params): Decimal {
            for (const arg of this.args) {
                const result = arg.evaluate(params);
                if (!result.equals(0)) {
                    return result;
                }
            }
            return new Decimal(0);
        }
    },
    "<": class extends ComparisonCall {
        opName = "lt" as const;
    },
    "<=": class extends ComparisonCall {
        opName = "lte" as const;
    },
    ">": class extends ComparisonCall {
        opName = "gt" as const;
    },
    ">=": class extends ComparisonCall {
        opName = "gte" as const;
    },
    "=": class extends ComparisonCall {
        opName = "eq" as const;
    },
    round_up: class extends RoundingCall {
        Context = Decimal.clone({ rounding: Decimal.ROUND_CEIL });
    },
    round_down: class extends RoundingCall {
        Context = Decimal.clone({ rounding: Decimal.ROUND_FLOOR });
    },
    round_even: class extends RoundingCall {
        Context = Decimal.clone({ rounding: Decimal.ROUND_HALF_EVEN });
    },
};
