From 20b3dcb2e97d8188df44e7801e19f7a8b200f6cf Mon Sep 17 00:00:00 2001
From: Ariya Hidayat <ariya@metabase.com>
Date: Fri, 3 Dec 2021 15:18:48 -0800
Subject: [PATCH] Custom expression: improve the type-checker (#19187)

Use an improved semantic validation layered on top of the existing
resolver (around ~40 LoC) instead of its own dedicated logic
(~240 LoC).  Not only the improved type-checker is leaner and faster,
it's also more accurate in some corner cases.
The type-checking process is also deferred to the compilation stage,
instead of earlier in the parsing stage.
---
 .../src/metabase/lib/expressions/parser.js    |   3 -
 .../src/metabase/lib/expressions/process.js   |   6 +-
 .../src/metabase/lib/expressions/resolver.js  |  70 ++++-
 .../metabase/lib/expressions/typechecker.js   | 243 ------------------
 .../lib/expressions/compile.unit.spec.js      | 181 -------------
 .../lib/expressions/diagnostics.unit.spec.js  |   2 +-
 .../lib/expressions/parser.unit.spec.js       |  15 --
 .../lib/expressions/resolver.unit.spec.js     | 102 +++++++-
 .../lib/expressions/typechecker.unit.spec.js  | 238 +----------------
 .../expressions/typeinferencer.unit.spec.js   |   4 +-
 ...s-percentile-accepts-two-params.cy.spec.js |   5 +-
 11 files changed, 175 insertions(+), 694 deletions(-)
 delete mode 100644 frontend/src/metabase/lib/expressions/typechecker.js
 delete mode 100644 frontend/test/metabase/lib/expressions/compile.unit.spec.js
 rename frontend/test/metabase/scenarios/{custom-column => question}/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js (84%)

diff --git a/frontend/src/metabase/lib/expressions/parser.js b/frontend/src/metabase/lib/expressions/parser.js
index 074e7e16a39..be608e24d3b 100644
--- a/frontend/src/metabase/lib/expressions/parser.js
+++ b/frontend/src/metabase/lib/expressions/parser.js
@@ -26,7 +26,6 @@ import {
 } from "./lexer";
 
 import { isExpressionType, getFunctionArgType } from ".";
-import { typeCheck } from "./typechecker";
 
 export class ExpressionParser extends CstParser {
   constructor(config = {}) {
@@ -443,7 +442,6 @@ export function parse({
   if (parserErrors.length > 0 && !recover) {
     throw parserErrors;
   }
-  const { typeErrors } = typeCheck(cst, startRule || "expression");
   const parserRecovered = !!(cst && parserErrors.length > 0);
 
   return {
@@ -453,7 +451,6 @@ export function parse({
     parserRecovered,
     parserErrors,
     lexerErrors,
-    typeErrors,
   };
 }
 
diff --git a/frontend/src/metabase/lib/expressions/process.js b/frontend/src/metabase/lib/expressions/process.js
index b8cb9210b8d..358792cd60d 100644
--- a/frontend/src/metabase/lib/expressions/process.js
+++ b/frontend/src/metabase/lib/expressions/process.js
@@ -11,14 +11,14 @@ export function processSource(options) {
   let compileError;
 
   // PARSE
-  const { cst, tokenVector, parserErrors, typeErrors } = parse({
+  const { cst, tokenVector, parserErrors } = parse({
     ...options,
     recover: true,
   });
 
   // COMPILE
-  if (typeErrors.length > 0 || parserErrors.length > 0) {
-    compileError = typeErrors.concat(parserErrors);
+  if (parserErrors.length > 0) {
+    compileError = parserErrors;
   } else {
     try {
       expression = compile({ cst, tokenVector, ...options });
diff --git a/frontend/src/metabase/lib/expressions/resolver.js b/frontend/src/metabase/lib/expressions/resolver.js
index bb0dabdc6b9..47bc0d0d320 100644
--- a/frontend/src/metabase/lib/expressions/resolver.js
+++ b/frontend/src/metabase/lib/expressions/resolver.js
@@ -1,3 +1,5 @@
+import { ngettext, msgid, t } from "ttag";
+
 import { OPERATOR as OP } from "./tokenizer";
 import { MBQL_CLAUSES } from "./index";
 
@@ -35,7 +37,20 @@ function findMBQL(op) {
   return clause;
 }
 
-export function resolve(expression, type, fn) {
+const isCompatible = (a, b) => {
+  if (a === b) {
+    return true;
+  }
+  if (a === "expression" && (b === "number" || b === "string")) {
+    return true;
+  }
+  if (a === "aggregation" && b === "number") {
+    return true;
+  }
+  return false;
+};
+
+export function resolve(expression, type = "expression", fn = undefined) {
   if (Array.isArray(expression)) {
     const [op, ...operands] = expression;
 
@@ -49,13 +64,20 @@ export function resolve(expression, type, fn) {
     if (LOGICAL_OPS.includes(op)) {
       operandType = "boolean";
     } else if (NUMBER_OPS.includes(op) || op === "coalesce") {
-      operandType = type;
+      operandType = type === "aggregation" ? type : "number";
     } else if (COMPARISON_OPS.includes(op)) {
       operandType = "expression";
+      const [firstOperand] = operands;
+      if (typeof firstOperand !== "undefined" && !Array.isArray(firstOperand)) {
+        throw new Error(t`Expecting field but found ${firstOperand}`);
+      }
     } else if (op === "concat") {
       operandType = "string";
     } else if (op === "case") {
       const [pairs, options] = operands;
+      if (pairs.length < 1) {
+        throw new Error(t`CASE expects 2 arguments or more`);
+      }
 
       const resolvedPairs = pairs.map(([tst, val]) => [
         resolve(tst, "boolean", fn),
@@ -80,13 +102,45 @@ export function resolve(expression, type, fn) {
     }
 
     const clause = findMBQL(op);
-    if (clause) {
-      const { args } = clause;
-      return [
-        op,
-        ...operands.map((operand, i) => resolve(operand, args[i], fn)),
-      ];
+    if (!clause) {
+      throw new Error(t`Unknown function ${op}`);
+    }
+    const { displayName, args, multiple, hasOptions } = clause;
+    if (!isCompatible(type, clause.type)) {
+      throw new Error(
+        t`Expecting ${type} but found function ${displayName} returning ${clause.type}`,
+      );
     }
+    if (!multiple) {
+      const expectedArgsLength = args.length;
+      const maxArgCount = hasOptions
+        ? expectedArgsLength + 1
+        : expectedArgsLength;
+      if (
+        operands.length < expectedArgsLength ||
+        operands.length > maxArgCount
+      ) {
+        throw new Error(
+          ngettext(
+            msgid`Function ${displayName} expects ${expectedArgsLength} argument`,
+            `Function ${displayName} expects ${expectedArgsLength} arguments`,
+            expectedArgsLength,
+          ),
+        );
+      }
+    }
+    const resolvedOperands = operands.map((operand, i) => {
+      if (i >= args.length) {
+        // as-is, optional object for e.g. ends-with, time-interval, etc
+        return operand;
+      }
+      return resolve(operand, args[i], fn);
+    });
+    return [op, ...resolvedOperands];
+  } else if (!isCompatible(type, typeof expression)) {
+    throw new Error(
+      t`Expecting ${type} but found ${JSON.stringify(expression)}`,
+    );
   }
   return expression;
 }
diff --git a/frontend/src/metabase/lib/expressions/typechecker.js b/frontend/src/metabase/lib/expressions/typechecker.js
deleted file mode 100644
index 82ce6632100..00000000000
--- a/frontend/src/metabase/lib/expressions/typechecker.js
+++ /dev/null
@@ -1,243 +0,0 @@
-import { getIn } from "icepick";
-import { ngettext, msgid, t } from "ttag";
-import { ExpressionVisitor } from "./visitor";
-import { CLAUSE_TOKENS } from "./lexer";
-
-import { MBQL_CLAUSES, getMBQLName } from "./config";
-
-export function typeCheck(cst, rootType) {
-  class TypeChecker extends ExpressionVisitor {
-    constructor() {
-      super();
-      this.typeStack = [rootType];
-      this.errors = [];
-    }
-
-    logicalOrExpression(ctx) {
-      this.typeStack.unshift("boolean");
-      const result = super.logicalOrExpression(ctx);
-      this.typeStack.shift();
-      return result;
-    }
-    logicalAndExpression(ctx) {
-      this.typeStack.unshift("boolean");
-      const result = super.logicalAndExpression(ctx);
-      this.typeStack.shift();
-      return result;
-    }
-    logicalNotExpression(ctx) {
-      this.typeStack.unshift("boolean");
-      const result = super.logicalNotExpression(ctx);
-      this.typeStack.shift();
-      return result;
-    }
-    relationalExpression(ctx) {
-      this.typeStack.unshift("number");
-      const result = super.relationalExpression(ctx);
-      this.typeStack.shift();
-
-      // backward-compatibility: literal on the left-hand side isn't allowed (MBQL limitation)
-      if (ctx.operands.length > 1) {
-        const lhs = ctx.operands[0];
-        if (lhs.name === "numberLiteral") {
-          const literal = getIn(lhs, ["children", "NumberLiteral", 0, "image"]);
-          const message = t`Expecting field but found ${literal}`;
-          this.errors.push({ message });
-        }
-      }
-      return result;
-    }
-
-    caseExpression(ctx) {
-      const type = this.typeStack[0];
-      const args = ctx.arguments || [];
-      if (args.length < 2) {
-        this.errors.push({ message: t`CASE expects 2 arguments or more` });
-        return [];
-      }
-      return args.map((arg, index) => {
-        // argument 0, 2, 4, ...(even) is always a boolean, ...
-        const argType = index & 1 ? type : "boolean";
-        // ... except the very last one
-        const lastArg = index === args.length - 1;
-        this.typeStack.unshift(lastArg ? type : argType);
-        const result = this.visit(arg);
-        this.typeStack.shift();
-        return result;
-      });
-    }
-
-    functionExpression(ctx) {
-      const args = ctx.arguments || [];
-      const functionToken = ctx.functionName[0].tokenType;
-      const clause = CLAUSE_TOKENS.get(functionToken);
-      const name = functionToken.name;
-
-      // check for return value sub-type mismatch
-      const type = this.typeStack[0];
-      if (type === "number") {
-        const op = getMBQLName(name);
-        const returnType = MBQL_CLAUSES[op].type;
-        if (
-          returnType !== "number" &&
-          returnType !== "string" &&
-          returnType !== "expression"
-        ) {
-          const message = t`Expecting ${type} but found function ${name} returning ${returnType}`;
-          this.errors.push({ message });
-        }
-      }
-
-      if (!clause.multiple) {
-        const expectedArgsLength = clause.args.length;
-        const maxArgCount = clause.hasOptions
-          ? expectedArgsLength + 1
-          : expectedArgsLength;
-        if (args.length < expectedArgsLength || args.length > maxArgCount) {
-          const message = ngettext(
-            msgid`Function ${name} expects ${expectedArgsLength} argument`,
-            `Function ${name} expects ${expectedArgsLength} arguments`,
-            expectedArgsLength,
-          );
-          this.errors.push({ message });
-        }
-
-        // check for argument type matching
-        return args.map((arg, index) => {
-          const argType = clause.args[index];
-          const genericType =
-            argType === "number" || argType === "string"
-              ? "expression"
-              : argType;
-          this.typeStack.unshift(genericType);
-          const result = this.visit(arg);
-          this.typeStack.shift();
-          return result;
-        });
-      }
-    }
-
-    identifierExpression(ctx) {
-      const type = this.typeStack[0];
-      if (type === "aggregation") {
-        ctx.resolveAs = "metric";
-      } else if (type === "boolean") {
-        ctx.resolveAs = "segment";
-      } else {
-        ctx.resolveAs = "dimension";
-        if (type === "aggregation") {
-          throw new Error("Incorrect type for dimension");
-        }
-      }
-      return super.identifierExpression(ctx);
-    }
-
-    numberLiteral(ctx) {
-      const type = this.typeStack[0];
-      if (type === "boolean") {
-        const literal = getIn(ctx, ["NumberLiteral", 0, "image"]);
-        const message = t`Expecting boolean but found ${literal}`;
-        this.errors.push({ message });
-      }
-    }
-
-    stringLiteral(ctx) {
-      const type = this.typeStack[0];
-      if (type === "boolean") {
-        const literal = getIn(ctx, ["StringLiteral", 0, "image"]);
-        const message = t`Expecting boolean but found ${literal}`;
-        this.errors.push({ message });
-      }
-    }
-  }
-  const checker = new TypeChecker();
-  const compactCst = compactSyntaxTree(cst);
-  checker.visit(compactCst);
-  return { typeErrors: checker.errors };
-}
-
-/*
-
-  Create a copy of the syntax tree where the unnecessary intermediate nodes
-  are not present anymore.
-
-  Example:
-  For a simple expression "42", the syntax tree produced by the parser is
-
-  expression <-- this is the root node
-    relationalExpression
-      additionExpression
-        multiplicationExpression
-          atomicExpression
-            numberLiteral
-
-  Meanwhile, the compact variant of the syntax tree:
-
-    numberLiteral
-
-*/
-
-export function compactSyntaxTree(node) {
-  if (!node) {
-    return;
-  }
-  const { name, children } = node;
-
-  switch (name) {
-    case "any":
-    case "aggregation":
-    case "atomicExpression":
-    case "boolean":
-    case "booleanExpression":
-    case "booleanUnaryExpression":
-    case "expression":
-    case "parenthesisExpression":
-    case "string":
-      if (children.expression) {
-        const expression = children.expression.map(compactSyntaxTree);
-        return expression.length === 1
-          ? expression[0]
-          : { name, children: { expression: expression } };
-      }
-      break;
-
-    case "logicalNotExpression":
-      if (children.operands) {
-        const operands = children.operands.map(compactSyntaxTree);
-        const operators = children.operators;
-        return { name, children: { operators, operands } };
-      }
-      break;
-
-    case "additionExpression":
-    case "multiplicationExpression":
-    case "logicalAndExpression":
-    case "logicalOrExpression":
-    case "relationalExpression":
-      if (children.operands) {
-        const operands = children.operands.map(compactSyntaxTree);
-        const operators = children.operators;
-        return operands.length === 1
-          ? operands[0]
-          : { name, children: { operators, operands } };
-      }
-      break;
-
-    case "functionExpression":
-    case "caseExpression": {
-      const { functionName, LParen, RParen } = children;
-      const args = children.arguments
-        ? children.arguments.map(compactSyntaxTree)
-        : [];
-      return {
-        name,
-        children: { functionName, arguments: args, LParen, RParen },
-      };
-    }
-
-    default:
-      break;
-  }
-
-  return { name, children };
-}
diff --git a/frontend/test/metabase/lib/expressions/compile.unit.spec.js b/frontend/test/metabase/lib/expressions/compile.unit.spec.js
deleted file mode 100644
index 4de90fb4125..00000000000
--- a/frontend/test/metabase/lib/expressions/compile.unit.spec.js
+++ /dev/null
@@ -1,181 +0,0 @@
-// import { compile } from "metabase/lib/expressions/compile";
-
-import {
-  shared,
-  aggregationOpts,
-  expressionOpts,
-} from "./__support__/expressions";
-
-const ENABLE_PERF_TESTS = false; //!process.env["CI"];
-
-function expectFast(fn, milliseconds = 1000) {
-  const start = Date.now();
-  fn();
-  const end = Date.now();
-  if (ENABLE_PERF_TESTS) {
-    expect(end - start).toBeLessThan(milliseconds);
-  }
-}
-
-describe("metabase/lib/expressions/compile", () => {
-  let compile;
-  it("should load compile quickly", () => {
-    expectFast(() => {
-      ({ compile } = require("metabase/lib/expressions/compile"));
-    });
-  });
-
-  describe("compile()", () => {
-    for (const [name, cases, opts] of shared) {
-      describe(name, () => {
-        for (const [source, mbql, description] of cases) {
-          if (mbql) {
-            it(`should compile ${description}`, () => {
-              expectFast(() => {
-                expect(compile({ source, ...opts })).toEqual(mbql);
-              }, 250);
-            });
-          } else {
-            it(`should not compile ${description}`, () => {
-              expectFast(() => {
-                expect(() => compile({ source, ...opts })).toThrow();
-              }, 250);
-            });
-          }
-        }
-      });
-    }
-
-    // NOTE: only add tests below for things that don't fit the shared test cases above
-
-    it("should throw exception on invalid input", () => {
-      expect(() => compile({ source: "1 + ", ...expressionOpts })).toThrow();
-    });
-
-    it("should treat aggregations as case-insensitive", () => {
-      expect(compile({ source: "count", ...aggregationOpts })).toEqual([
-        "count",
-      ]);
-      expect(compile({ source: "cOuNt", ...aggregationOpts })).toEqual([
-        "count",
-      ]);
-      expect(compile({ source: "average(A)", ...aggregationOpts })).toEqual([
-        "avg",
-        ["field", 1, null],
-      ]);
-    });
-
-    it("should not take a long time to parse long string literals", () => {
-      expectFast(() => {
-        try {
-          compile({
-            source: '"12345678901234567901234567890',
-            ...expressionOpts,
-          });
-        } catch (e) {}
-      });
-    });
-  });
-
-  function mockResolve(kind, name) {
-    return [kind, name];
-  }
-  function compileSource(source, startRule) {
-    let mbql = null;
-    try {
-      mbql = compile({ source, startRule, resolve: mockResolve });
-    } catch (e) {
-      let err = e;
-      if (err.length && err.length > 0) {
-        err = err[0];
-        if (typeof err.message === "string") {
-          err = err.message;
-        }
-      }
-      throw err;
-    }
-    return mbql;
-  }
-
-  describe("(for an expression)", () => {
-    function expr(source) {
-      return compileSource(source, "expression");
-    }
-    it("should compile literals", () => {
-      expect(expr("42")).toEqual(42);
-      expect(expr("'Universe'")).toEqual("Universe");
-    });
-    it("should compile dimensions", () => {
-      expect(expr("[Price]")).toEqual(["dimension", "Price"]);
-      expect(expr("([X])")).toEqual(["dimension", "X"]);
-    });
-    it("should compile arithmetic operations", () => {
-      expect(expr("1+2")).toEqual(["+", 1, 2]);
-      expect(expr("3-4")).toEqual(["-", 3, 4]);
-      expect(expr("5*6")).toEqual(["*", 5, 6]);
-      expect(expr("7/8")).toEqual(["/", 7, 8]);
-    });
-    it("should compile comparisons", () => {
-      expect(expr("1<2")).toEqual(["<", 1, 2]);
-      expect(expr("3>4")).toEqual([">", 3, 4]);
-      expect(expr("5<=6")).toEqual(["<=", 5, 6]);
-      expect(expr("7>=8")).toEqual([">=", 7, 8]);
-      expect(expr("9=9")).toEqual(["=", 9, 9]);
-      expect(expr("9!=0")).toEqual(["!=", 9, 0]);
-    });
-    it("should handle parenthesized expression", () => {
-      expect(expr("(42)")).toEqual(42);
-      expect(expr("((43))")).toEqual(43);
-      expect(expr("('Universe')")).toEqual("Universe");
-      expect(expr("(('Answer'))")).toEqual("Answer");
-      expect(expr("(1+2)")).toEqual(["+", 1, 2]);
-      expect(expr("(1+2)/3")).toEqual(["/", ["+", 1, 2], 3]);
-      expect(expr("4-(5*6)")).toEqual(["-", 4, ["*", 5, 6]]);
-    });
-  });
-
-  describe("(for a filter)", () => {
-    function filter(source) {
-      return compileSource(source, "boolean");
-    }
-    it("should compile logical operations", () => {
-      expect(filter("NOT A")).toEqual(["not", ["segment", "A"]]);
-      expect(filter("NOT 0")).toEqual(["not", 0]);
-      expect(filter("NOT 'Answer'")).toEqual(["not", "Answer"]);
-      expect(filter("NOT NOT 0")).toEqual(["not", ["not", 0]]);
-      expect(filter("1 OR 2")).toEqual(["or", 1, 2]);
-      expect(filter("2 AND 3")).toEqual(["and", 2, 3]);
-      expect(filter("1 OR 2 AND 3")).toEqual(["or", 1, ["and", 2, 3]]);
-      expect(filter("NOT 4 OR 5")).toEqual(["or", ["not", 4], 5]);
-    });
-    it("should compile comparisons", () => {
-      expect(filter("Tax>5")).toEqual([">", ["dimension", "Tax"], 5]);
-      expect(filter("X=0")).toEqual(["=", ["dimension", "X"], 0]);
-    });
-    it("should compile segments", () => {
-      expect(filter("[Expensive]")).toEqual(["segment", "Expensive"]);
-      expect(filter("NOT [Good]")).toEqual(["not", ["segment", "Good"]]);
-      expect(filter("NOT Answer")).toEqual(["not", ["segment", "Answer"]]);
-    });
-    it("should compile negative filters", () => {
-      expect(filter("NOT CONTAINS('X','Y')")).toEqual([
-        "does-not-contain",
-        "X",
-        "Y",
-      ]);
-      expect(filter("NOT ISNULL('P')")).toEqual(["not-null", "P"]);
-      expect(filter("NOT ISEMPTY('Q')")).toEqual(["not-empty", "Q"]);
-    });
-  });
-
-  describe("(for an aggregation)", () => {
-    function aggr(source) {
-      return compileSource(source, "aggregation");
-    }
-    it("should handle metric vs dimension vs segment", () => {
-      expect(aggr("[TotalOrder]")).toEqual(["metric", "TotalOrder"]);
-      expect(aggr("AVERAGE(X)")).toEqual(["avg", ["dimension", "X"]]);
-      expect(aggr("COUNTIF(Y)")).toEqual(["count-where", ["segment", "Y"]]);
-    });
-  });
-});
diff --git a/frontend/test/metabase/lib/expressions/diagnostics.unit.spec.js b/frontend/test/metabase/lib/expressions/diagnostics.unit.spec.js
index 699759122ae..a357f6191d8 100644
--- a/frontend/test/metabase/lib/expressions/diagnostics.unit.spec.js
+++ b/frontend/test/metabase/lib/expressions/diagnostics.unit.spec.js
@@ -48,7 +48,7 @@ describe("metabase/lib/expressions/diagnostics", () => {
   });
 
   it("should show the correct number of function arguments in a custom expression", () => {
-    expect(diagnose("contains([Category])").message).toEqual(
+    expect(diagnose("contains([Category])", "boolean").message).toEqual(
       "Function contains expects 2 arguments",
     );
   });
diff --git a/frontend/test/metabase/lib/expressions/parser.unit.spec.js b/frontend/test/metabase/lib/expressions/parser.unit.spec.js
index bfceb6583a5..68a213c7e00 100644
--- a/frontend/test/metabase/lib/expressions/parser.unit.spec.js
+++ b/frontend/test/metabase/lib/expressions/parser.unit.spec.js
@@ -5,9 +5,6 @@ describe("metabase/lib/expressions/parser", () => {
     let result = null;
     try {
       result = parse({ source, tokenVector: null, startRule });
-      if (result.typeErrors.length > 0) {
-        throw new Error(result.typeErrors);
-      }
     } catch (e) {
       let err = e;
       if (err.length && err.length > 0) {
@@ -74,12 +71,6 @@ describe("metabase/lib/expressions/parser", () => {
         parseExpression("case(isempty([Discount]),[P])"),
       ).not.toThrow();
     });
-    it("should reject CASE with only one argument", () => {
-      expect(() => parseExpression("case([Deal])")).toThrow();
-    });
-    it("should accept CASE with two arguments", () => {
-      expect(() => parseExpression("case([Deal],x)")).not.toThrow();
-    });
   });
 
   describe("(in aggregation mode)", () => {
@@ -149,11 +140,5 @@ describe("metabase/lib/expressions/parser", () => {
     it("should accept a function", () => {
       expect(() => parseFilter("between([Subtotal], 1, 2)")).not.toThrow();
     });
-    it("should reject CASE with only one argument", () => {
-      expect(() => parseFilter("case([Deal])")).toThrow();
-    });
-    it("should reject a number on the left-hand side", () => {
-      expect(() => parseFilter("10 < [DiscountPercent]")).toThrow();
-    });
   });
 });
diff --git a/frontend/test/metabase/lib/expressions/resolver.unit.spec.js b/frontend/test/metabase/lib/expressions/resolver.unit.spec.js
index f21e244d8e7..835fe83199e 100644
--- a/frontend/test/metabase/lib/expressions/resolver.unit.spec.js
+++ b/frontend/test/metabase/lib/expressions/resolver.unit.spec.js
@@ -32,6 +32,8 @@ describe("metabase/lib/expressions/resolve", () => {
   const Q = ["dimension", "Q"];
   const R = ["dimension", "R"];
   const S = ["dimension", "S"];
+  const X = ["segment", "X"];
+  const Y = ["dimension", "Y"];
 
   describe("for filters", () => {
     const filter = e => collect(e, "boolean");
@@ -44,8 +46,8 @@ describe("metabase/lib/expressions/resolve", () => {
       expect(filter(["and", ["<", Q, 1], R]).segments).toEqual(["R"]);
       expect(filter(["is-null", S]).segments).toEqual([]);
       expect(filter(["not-empty", S]).segments).toEqual([]);
-      expect(filter(["lower", A]).segments).toEqual([]);
-      expect(filter(["sqrt", B]).segments).toEqual([]);
+      expect(filter([">", ["lower", A], "X"]).segments).toEqual([]);
+      expect(filter(["<", ["sqrt", B], 1]).segments).toEqual([]);
       expect(filter(["contains", C, "SomeString"]).segments).toEqual([]);
       expect(filter(["or", P, [">", Q, 3]]).segments).toEqual(["P"]);
     });
@@ -58,11 +60,57 @@ describe("metabase/lib/expressions/resolve", () => {
       expect(filter(["and", ["<", Q, 1], R]).dimensions).toEqual(["Q"]);
       expect(filter(["is-null", Q]).dimensions).toEqual(["Q"]);
       expect(filter(["not-empty", S]).dimensions).toEqual(["S"]);
-      expect(filter(["lower", A]).dimensions).toEqual(["A"]);
-      expect(filter(["sqrt", B]).dimensions).toEqual(["B"]);
+      expect(filter([">", ["lower", A], "X"]).dimensions).toEqual(["A"]);
+      expect(filter(["<", ["sqrt", B], 1]).dimensions).toEqual(["B"]);
       expect(filter(["contains", C, "SomeString"]).dimensions).toEqual(["C"]);
       expect(filter(["or", P, [">", Q, 3]]).dimensions).toEqual(["Q"]);
     });
+
+    it("should reject a number literal", () => {
+      expect(() => filter("3.14159")).toThrow();
+    });
+
+    it("should reject a string literal", () => {
+      expect(() => filter('"TheAnswer"')).toThrow();
+    });
+
+    it("should catch mismatched number of function parameters", () => {
+      expect(() => filter(["contains"])).toThrow();
+      expect(() => filter(["contains", Y])).toThrow();
+      expect(() => filter(["contains", Y, "A", "B", "C"])).toThrow();
+      expect(() => filter(["starts-with"])).toThrow();
+      expect(() => filter(["starts-with", A])).toThrow();
+      expect(() => filter(["starts-with", A, "P", "Q", "R"])).toThrow();
+      expect(() => filter(["ends-with"])).toThrow();
+      expect(() => filter(["ends-with", B])).toThrow();
+      expect(() => filter(["ends-with", B, "P", "Q", "R"])).toThrow();
+    });
+
+    it("should allow a comparison (lexicographically) on strings", () => {
+      // P <= "abc"
+      expect(() => filter(["<=", P, "abc"])).not.toThrow();
+    });
+
+    it("should allow a comparison (lexicographically) on functions returning string", () => {
+      // Lower([A]) <= "P"
+      expect(() => filter(["<=", ["lower", A], "P"])).not.toThrow();
+    });
+
+    it("should reject a less/greater comparison on functions returning boolean", () => {
+      // IsEmpty([A]) < 0
+      expect(() => filter(["<", ["is-empty", A], 0])).toThrow();
+    });
+
+    // backward-compatibility
+    it("should reject a literal on the left-hand side of a comparison", () => {
+      // 0 < [A]
+      expect(() => filter(["<", 0, A])).toThrow();
+    });
+
+    it("should work on functions with optional flag", () => {
+      const flag = { "include-current": true };
+      expect(() => filter(["time-interval", A, 3, "day", flag])).not.toThrow();
+    });
   });
 
   describe("for expressions (for custom columns)", () => {
@@ -85,6 +133,12 @@ describe("metabase/lib/expressions/resolve", () => {
       expect(expr(["coalesce", P]).dimensions).toEqual(["P"]);
       expect(expr(["coalesce", P, Q, R]).dimensions).toEqual(["P", "Q", "R"]);
     });
+
+    it("should allow any number of arguments in a variadic function", () => {
+      expect(() => expr(["concat", "1"])).not.toThrow();
+      expect(() => expr(["concat", "1", "2"])).not.toThrow();
+      expect(() => expr(["concat", "1", "2", "3"])).not.toThrow();
+    });
   });
 
   describe("for aggregations", () => {
@@ -107,6 +161,11 @@ describe("metabase/lib/expressions/resolve", () => {
       expect(aggregation(["max", ["*", 4, Q]]).metrics).toEqual([]);
       expect(aggregation(["+", R, ["median", S]]).metrics).toEqual(["R"]);
     });
+
+    it("should accept PERCENTILE with two arguments", () => {
+      // PERCENTILE(A, 0.5)
+      expect(() => aggregation(["percentile", A, 0.5])).not.toThrow();
+    });
   });
 
   describe("for CASE expressions", () => {
@@ -153,16 +212,39 @@ describe("metabase/lib/expressions/resolve", () => {
       expect(expr(["coalesce", ["case", [[A, B]]]]).segments).toEqual(["A"]);
       expect(expr(["coalesce", ["case", [[A, B]]]]).dimensions).toEqual(["B"]);
     });
-  });
 
-  it("should handle unknown MBQL gracefully", () => {
-    expect(() => collect(["abc-xyz", B])).not.toThrow();
+    it("should reject a CASE expression with only one argument", () => {
+      // CASE(X)
+      expect(() => expr(["case", [], { default: Y }])).toThrow();
+    });
+    it("should reject a CASE expression with incorrect argument type", () => {
+      // CASE(X, 1, 2, 3)
+      expect(() =>
+        expr([
+          "case",
+          [
+            [X, 1],
+            [2, 3],
+          ],
+        ]),
+      ).toThrow();
+    });
+
+    it("should accept a CASE expression with complex arguments", () => {
+      // CASE(X, 0.5*Y, A-B)
+      const def = { default: ["-", A, B] };
+      expect(() => expr(["case", [[X, ["*", 0.5, Y]]], def])).not.toThrow();
+    });
   });
 
   it("should not fail on literal 0", () => {
     const opt = { default: 0 };
-    expect(resolve(["case", [[1, 0]]])).toEqual(["case", [[1, 0]]]);
-    expect(resolve(["case", [[1, 0]], opt])).toEqual(["case", [[1, 0]], opt]);
-    expect(resolve(["case", [[1, 2]], opt])).toEqual(["case", [[1, 2]], opt]);
+    expect(resolve(["case", [[X, 0]]])).toEqual(["case", [[X, 0]]]);
+    expect(resolve(["case", [[X, 0]], opt])).toEqual(["case", [[X, 0]], opt]);
+    expect(resolve(["case", [[X, 2]], opt])).toEqual(["case", [[X, 2]], opt]);
+  });
+
+  it("should reject unknown function", () => {
+    expect(() => resolve(["foobar", 42])).toThrow();
   });
 });
diff --git a/frontend/test/metabase/lib/expressions/typechecker.unit.spec.js b/frontend/test/metabase/lib/expressions/typechecker.unit.spec.js
index 9a3795f9115..b81c89008ce 100644
--- a/frontend/test/metabase/lib/expressions/typechecker.unit.spec.js
+++ b/frontend/test/metabase/lib/expressions/typechecker.unit.spec.js
@@ -1,7 +1,5 @@
 import { parse } from "metabase/lib/expressions/parser";
-import { ExpressionVisitor } from "metabase/lib/expressions/visitor";
-import { parseIdentifierString } from "metabase/lib/expressions/index";
-import { compactSyntaxTree } from "metabase/lib/expressions/typechecker";
+import { compile } from "metabase/lib/expressions/compile";
 
 // Since the type checking is inserted as the last stage in the expression parser,
 // the whole tests must continue to pass (i.e. none of them should thrown
@@ -9,89 +7,14 @@ import { compactSyntaxTree } from "metabase/lib/expressions/typechecker";
 
 describe("type-checker", () => {
   function parseSource(source, startRule) {
-    let cst = null;
-    let typeErrors = [];
-    try {
-      const result = parse({ source, tokenVector: null, startRule });
-      cst = result.cst;
-      typeErrors = result.typeErrors;
-    } catch (e) {
-      let err = e;
-      if (err.length && err.length > 0) {
-        err = err[0];
-        if (typeof err.message === "string") {
-          err = err.message;
-        }
-      }
-      throw err;
-    }
-    return { cst, typeErrors };
-  }
-
-  function collect(source, startRule) {
-    class Collector extends ExpressionVisitor {
-      constructor() {
-        super();
-        this.metrics = [];
-        this.segments = [];
-        this.dimensions = [];
-      }
-      identifier(ctx) {
-        return ctx.Identifier[0].image;
-      }
-      identifierString(ctx) {
-        return parseIdentifierString(ctx.IdentifierString[0].image);
-      }
-      identifierExpression(ctx) {
-        const name = this.visit(ctx.identifierName);
-        if (ctx.resolveAs === "metric") {
-          this.metrics.push(name);
-        } else if (ctx.resolveAs === "segment") {
-          this.segments.push(name);
-        } else {
-          this.dimensions.push(name);
-        }
-      }
-    }
-    const { cst } = parseSource(source, startRule);
-    const collector = new Collector();
-    collector.visit(cst);
-    return collector;
+    const mockResolve = (kind, name) => [kind, name];
+    const { cst } = parse({ source, tokenVector: null, startRule });
+    compile({ source, startRule, cst, resolve: mockResolve });
+    return { cst };
   }
 
   describe("for an expression", () => {
-    function expr(source) {
-      return collect(source, "expression");
-    }
-    function validate(source) {
-      const { typeErrors } = parseSource(source, "expression");
-      if (typeErrors.length > 0) {
-        throw new Error(typeErrors[0].message);
-      }
-    }
-
-    it("should resolve dimensions correctly", () => {
-      expect(expr("[Price]+[Tax]").dimensions).toEqual(["Price", "Tax"]);
-      expect(expr("ABS([Discount])").dimensions).toEqual(["Discount"]);
-      expect(expr("CASE([Deal],10,20)").dimensions).toEqual([]);
-    });
-
-    it("should resolve segments correctly", () => {
-      expect(expr("[Price]+[Tax]").segments).toEqual([]);
-      expect(expr("ABS([Discount])").segments).toEqual([]);
-      expect(expr("CASE([Deal],10,20)").segments).toEqual(["Deal"]);
-    });
-
-    it("should resolve dimensions and segments correctly", () => {
-      expect(expr("[X]+CASE([Y],4,5)").dimensions).toEqual(["X"]);
-      expect(expr("[X]+CASE([Y],4,5)").segments).toEqual(["Y"]);
-      expect(expr("CASE([Z]>100,'Pricey')").dimensions).toEqual(["Z"]);
-      expect(expr("CASE([Z]>100,'Pricey')").segments).toEqual([]);
-      expect(expr("CASE(A,B,C)").dimensions).toEqual(["B", "C"]);
-      expect(expr("CASE(A,B,C)").segments).toEqual(["A"]);
-      expect(expr("CASE(P,Q,R,S)").dimensions).toEqual(["Q", "S"]);
-      expect(expr("CASE(P,Q,R,S)").segments).toEqual(["P", "R"]);
-    });
+    const validate = e => parseSource(e, "expression");
 
     it("should allow any number of arguments in a variadic function", () => {
       expect(() => validate("CONCAT('1')")).not.toThrow();
@@ -110,6 +33,10 @@ describe("type-checker", () => {
     it("should accept a CASE expression with complex arguments", () => {
       expect(() => validate("CASE(Deal, 0.5*X, Y-Z)")).not.toThrow();
     });
+  });
+
+  describe("for an aggregation", () => {
+    const validate = e => parseSource(e, "aggregation");
 
     it("should accept PERCENTILE with two arguments", () => {
       expect(() => validate("PERCENTILE([Rating], .5)")).not.toThrow();
@@ -117,51 +44,7 @@ describe("type-checker", () => {
   });
 
   describe("for a filter", () => {
-    function filter(source) {
-      return collect(source, "boolean");
-    }
-    function validate(source) {
-      const { typeErrors } = parseSource(source, "boolean");
-      if (typeErrors.length > 0) {
-        throw new Error(typeErrors[0].message);
-      }
-    }
-    it("should resolve segments correctly", () => {
-      expect(filter("[Clearance]").segments).toEqual(["Clearance"]);
-      expect(filter("NOT [Deal]").segments).toEqual(["Deal"]);
-      expect(filter("NOT NOT [Deal]").segments).toEqual(["Deal"]);
-      expect(filter("P > 3").segments).toEqual([]);
-      expect(filter("R<1 AND [S]>4").segments).toEqual([]);
-      expect(filter("5 <= Q").segments).toEqual([]);
-      expect(filter("Between([BIG],3,7)").segments).toEqual([]);
-      expect(filter("Contains([GI],'Joe')").segments).toEqual([]);
-      expect(filter("IsEmpty([Discount])").segments).toEqual([]);
-      expect(filter("IsNull([Tax])").segments).toEqual([]);
-    });
-
-    it("should resolve dimensions correctly", () => {
-      expect(filter("[Clearance]").dimensions).toEqual([]);
-      expect(filter("NOT [Deal]").dimensions).toEqual([]);
-      expect(filter("NOT NOT [Deal]").dimensions).toEqual([]);
-      expect(filter("P > 3").dimensions).toEqual(["P"]);
-      expect(filter("R<1 AND [S]>4").dimensions).toEqual(["R", "S"]);
-      expect(filter("5 <= Q").dimensions).toEqual(["Q"]);
-      expect(filter("Between([BIG],3,7)").dimensions).toEqual(["BIG"]);
-      expect(filter("Contains([GI],'Joe')").dimensions).toEqual(["GI"]);
-      expect(filter("IsEmpty([Discount])").dimensions).toEqual(["Discount"]);
-      expect(filter("IsNull([Tax])").dimensions).toEqual(["Tax"]);
-    });
-
-    it("should resolve dimensions and segments correctly", () => {
-      expect(filter("[A] OR [B]>0").segments).toEqual(["A"]);
-      expect(filter("[A] OR [B]>0").dimensions).toEqual(["B"]);
-      expect(filter("[X]=4 AND NOT [Y]").segments).toEqual(["Y"]);
-      expect(filter("[X]=4 AND NOT [Y]").dimensions).toEqual(["X"]);
-      expect(filter("T OR Between([R],0,9)").segments).toEqual(["T"]);
-      expect(filter("T OR Between([R],0,9)").dimensions).toEqual(["R"]);
-      expect(filter("NOT between(P, 3, 14) OR Q").dimensions).toEqual(["P"]);
-      expect(filter("NOT between(P, 3, 14) OR Q").segments).toEqual(["Q"]);
-    });
+    const validate = e => parseSource(e, "boolean");
 
     it("should reject a number literal", () => {
       expect(() => validate("3.14159")).toThrow();
@@ -183,113 +66,16 @@ describe("type-checker", () => {
     });
 
     it("should allow a comparison (lexicographically) on strings", () => {
-      expect(() => validate("'A' <= 'B'")).not.toThrow();
+      expect(() => validate("A <= 'B'")).not.toThrow();
     });
 
     it("should allow a comparison (lexicographically) on functions returning string", () => {
       expect(() => validate("Lower([State]) > 'AB'")).not.toThrow();
     });
 
-    it("should allow a comparison on the result of COALESCE", () => {
-      expect(() => validate("Coalesce([X],[Y]) > 0")).not.toThrow();
-    });
-
     it("should reject a less/greater comparison on functions returning boolean", () => {
       expect(() => validate("IsEmpty([Tax]) < 5")).toThrow();
       expect(() => validate("IsEmpty([Tax]) >= 0")).toThrow();
     });
   });
-
-  describe("for an aggregation", () => {
-    function aggregation(source) {
-      return collect(source, "aggregation");
-    }
-    it("should resolve dimensions correctly", () => {
-      expect(aggregation("Sum([Discount])").dimensions).toEqual(["Discount"]);
-      expect(aggregation("5-Average([Rating])").dimensions).toEqual(["Rating"]);
-      expect(aggregation("Share(contains([P],'Q'))").dimensions).toEqual(["P"]);
-      expect(aggregation("CountIf([Tax]>13)").dimensions).toEqual(["Tax"]);
-      expect(aggregation("Sum([Total]*2)").dimensions).toEqual(["Total"]);
-      expect(aggregation("[Total]").dimensions).toEqual([]);
-      expect(aggregation("CountIf(4>[A]+[B])").dimensions).toEqual(["A", "B"]);
-    });
-
-    it("should resolve metrics correctly", () => {
-      expect(aggregation("Sum([Discount])").metrics).toEqual([]);
-      expect(aggregation("5-Average([Rating])").metrics).toEqual([]);
-      expect(aggregation("Share(contains([P],'Q'))").metrics).toEqual([]);
-      expect(aggregation("CountIf([Tax]>13)").metrics).toEqual([]);
-      expect(aggregation("Sum([Total]*2)").metrics).toEqual([]);
-      expect(aggregation("[Total]").metrics).toEqual(["Total"]);
-      expect(aggregation("CountIf(4>[A]+[B])").metrics).toEqual([]);
-    });
-
-    it("should resolve dimensions and metrics correctly", () => {
-      expect(aggregation("[X]+Sum([Y])").dimensions).toEqual(["Y"]);
-      expect(aggregation("[X]+Sum([Y])").metrics).toEqual(["X"]);
-    });
-  });
-
-  describe("compactSyntaxTree", () => {
-    function exprRoot(source) {
-      const tokenVector = null;
-      const startRule = "expression";
-      const { cst } = parse({ source, tokenVector, startRule });
-      const compactCst = compactSyntaxTree(cst);
-      const { name } = compactCst;
-      return name;
-    }
-    function filterRoot(source) {
-      const tokenVector = null;
-      const startRule = "boolean";
-      const { cst } = parse({ source, tokenVector, startRule });
-      const compactCst = compactSyntaxTree(cst);
-      const { name } = compactCst;
-      return name;
-    }
-
-    it("should handle literals", () => {
-      expect(exprRoot("42")).toEqual("numberLiteral");
-      expect(exprRoot("(43)")).toEqual("numberLiteral");
-      expect(exprRoot("'Answer'")).toEqual("stringLiteral");
-      expect(exprRoot('"Answer"')).toEqual("stringLiteral");
-      expect(exprRoot('("The Answer")')).toEqual("stringLiteral");
-    });
-    it("should handle binary expressions", () => {
-      expect(exprRoot("1+2")).toEqual("additionExpression");
-      expect(exprRoot("3-4")).toEqual("additionExpression");
-      expect(exprRoot("1+2-3")).toEqual("additionExpression");
-      expect(exprRoot("(1+2-3)")).toEqual("additionExpression");
-      expect(exprRoot("(1+2)-3")).toEqual("additionExpression");
-      expect(exprRoot("1+(2-3)")).toEqual("additionExpression");
-      expect(exprRoot("5*6")).toEqual("multiplicationExpression");
-      expect(exprRoot("7/8")).toEqual("multiplicationExpression");
-      expect(exprRoot("5*6/7")).toEqual("multiplicationExpression");
-      expect(exprRoot("(5*6/7)")).toEqual("multiplicationExpression");
-      expect(exprRoot("5*(6/7)")).toEqual("multiplicationExpression");
-      expect(exprRoot("(5*6)/7")).toEqual("multiplicationExpression");
-    });
-    it("should handle function expressions", () => {
-      expect(exprRoot("LOWER(A)")).toEqual("functionExpression");
-      expect(exprRoot("UPPER(B)")).toEqual("functionExpression");
-      expect(filterRoot("BETWEEN(C,0,9)")).toEqual("functionExpression");
-    });
-    it("should handle case expressions", () => {
-      expect(exprRoot("CASE(X,1)")).toEqual("caseExpression");
-      expect(exprRoot("CASE(Y,2,3)")).toEqual("caseExpression");
-    });
-    it("should handle relational expressions", () => {
-      expect(filterRoot("1<2")).toEqual("relationalExpression");
-      expect(filterRoot("3>4")).toEqual("relationalExpression");
-      expect(filterRoot("5=6")).toEqual("relationalExpression");
-      expect(filterRoot("7!=8")).toEqual("relationalExpression");
-    });
-    it("should handle logical expressions", () => {
-      expect(filterRoot("A AND B")).toEqual("logicalAndExpression");
-      expect(filterRoot("C OR D")).toEqual("logicalOrExpression");
-      expect(filterRoot("A AND B OR C")).toEqual("logicalOrExpression");
-      expect(filterRoot("NOT E")).toEqual("logicalNotExpression");
-      expect(filterRoot("NOT NOT F")).toEqual("logicalNotExpression");
-    });
-  });
 });
diff --git a/frontend/test/metabase/lib/expressions/typeinferencer.unit.spec.js b/frontend/test/metabase/lib/expressions/typeinferencer.unit.spec.js
index 95e160557f4..fca678d1dd0 100644
--- a/frontend/test/metabase/lib/expressions/typeinferencer.unit.spec.js
+++ b/frontend/test/metabase/lib/expressions/typeinferencer.unit.spec.js
@@ -16,7 +16,7 @@ describe("metabase/lib/expressions/typeinferencer", () => {
   // workaround the limitation of the parsing expecting a strict top-level grammar rule
   function tryCompile(source) {
     let mbql = compileAs(source, "expression");
-    if (!mbql) {
+    if (mbql === null) {
       mbql = compileAs(source, "boolean");
     }
     return mbql;
@@ -58,7 +58,7 @@ describe("metabase/lib/expressions/typeinferencer", () => {
   it("should infer the result of comparisons", () => {
     expect(type("[Discount] > 0")).toEqual("boolean");
     expect(type("[Revenue] <= [Limit] * 2")).toEqual("boolean");
-    expect(type("1 != 2")).toEqual("boolean");
+    expect(type("[Price] != 2")).toEqual("boolean");
   });
 
   it("should infer the result of logical operations", () => {
diff --git a/frontend/test/metabase/scenarios/custom-column/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js b/frontend/test/metabase/scenarios/question/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js
similarity index 84%
rename from frontend/test/metabase/scenarios/custom-column/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js
rename to frontend/test/metabase/scenarios/question/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js
index c33dbb81fbd..4456136e015 100644
--- a/frontend/test/metabase/scenarios/custom-column/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js
+++ b/frontend/test/metabase/scenarios/question/reproductions/15714-cc-postgres-percentile-accepts-two-params.cy.spec.js
@@ -14,11 +14,12 @@ describe("postgres > question > custom columns", () => {
   });
 
   it("`Percentile` custom expression function should accept two parameters (metabase#15714)", () => {
-    cy.icon("add_data").click();
+    cy.findByText("Pick the metric you want to see").click();
+    cy.findByText("Custom Expression").click();
     cy.get("[contenteditable='true']")
       .click()
       .type("Percentile([Subtotal], 0.1)");
-    cy.findByPlaceholderText("Something nice and descriptive")
+    cy.findByPlaceholderText("Name (required)")
       .as("description")
       .click();
 
-- 
GitLab