From 001055dfb508a3e6844becc2f897f5e0ab0b285c Mon Sep 17 00:00:00 2001
From: metamben <103100869+metamben@users.noreply.github.com>
Date: Mon, 20 Jun 2022 22:00:20 +0300
Subject: [PATCH] Fix compilation of temporal arithmetic in between filters
 (#23292)

Fix compilation of temporal arithmetic for BigQuery and Mongo 5+

* Mongo 4 doesn't support $dateAdd so the generated filters result in an exception.
* Support adding field to interval too (time intervals were not allowed in the first place of an addition)
* Support temporal arithmetic with more than two operands for Mongo
---
 .../bigquery_cloud_sdk/query_processor.clj    | 60 ++++++-------
 .../metabase/driver/mongo/query_processor.clj | 65 +++++++++++++-
 .../driver/mongo/query_processor_test.clj     | 38 ++++++++
 shared/src/metabase/mbql/schema.cljc          |  2 +-
 src/metabase/driver/sql/query_processor.clj   |  8 +-
 src/metabase/util.clj                         | 11 +++
 .../query_processor_test/filter_test.clj      | 87 ++++++++++++++++++-
 .../query_processor_test/timezones_test.clj   |  2 +-
 test/metabase/util_test.clj                   | 13 +++
 9 files changed, 250 insertions(+), 36 deletions(-)

diff --git a/modules/drivers/bigquery-cloud-sdk/src/metabase/driver/bigquery_cloud_sdk/query_processor.clj b/modules/drivers/bigquery-cloud-sdk/src/metabase/driver/bigquery_cloud_sdk/query_processor.clj
index 6cb57c210a7..972fd4afd3b 100644
--- a/modules/drivers/bigquery-cloud-sdk/src/metabase/driver/bigquery_cloud_sdk/query_processor.clj
+++ b/modules/drivers/bigquery-cloud-sdk/src/metabase/driver/bigquery_cloud_sdk/query_processor.clj
@@ -604,9 +604,10 @@
 (doseq [filter-type [:between := :!= :> :>= :< :<=]]
   (defmethod sql.qp/->honeysql [:bigquery-cloud-sdk filter-type]
     [driver clause]
-    ((get-method sql.qp/->honeysql [:sql filter-type])
-     driver
-     (reconcile-temporal-types clause))))
+    (reconcile-temporal-types
+     ((get-method sql.qp/->honeysql [:sql filter-type])
+      driver
+      (reconcile-temporal-types clause)))))
 
 
 ;;; +----------------------------------------------------------------------------------------------------------------+
@@ -616,48 +617,49 @@
 (defn- interval [amount unit]
   (hsql/raw (format "INTERVAL %d %s" (int amount) (name unit))))
 
-(defn- assert-addable-unit [t-type unit]
-  (when-not (contains? (temporal-type->supported-units t-type) unit)
-    ;; trying to add an `hour` to a `date` or a `year` to a `time` is something we shouldn't be allowing in the UI in
-    ;; the first place
-    (throw (ex-info (tru "Invalid query: you cannot add a {0} to a {1} column."
-                         (name unit) (name t-type))
-             {:type qp.error-type/invalid-query}))))
-
 ;; We can coerce the HoneySQL form this wraps to whatever we want and generate the appropriate SQL.
 ;; Thus for something like filtering against a relative datetime
 ;;
 ;; [:time-interval <datetime field> -1 :day]
 ;;
 ;;
+(def ^:private temporal-type->arithmetic-function
+  {:timestamp :timestamp_add
+   :datetime  :datetime_add
+   :date      :date_add
+   :time      :time_add})
+
 (defrecord AddIntervalForm [hsql-form amount unit]
   hformat/ToSql
   (to-sql [_]
-    (loop [hsql-form hsql-form]
-      (let [t      (temporal-type hsql-form)
-            add-fn (case t
-                     :timestamp :timestamp_add
-                     :datetime  :datetime_add
-                     :date      :date_add
-                     :time      :time_add
-                     nil)]
-        (if-not add-fn
-          (recur (->temporal-type :datetime hsql-form))
-          (do
-            (assert-addable-unit t unit)
-            (hformat/to-sql (hsql/call add-fn hsql-form (interval amount unit)))))))))
+    (let [t      (temporal-type hsql-form)
+          add-fn (temporal-type->arithmetic-function t)]
+      (hformat/to-sql (hsql/call add-fn hsql-form (interval amount unit))))))
+
+(defn- add-interval-form [hsql-form amount unit]
+  (let [t         (temporal-type hsql-form)
+        add-fn    (temporal-type->arithmetic-function t)
+        hsql-form (if (or (not add-fn)
+                          (and (not (contains? (temporal-type->supported-units t) unit))
+                               (contains? (temporal-type->supported-units :datetime) unit)))
+                    (->temporal-type :datetime hsql-form)
+                    hsql-form)]
+    (AddIntervalForm. hsql-form amount unit)))
 
 (defmethod temporal-type AddIntervalForm
   [add-interval]
   (temporal-type (:hsql-form add-interval)))
 
 (defmethod ->temporal-type [:temporal-type AddIntervalForm]
-  [target-type add-interval-form]
-  (let [current-type (temporal-type (:hsql-form add-interval-form))]
+  [target-type form]
+  (let [current-type (temporal-type (:hsql-form form))]
     (when (#{[:date :time] [:time :date]} [current-type target-type])
       (throw (ex-info (tru "It doesn''t make sense to convert between DATEs and TIMEs!")
-               {:type qp.error-type/invalid-query}))))
-  (map->AddIntervalForm (update add-interval-form :hsql-form (partial ->temporal-type target-type))))
+                      {:type qp.error-type/invalid-query}))))
+  (let [new-form (add-interval-form (->temporal-type target-type (:hsql-form form)) (:amount form) (:unit form))]
+    (if (= (temporal-type new-form) target-type)
+      new-form
+      (hx/cast target-type form))))
 
 (defmethod sql.qp/add-interval-honeysql-form :bigquery-cloud-sdk
   [_ hsql-form amount unit]
@@ -667,7 +669,7 @@
                     (and (= (temporal-type hsql-form) :timestamp)
                          (not (contains? (temporal-type->supported-units :timestamp) unit)))
                     (hx/cast :datetime))]
-    (AddIntervalForm. hsql-form amount unit)))
+    (add-interval-form hsql-form amount unit)))
 
 (defmethod driver/mbql->native :bigquery-cloud-sdk
   [driver outer-query]
diff --git a/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj b/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj
index 5aa4b8f21ac..42a5d300a06 100644
--- a/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj
+++ b/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj
@@ -6,6 +6,7 @@
             [clojure.walk :as walk]
             [flatland.ordered.map :as ordered-map]
             [java-time :as t]
+            [metabase.driver :as driver]
             [metabase.driver.common :as driver.common]
             [metabase.mbql.schema :as mbql.s]
             [metabase.mbql.util :as mbql.u]
@@ -363,8 +364,68 @@
 (defmethod ->rvalue :concat    [[_ & args]] {"$concat" (mapv ->rvalue args)})
 (defmethod ->rvalue :substring [[_ & args]] {"$substrCP" (mapv ->rvalue args)})
 
-(defmethod ->rvalue :+ [[_ & args]] {"$add" (mapv ->rvalue args)})
-(defmethod ->rvalue :- [[_ & args]] {"$subtract" (mapv ->rvalue args)})
+;;; Intervals are not first class Mongo citizens, so they cannot be translated on their own.
+;;; The only thing we can do with them is adding to or subtracting from a date valued expression.
+;;; Also, date arithmetic with intervals was first implemented in version 5. (Before that only
+;;; ordinary addition could be used: one of the operands of the addition could be a date, their
+;;; rest of the operands had to be integers and would be treated as milliseconds.)
+;;; Because of this, whenever we translate date arithmetic with intervals, we check the major
+;;; version of the database and throw a nice exception if it's less than 5.
+
+(defn- get-mongo-version []
+  (:version (driver/describe-database :mongo (qp.store/database))))
+
+(defn- get-major-version [version]
+  (some-> version (str/split #"\.") first parse-long))
+
+(defn- check-date-operations-supported []
+  (let [mongo-version (get-mongo-version)
+        major-version (get-major-version mongo-version)]
+    (when (and major-version (< major-version 5))
+      (throw (ex-info "Date arithmetic not supported in versions before 5"
+                      {:database-version mongo-version})))))
+
+(defn- interval? [expr]
+  (and (vector? expr) (= (first expr) :interval)))
+
+(defn- summarize-interval [op date-expr [_ amount unit]]
+  {op {:startDate date-expr
+       :unit unit
+       :amount amount}})
+
+(defn- summarize-num-or-interval [number-op date-op mongo-expr mbql-expr]
+  (cond
+    (interval? mbql-expr) (summarize-interval date-op mongo-expr mbql-expr)
+    (contains? mongo-expr number-op) (update mongo-expr number-op conj (->rvalue mbql-expr))
+    :else {number-op [mongo-expr (->rvalue mbql-expr)]}))
+
+(def ^:private num-or-interval-reducer
+  {:+ (partial summarize-num-or-interval "$add" "$dateAdd")
+   :- (partial summarize-num-or-interval "$subtract" "$dateSubtract")})
+
+(defmethod ->rvalue :+ [[_ & args]]
+  ;; Addition is commutative and any but not all elements of `args` can be intervals.
+  ;; We pick the first arg that is not an interval and add the rest of args to it.
+  ;; (It's the callers responsibility to make sure that the first non-interval argument
+  ;; represents a date and not an offset like an integer would.)
+  ;; If none of the args is an interval, we shortcut with a simple addition.
+  (if (some interval? args)
+    (if-let [[arg others] (u/pick-first (complement interval?) args)]
+      (do
+        (check-date-operations-supported)
+        (reduce (num-or-interval-reducer :+) (->rvalue arg) others))
+      (throw (ex-info "Summing intervals is not supported" {:args args})))
+    {"$add" (mapv ->rvalue args)}))
+
+(defmethod ->rvalue :- [[_ & [arg & others :as args]]]
+  ;; Subtraction is not commutative so `arg` cannot be an interval.
+  ;; If none of the args is an interval, we shortcut with a simple subtraction.
+  (if (some interval? others)
+    (do
+      (check-date-operations-supported)
+      (reduce (num-or-interval-reducer :-) (->rvalue arg) others))
+    {"$subtract" (mapv ->rvalue args)}))
+
 (defmethod ->rvalue :* [[_ & args]] {"$multiply" (mapv ->rvalue args)})
 (defmethod ->rvalue :/ [[_ & args]] {"$divide" (mapv ->rvalue args)})
 
diff --git a/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj b/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj
index ea3f3c2cd41..73e00722849 100644
--- a/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj
+++ b/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj
@@ -1,6 +1,7 @@
 (ns metabase.driver.mongo.query-processor-test
   (:require [clojure.set :as set]
             [clojure.test :refer :all]
+            [java-time :as t]
             [metabase.driver.mongo.query-processor :as mongo.qp]
             [metabase.models :refer [Field Table]]
             [metabase.query-processor :as qp]
@@ -272,3 +273,40 @@
                    (mt/mbql-query checkins
                      {:filter   [:time-interval $date -4 :month]
                       :breakout [[:datetime-field $date :day]]}))))))))))
+
+(deftest temporal-arithmetic-test
+  (testing "Mixed integer and date arithmetic works with Mongo 5+"
+    (with-redefs [mongo.qp/get-mongo-version (constantly "5.2.13")]
+      (mt/with-clock #t "2022-06-21T15:36:00+02:00[Europe/Berlin]"
+        (is (= {:$expr
+                {"$lt"
+                 [{"$dateAdd"
+                   {:startDate {"$add" [{"$dateAdd" {:startDate "$date-field"
+                                                     :unit :year
+                                                     :amount 1}}
+                                        3600000]}
+                    :unit :month
+                    :amount -1}}
+                  {"$subtract"
+                   [{"$dateSubtract" {:startDate {:$dateFromString {:dateString "2008-05-31"}}
+                                      :unit :week
+                                      :amount -1}}
+                    86400000]}]}}
+               (mongo.qp/compile-filter [:<
+                                         [:+
+                                          [:interval 1 :year]
+                                          [:field "date-field"]
+                                          3600000
+                                          [:interval -1 :month]]
+                                         [:-
+                                          [:absolute-datetime (t/local-date "2008-05-31")]
+                                          [:interval -1 :week]
+                                          86400000]]))))))
+  (testing "Date arithmetic fails with Mongo 4-"
+    (with-redefs [mongo.qp/get-mongo-version (constantly "4")]
+      (is (thrown-with-msg? clojure.lang.ExceptionInfo  #"Date arithmetic not supported in versions before 5"
+                            (mongo.qp/compile-filter [:<
+                                                      [:+
+                                                       [:interval 1 :year]
+                                                       [:field "date-field"]]
+                                                      [:absolute-datetime (t/local-date "2008-05-31")]]))))))
diff --git a/shared/src/metabase/mbql/schema.cljc b/shared/src/metabase/mbql/schema.cljc
index 22191be2382..5b8c00d5c58 100644
--- a/shared/src/metabase/mbql/schema.cljc
+++ b/shared/src/metabase/mbql/schema.cljc
@@ -548,7 +548,7 @@
   s StringExpressionArg, pattern s/Str)
 
 (defclause ^{:requires-features #{:expressions}} +
-  x NumericExpressionArg, y NumericExpressionArgOrInterval, more (rest NumericExpressionArgOrInterval))
+  x NumericExpressionArgOrInterval, y NumericExpressionArgOrInterval, more (rest NumericExpressionArgOrInterval))
 
 (defclause ^{:requires-features #{:expressions}} -
   x NumericExpressionArg, y NumericExpressionArgOrInterval, more (rest NumericExpressionArgOrInterval))
diff --git a/src/metabase/driver/sql/query_processor.clj b/src/metabase/driver/sql/query_processor.clj
index 59f49388a0c..65363499810 100644
--- a/src/metabase/driver/sql/query_processor.clj
+++ b/src/metabase/driver/sql/query_processor.clj
@@ -413,14 +413,18 @@
 (defmethod ->honeysql [:sql :power] [driver [_ field power]]
   (hsql/call :power (->honeysql driver field) (->honeysql driver power)))
 
+(defn- interval? [expr]
+  (and (vector? expr) (= (first expr) :interval)))
+
 (defmethod ->honeysql [:sql :+]
   [driver [_ & args]]
   (if (mbql.u/datetime-arithmetics? args)
-    (let [[field & intervals] args]
+    (if-let [[field intervals] (u/pick-first (complement interval?) args)]
       (reduce (fn [hsql-form [_ amount unit]]
                 (add-interval-honeysql-form driver hsql-form amount unit))
               (->honeysql driver field)
-              intervals))
+              intervals)
+      (throw (ex-info "Summing intervals is not supported" {:args args})))
     (apply hsql/call :+ (map (partial ->honeysql driver) args))))
 
 (defmethod ->honeysql [:sql :-] [driver [_ & args]] (apply hsql/call :- (map (partial ->honeysql driver) args)))
diff --git a/src/metabase/util.clj b/src/metabase/util.clj
index dcf10cdac9b..1e08b33e843 100644
--- a/src/metabase/util.clj
+++ b/src/metabase/util.clj
@@ -983,3 +983,14 @@
   "Generates a random NanoID string. Usually these are used for the entity_id field of various models."
   []
   (nano-id))
+
+(defn pick-first
+  "Returns a pair [match others] where match is the first element of `coll` for which `pred` returns
+  a truthy value and others is a sequence of the other elements of `coll` with the order preserved.
+  Returns nil if no element satisfies `pred`."
+  [pred coll]
+  (loop [xs (seq coll), prefix []]
+    (when-let [[x & xs] xs]
+      (if (pred x)
+        [x (concat prefix xs)]
+        (recur xs (conj prefix x))))))
diff --git a/test/metabase/query_processor_test/filter_test.clj b/test/metabase/query_processor_test/filter_test.clj
index 75dc3d087e9..d960189980c 100644
--- a/test/metabase/query_processor_test/filter_test.clj
+++ b/test/metabase/query_processor_test/filter_test.clj
@@ -1,9 +1,12 @@
 (ns metabase.query-processor-test.filter-test
   "Tests for the `:filter` clause."
-  (:require [clojure.test :refer :all]
+  (:require [clojure.set :as set]
+            [clojure.string :as str]
+            [clojure.test :refer :all]
             [metabase.driver :as driver]
             [metabase.query-processor :as qp]
             [metabase.query-processor-test :as qp.test]
+            [metabase.query-processor-test.timezones-test :as timezones-test]
             [metabase.test :as mt]))
 
 (deftest and-test
@@ -109,6 +112,88 @@
                    {:aggregation [[:count]]
                     :filter      [:between [:datetime-field $date :day] "2015-04-01" "2015-05-01"]}))))))))
 
+(defn- mongo-major-version [db]
+  (when (= driver/*driver* :mongo)
+    (-> (driver/describe-database :mongo db)
+        :version (str/split #"\.") first parse-long)))
+
+(defn- timezone-arithmetic-drivers []
+  (set/intersection
+   ;; we also want to test this against MongoDB but [[mt/normal-drivers-with-feature]] would normally not include that
+   ;; since MongoDB only supports expressions if version is 4.0 or above and [[mt/normal-drivers-with-feature]]
+   ;; currently uses [[driver/supports?]] rather than [[driver/database-supports?]] (TODO FIXME, see #23422)
+   (conj (mt/normal-drivers-with-feature :expressions) :mongo)
+   (timezones-test/timezone-aware-column-drivers)))
+
+(deftest temporal-arithmetic-test
+  (testing "Should be able to use temporal arithmetic expressions in filters (#22531)"
+    (mt/test-drivers (timezone-arithmetic-drivers)
+      (mt/dataset attempted-murders
+        (when-not (some-> (mongo-major-version (mt/db))
+                          (< 5))
+          (doseq [offset-unit [:year :day]
+                  interval-unit [:year :day]
+                  compare-op [:between := :< :<= :> :>=]
+                  compare-order (cond-> [:field-first]
+                                  (not= compare-op :between) (conj :value-first))]
+            (let [query (mt/mbql-query attempts
+                          {:aggregation [[:count]]
+                           :filter      (cond-> [compare-op]
+                                          (= compare-order :field-first)
+                                          (conj [:+ !default.datetime_tz [:interval 3 offset-unit]]
+                                                [:relative-datetime -7 interval-unit])
+                                          (= compare-order :value-first)
+                                          (conj [:relative-datetime -7 interval-unit]
+                                                [:+ !default.datetime_tz [:interval 3 offset-unit]])
+                                          (= compare-op :between)
+                                          (conj [:relative-datetime 0 interval-unit]))})]
+              ;; we are not interested in the exact result, just want to check
+              ;; that the query can be compiled and executed
+              (mt/with-native-query-testing-context query
+                (let [[[result]] (mt/formatted-rows [int]
+                                   (qp/process-query query))]
+                  (if (= driver/*driver* :mongo)
+                    (is (or (nil? result)
+                            (pos-int? result)))
+                    (is (nat-int? result))))))))))))
+
+(deftest nonstandard-temporal-arithmetic-test
+  (testing "Nonstandard temporal arithmetic should also be supported"
+    (mt/test-drivers (timezone-arithmetic-drivers)
+      (mt/dataset attempted-murders
+        (when-not (some-> (mongo-major-version (mt/db))
+                          (< 5))
+          (doseq [offset-unit [:year :day]
+                  interval-unit [:year :day]
+                  compare-op [:between := :< :<= :> :>=]
+                  add-op [:+ #_:-] ; TODO support subtraction like sql.qp/add-interval-honeysql-form (#23423)
+                  compare-order (cond-> [:field-first]
+                                  (not= compare-op :between) (conj :value-first))]
+            (let [add-fn (fn [field interval]
+                           (if (= add-op :-)
+                             [add-op field interval interval]
+                             [add-op interval field interval]))
+                  query (mt/mbql-query attempts
+                          {:aggregation [[:count]]
+                           :filter      (cond-> [compare-op]
+                                          (= compare-order :field-first)
+                                          (conj (add-fn !default.datetime_tz [:interval 3 offset-unit])
+                                                [:relative-datetime -7 interval-unit])
+                                          (= compare-order :value-first)
+                                          (conj [:relative-datetime -7 interval-unit]
+                                                (add-fn !default.datetime_tz [:interval 3 offset-unit]))
+                                          (= compare-op :between)
+                                          (conj [:relative-datetime 0 interval-unit]))})]
+              ;; we are not interested in the exact result, just want to check
+              ;; that the query can be compiled and executed
+              (mt/with-native-query-testing-context query
+                (let [[[result]] (mt/formatted-rows [int]
+                                   (qp/process-query query))]
+                  (if (= driver/*driver* :mongo)
+                    (is (or (nil? result)
+                            (pos-int? result)))
+                    (is (nat-int? result))))))))))))
+
 (deftest or-test
   (mt/test-drivers (mt/normal-drivers)
     (testing ":or, :<=, :="
diff --git a/test/metabase/query_processor_test/timezones_test.clj b/test/metabase/query_processor_test/timezones_test.clj
index f4423e1461f..50449e30fec 100644
--- a/test/metabase/query_processor_test/timezones_test.clj
+++ b/test/metabase/query_processor_test/timezones_test.clj
@@ -31,7 +31,7 @@
    (set (mt/normal-drivers-with-feature :set-timezone))
    broken-drivers))
 
-(defn- timezone-aware-column-drivers
+(defn timezone-aware-column-drivers
   "Drivers that support the equivalent of `TIMESTAMP WITH TIME ZONE` columns."
   []
   (conj (set-timezone-drivers) :h2 :bigquery-cloud-sdk :sqlserver :mongo))
diff --git a/test/metabase/util_test.clj b/test/metabase/util_test.clj
index 6566e795818..f20799f6059 100644
--- a/test/metabase/util_test.clj
+++ b/test/metabase/util_test.clj
@@ -368,3 +368,16 @@
        1.3     2 1.251
        12300.0 3 12345.67
        0.00321 3 0.003209817))
+
+(defspec pick-first-test 100
+  (prop/for-all [coll (gen/list gen/int)]
+    (let [result (u/pick-first pos? coll)]
+      (or (and (nil? result)
+               (every? (complement pos?) coll))
+          (let [[x ys] result
+                [non-pos [m & rest]] (split-with (complement pos?) coll)]
+            (and (vector? result)
+                 (= (count result) 2)
+                 (pos? x)
+                 (= x m)
+                 (= ys (concat non-pos rest))))))))
-- 
GitLab