From 1290ecf78dd1eb74b5b871ce7c155c92fb8053ee Mon Sep 17 00:00:00 2001
From: dpsutton <dan@dpsutton.com>
Date: Mon, 20 Nov 2023 14:46:06 -0600
Subject: [PATCH] Alternative fix mu defn (#35939)

* Revert "Clarify stacktraces from malli validation errors (#34712)"

This reverts commit 1a0f2d6587bc8ecad652c5722a56a615ec8b10c7.

* Alternative fix for mu/defn correctness issue

https://github.com/metabase/metabase/pull/34712 introduced a correctness
bug:

```clojure
(macroexpand '(metabase.util.malli/defn foo [x :- string?] :foo))

(def
 foo
 "Inputs: [x :- string?]\n  Return: :any"
 (clojure.core/fn
  ([a] (metabase.util.malli.fn/validate-input {:fn-name 'user/foo} string? a) ((clojure.core/fn [x] :foo) a))))
```

The simplest demo of the bug is

```clojure
user=> (metabase.util.malli/defn foo [x] a)
 #'user/foo
```

The macro used simple and predictable args for it's emitted
functions. This was not a problem with the original shape:

(different macroexpansions of `(mu/defn foo [x :- int?] (inc x))`):

prior to 34712:

```clojure
(def
 foo
 "Inputs: [x :- int?]\n  Return: :any"
 (clojure.core/let
  [&f (clojure.core/fn [x] (inc x))]
  (clojure.core/fn
   ([a] ;; simple argument name not a problem
    (metabase.util.malli.fn/validate-input
      {:fn-name (quote metabase.util.malli.fn-test/foo)}
      int?
      a)
    (&f a)))))
```

after 34712:

```clojure
(def
 foo
 "Inputs: [x :- int?]\n  Return: :any"
 (clojure.core/fn
  ([a] ;; simple argument name is now a problem
   (metabase.util.malli.fn/validate-input {:fn-name 'user/foo} int? a)
   ((clojure.core/fn [x] (inc x))
    a))))
```

after this PR:
```clojure
(def
 foo
 "Inputs: [x :- int?]\n  Return: :any"
 (clojure.core/let
  [&f (clojure.core/fn [x] (inc x))]
  (clojure.core/fn
   ([a] ;; again simple argument name is not a problem
    (try
     (metabase.util.malli.fn/validate-input
      {:fn-name (quote metabase.util.malli.fn-test/foo)}
      int?
      a)
     (&f a)
     (catch
      java.lang.Exception
      error
      (throw (metabase.util.malli.fn/fixup-stacktrace error))))))))
```

And the fix here is just to remove the StackTraceElements which
reference `validate`, `validate-input`, or `validate-output`.

* fix tests

tests are run in the user namespace similar to

```clojure
user=> (clojure.test/run-tests 'metabase.util.malli.fn-test)

Testing metabase.util.malli.fn-test

Ran 8 tests containing 43 assertions.
0 failures, 0 errors.
{:test 8, :pass 43, :fail 0, :error 0, :type :summary}
```

But we still want the namespace of _this_ file m.u.m.fn-test
---
 src/metabase/util/malli/fn.clj       |  86 ++++++++---------
 test/metabase/util/malli/fn_test.clj | 134 +++++++++++++++++++--------
 2 files changed, 137 insertions(+), 83 deletions(-)

diff --git a/src/metabase/util/malli/fn.clj b/src/metabase/util/malli/fn.clj
index e2419c1218c..3ba5f3d4787 100644
--- a/src/metabase/util/malli/fn.clj
+++ b/src/metabase/util/malli/fn.clj
@@ -2,7 +2,6 @@
   (:refer-clojure :exclude [fn])
   (:require
    [clojure.core :as core]
-   [clojure.string :as str]
    [malli.core :as mc]
    [malli.destructure :as md]
    [malli.error :as me]
@@ -146,36 +145,20 @@
   use [[metabase.util.malli/disable-enforcement]] to bind this only in Clojure code."
   true)
 
-(defn- fixup-stacktrace
-  "This function removes stack trace elements that came from this namespace. When we throw validation errors, they
-  shouldn't originate from *this* namespace, they should appear to be thrown from the instrumented function itself."
-  [^Exception e]
-  (let [trace (.getStackTrace e)
-        fixed-trace (into-array StackTraceElement
-                                (drop-while
-                                 #(str/starts-with? (.getClassName ^StackTraceElement %)
-                                                    ;; this is... hacky, but it works.
-                                                    (namespace ::x))
-                                            trace))]
-    (.setStackTrace e fixed-trace)))
-
 (defn- validate [error-context schema value error-type]
   (when *enforce*
     (when-let [error (mr/explain schema value)]
       (let [humanized (me/humanize error)]
-        (throw
-         (doto (ex-info
-                (case error-type
-                  ::invalid-input  (i18n/tru "Invalid input: {0}" (pr-str humanized))
-                  ::invalid-output (i18n/tru "Invalid output: {0}" (pr-str humanized)))
-                (merge
-                 {:type      error-type
-                  :error     error
-                  :humanized humanized
-                  :schema    schema
-                  :value     value}
-                 error-context))
-           fixup-stacktrace))))))
+        (throw (ex-info (case error-type
+                          ::invalid-input  (i18n/tru "Invalid input: {0}" (pr-str humanized))
+                          ::invalid-output (i18n/tru "Invalid output: {0}" (pr-str humanized)))
+                        (merge
+                         {:type      error-type
+                          :error     error
+                          :humanized humanized
+                          :schema    schema
+                          :value     value}
+                         error-context)))))))
 
 (defn validate-input
   "Impl for [[metabase.util.malli.fn/fn]]; validates an input argument with `value` against `schema` using a cached
@@ -231,37 +214,57 @@
               schemas)
          (filter some?))))
 
-(defn- input-schema->application-form [input-schema deparameterized-fn]
+(defn- input-schema->application-form [input-schema]
   (let [arg-names (input-schema-arg-names input-schema)]
     (if (varargs-schema? input-schema)
-      (list* `apply deparameterized-fn arg-names)
-      (list* deparameterized-fn arg-names))))
+      (list* `apply '&f arg-names)
+      (list* '&f arg-names))))
 
-(defn- instrumented-arity [error-context [_=> input-schema output-schema] deparameterized-fn]
+(defn fixup-stacktrace
+  "If exception is thrown from the [[validate]] machinery, remove those stack trace elements so the top of the stack is
+  the calling function."
+  [^Exception e]
+  (if (#{::invalid-input ::invalid-output} (-> e ex-data :type))
+    (let [trace (.getStackTrace e)
+          cleaned (when trace
+                    (into-array StackTraceElement
+                                (drop-while (comp #{(.getName (class validate))
+                                                    (.getName (class validate-input))
+                                                    (.getName (class validate-output))}
+                                                  #(.getClassName ^StackTraceElement %))
+                                            trace)))]
+      (doto e
+        (.setStackTrace cleaned)))
+    e))
+
+(defn- instrumented-arity [error-context [_=> input-schema output-schema]]
   (let [input-schema           (if (= input-schema :cat)
                                  [:cat]
                                  input-schema)
         arglist                (input-schema->arglist input-schema)
         input-validation-forms (input-schema->validation-forms error-context input-schema)
-        result-form            (input-schema->application-form input-schema deparameterized-fn)
+        result-form            (input-schema->application-form input-schema)
         result-form            (if (and output-schema
                                         (not= output-schema :any))
                                  `(->> ~result-form
                                        (validate-output ~error-context ~output-schema))
                                  result-form)]
-    `(~arglist ~@input-validation-forms ~result-form)))
-
-(defn- instrumented-fn-tail [error-context
-                             [schema-type :as schema]
-                             deparameterized-fn]
+    `(~arglist
+      (try
+        ~@input-validation-forms
+        ~result-form
+        (catch Exception ~'error
+          (throw (fixup-stacktrace ~'error)))))))
+
+(defn- instrumented-fn-tail [error-context [schema-type :as schema]]
   (case schema-type
     :=>
-    [(instrumented-arity error-context schema deparameterized-fn)]
+    [(instrumented-arity error-context schema)]
 
     :function
     (let [[_function & schemas] schema]
       (for [schema schemas]
-        (instrumented-arity error-context schema deparameterized-fn)))))
+        (instrumented-arity error-context schema)))))
 
 (defn instrumented-fn-form
   "Given a `fn-tail` like
@@ -275,9 +278,8 @@
     (mc/-instrument {:schema [:=> [:cat :int :any] :any]}
                     (fn [x y] (+ 1 2)))"
   [error-context parsed]
-  `(core/fn ~@(instrumented-fn-tail error-context
-                                    (fn-schema parsed)
-                                    (deparameterized-fn-form parsed))))
+  `(let [~'&f ~(deparameterized-fn-form parsed)]
+     (core/fn ~@(instrumented-fn-tail error-context (fn-schema parsed)))))
 
 (defmacro fn
   "Malli version of [[schema.core/fn]]. A form like
diff --git a/test/metabase/util/malli/fn_test.clj b/test/metabase/util/malli/fn_test.clj
index 723709c58d7..595717ebf9e 100644
--- a/test/metabase/util/malli/fn_test.clj
+++ b/test/metabase/util/malli/fn_test.clj
@@ -1,7 +1,7 @@
 (ns ^:mb/once metabase.util.malli.fn-test
   (:require
-   [clojure.string :as str]
    [clojure.test :refer :all]
+   [clojure.tools.macro :as tools.macro]
    [clojure.walk :as walk]
    [metabase.util.malli :as mu]
    [metabase.util.malli.fn :as mu.fn]
@@ -71,44 +71,61 @@
   (are [form expected] (= expected
                           (walk/macroexpand-all (mu.fn/instrumented-fn-form {} (mu.fn/parse-fn-tail form))))
     '([x :- :int y])
-    '(fn* ([a b]
-           (metabase.util.malli.fn/validate-input {} :int a)
-           ((fn* ([x y])) a b)))
+    '(let* [&f (fn* ([x y]))]
+       (fn* ([a b]
+             (try
+               (metabase.util.malli.fn/validate-input {} :int a)
+               (&f a b)
+               (catch java.lang.Exception error
+                 (throw (metabase.util.malli.fn/fixup-stacktrace error)))))))
 
     '(:- :int [x :- :int y])
-    '(fn* ([a b]
-           (metabase.util.malli.fn/validate-input {} :int a)
-           (metabase.util.malli.fn/validate-output {} :int ((fn* ([x y])) a b))))
+    '(let* [&f (fn* ([x y]))]
+       (fn* ([a b]
+             (try
+               (metabase.util.malli.fn/validate-input {} :int a)
+               (metabase.util.malli.fn/validate-output {} :int (&f a b))
+               (catch java.lang.Exception error
+                 (throw (metabase.util.malli.fn/fixup-stacktrace error)))))))
 
     '(:- :int [x :- :int y] (+ x y))
-    '(fn* ([a b]
-           (metabase.util.malli.fn/validate-input {} :int a)
-           (metabase.util.malli.fn/validate-output {} :int ((fn* ([x y] (+ x y))) a b))))
+    '(let* [&f (fn* ([x y] (+ x y)))]
+       (fn* ([a b]
+             (try
+               (metabase.util.malli.fn/validate-input {} :int a)
+               (metabase.util.malli.fn/validate-output {} :int (&f a b))
+               (catch java.lang.Exception error
+                 (throw (metabase.util.malli.fn/fixup-stacktrace error)))))))
 
     '([x :- :int y] {:pre [(int? x)]})
-    '(fn* ([a b]
-           (metabase.util.malli.fn/validate-input {} :int a)
-           ((fn* ([x y]
-                  {:pre [(int? x)]}))
-            a b)))
+    '(let* [&f (fn* ([x y]
+                     {:pre [(int? x)]}))]
+       (fn* ([a b]
+             (try
+               (metabase.util.malli.fn/validate-input {} :int a)
+               (&f a b)
+               (catch java.lang.Exception error
+                 (throw (metabase.util.malli.fn/fixup-stacktrace error)))))))
 
     '(:- :int
          ([x] (inc x))
          ([x :- :int y] (+ x y)))
-    '(fn*
-      ([a]
-       (metabase.util.malli.fn/validate-output {} :int ((fn* ([x]
-                                                              (inc x))
-                                                             ([x y]
-                                                              (+ x y)))
-                                                        a)))
-      ([a b]
-       (metabase.util.malli.fn/validate-input {} :int a)
-       (metabase.util.malli.fn/validate-output {} :int ((fn* ([x]
-                                                              (inc x))
-                                                             ([x y]
-                                                              (+ x y)))
-                                                        a b))))))
+    '(let* [&f (fn* ([x]
+                     (inc x))
+                    ([x y]
+                     (+ x y)))]
+       (fn*
+        ([a]
+         (try
+           (metabase.util.malli.fn/validate-output {} :int (&f a))
+           (catch java.lang.Exception error
+             (throw (metabase.util.malli.fn/fixup-stacktrace error)))))
+        ([a b]
+         (try
+           (metabase.util.malli.fn/validate-input {} :int a)
+           (metabase.util.malli.fn/validate-output {} :int (&f a b))
+           (catch java.lang.Exception error
+             (throw (metabase.util.malli.fn/fixup-stacktrace error)))))))))
 
 (deftest ^:parallel fn-test
   (let [f (mu.fn/fn :- :int [y] y)]
@@ -119,11 +136,7 @@
     (is (thrown-with-msg?
          clojure.lang.ExceptionInfo
          #"Invalid output:.*should be an integer"
-         (f nil)))
-    (testing "the stacktrace does not begin in the validation function"
-      (let [e ^Exception (is (thrown? clojure.lang.ExceptionInfo (f nil)))]
-        (is (not (str/starts-with? (.getClassName (first (.getStackTrace e)))
-                                   "metabase.util.malli.fn")))))))
+         (f nil)))))
 
 (deftest ^:parallel registry-test
   (mr/def ::number :int)
@@ -149,13 +162,16 @@
                  & {:keys [token-check?]
                     :or   {token-check? true}}]
                 (merge {:path path, :token-check? token-check?} opts))]
-    (is (= '(fn*
-              ([a b & more]
-               (metabase.util.malli.fn/validate-input {:fn-name 'my-fn} :map b)
-               (clojure.core/apply (clojure.core/fn
-                                     [path opts & {:keys [token-check?], :or {token-check? true}}]
-                                     (merge {:path path, :token-check? token-check?} opts))
-                                   a b more)))
+    (is (= '(let* [&f (clojure.core/fn
+                        [path opts & {:keys [token-check?], :or {token-check? true}}]
+                        (merge {:path path, :token-check? token-check?} opts))]
+              (clojure.core/fn
+                ([a b & more]
+                 (try
+                   (metabase.util.malli.fn/validate-input {:fn-name 'my-fn} :map b)
+                   (clojure.core/apply &f a b more)
+                   (catch java.lang.Exception error
+                     (throw (metabase.util.malli.fn/fixup-stacktrace error)))))))
            (macroexpand form)))
     (is (= [:=>
             [:cat :any :map [:* :any]]
@@ -181,3 +197,39 @@
              :args
              meta
              :tag))))
+
+(mu/defn ^:private foo :- keyword? [_x :- string?] "bad output")
+(mu/defn ^:private bar :- keyword?
+  ([_x :- string? _y] "bad output")
+  ([_x :- string? _y & _xs] "bad output"))
+
+(mu/defn ^:private works? :- keyword? [_x :- string?] :yes)
+
+(defn from-here? [^Exception e]
+  (let [top-trace (-> e (.getStackTrace) first)
+        cn        (when top-trace
+                    (.getClassName ^StackTraceElement top-trace))]
+    (when cn
+      (is (re-find (re-pattern (munge (namespace `foo))) cn))
+      (is (not (re-find #"metabase.util.malli.fn\$validate" cn))))))
+
+(deftest ^:parallel error-location-tests
+  (tools.macro/macrolet [(check-error-location [expr]
+                           `(try ~expr
+                                 (is false "Did not throw")
+                                 (catch Exception e# (from-here? e#))))]
+    (testing "Top stack trace is this namespace, not in validate"
+      (testing "single arity input"
+        (check-error-location (foo 1)))
+      (testing "single arity output"
+        (check-error-location (foo "good input")))
+      (testing "multi arity input"
+        (check-error-location (bar 1 2)))
+      (testing "multi arity output"
+        (check-error-location (bar "good input" 2)))
+      (testing "var args input"
+        (check-error-location (bar 1 2 3 4 5)))
+      (testing "var args output"
+        (check-error-location (bar "good input" 2 3 4 5))))
+    (testing "sanity check-error-location that it works"
+      (is (= :yes (works? "valid input"))))))
-- 
GitLab