From 1600809b78dcb284a271eb9a34541829fb49736d Mon Sep 17 00:00:00 2001
From: metamben <103100869+metamben@users.noreply.github.com>
Date: Tue, 7 Feb 2023 00:44:33 +0300
Subject: [PATCH] Add aggregation expression  support for mongo (#28061)

---
 .../mongo/src/metabase/driver/mongo.clj       |   1 +
 .../metabase/driver/mongo/query_processor.clj | 152 +++++++++++-------
 .../driver/mongo/query_processor_test.clj     |  15 +-
 .../expression_aggregations_test.clj          |  17 ++
 4 files changed, 124 insertions(+), 61 deletions(-)

diff --git a/modules/drivers/mongo/src/metabase/driver/mongo.clj b/modules/drivers/mongo/src/metabase/driver/mongo.clj
index e01c9fc176d..8dc392c8b0e 100644
--- a/modules/drivers/mongo/src/metabase/driver/mongo.clj
+++ b/modules/drivers/mongo/src/metabase/driver/mongo.clj
@@ -229,6 +229,7 @@
                         column-info))})))
 
 (doseq [feature [:basic-aggregations
+                 :expression-aggregations
                  :nested-fields
                  :nested-queries
                  :native-parameters
diff --git a/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj b/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj
index f45114942da..8c931fd024d 100644
--- a/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj
+++ b/modules/drivers/mongo/src/metabase/driver/mongo/query_processor.clj
@@ -900,62 +900,96 @@
   [ag]
   {:group {(annotate/aggregation-name ag) (aggregation->rvalue ag)}})
 
-(defn- extract-aggregation
-  "Separate the expression `aggregation` named `aggr-name` into two parts:
-  an simple expression and an aggregation expression, where the simple expression
-  references the result of the aggregation expression such that first evaluating
-  the aggregation expression and binding its result to `aggr-name` and then
-  evaluating the simple expression in this context, the result is the same as
-  evaluating the whole expression `aggregation`.
-  This separation is necessary, because MongoDB doesn't support embedding
-  aggregations in `normal' expressions.
-
-  For example the aggregation
-    [:aggregation-options
-     [:+ [:/ [:sum
-              [:case [[[:< [:field 12 nil] [:field 7 nil]]
-                       [:field 12 nil]]]
-               {:default 0}]]
-             2]
-         1]
-     {:name \"expression\"}]
-  is transformed into the simple expression
-    [:+ [:/ \"$expression\" 2] 1]
-  and the aggregation expression
-    [:aggregation-options
-     [:sum
-      [:case [[[:< [:field 12 nil] [:field 7 nil]]
-               [:field 12 nil]]]
-       {:default 0}]]
-     {:name \"expression\"}]"
-  [aggregation aggr-name]
-  (when (and (vector? aggregation) (seq aggregation))
-    (let [[op & args] aggregation]
-      (cond
-        (= op :aggregation-options)
-        (let [[embedding-expr aggregation'] (extract-aggregation (first args) aggr-name)]
-          [embedding-expr (into [:aggregation-options aggregation'] (rest args))])
-
-        (aggregation-op op)
-        [(str \$ aggr-name) aggregation]
-
-        :else
-        (let [ges (map #(extract-aggregation % aggr-name) args)
-              [embedding-expr aggregation'] (first (filter some? ges))]
-          (when-not aggregation'
-            (throw
-             (ex-info (tru "Don''t know how to handle aggregation {0}" aggregation)
-                      {:type :invalid-query, :clause aggregation})))
-          [(into [op] (map (fn [arg ge] (if ge embedding-expr arg)) args ges))
-           aggregation'])))))
-
-(defn- expand-embedded-aggregation [aggregation]
-  (let [aggr-name (annotate/aggregation-name aggregation)
-        [embedding-expr aggregation-expr] (extract-aggregation aggregation aggr-name)
-        expanded (expand-aggregation aggregation-expr)]
-    (cond-> expanded
-      (not (string? embedding-expr))
-      (update :post conj {aggr-name (->rvalue embedding-expr)}))))
+(defn- extract-aggregations
+  "Extract aggregation expressions embedded in `aggr-expr` using `parent-name`
+  as a namespace for the names introduced for the aggregation expressions.
+  The function returns a pair with the first element an expression like
+  `aggr-expr` with aggregations replaced by new names. The second element of
+  the pair is a map from the extracted aggregations to the new names conjoined
+  on `aggregations-seen`. `:aggregation-option`s are ignored.
+
+  For example, given \"expression\" as `parent-name`, the expression
+
+  [:aggregation-options [:+ [:count [:field 1144 nil]]
+                            [:* [:count [:field 1144 nil]]
+                                [:sum [:+ [:field 1142 nil] 1]]]]
+                        {:name \"expression\"}]
+  is mapped to
+
+  [[:+ \"$expression~count\" [:* \"$expression~count\" \"$expression~sum\"]]
+   {[:count [:field 1144 nil]] \"expression~count\"
+    [:sum [:+ [:field 1142 nil] 1]] \"expression~sum\"}]"
+  ([aggr-expr parent-name] (extract-aggregations aggr-expr parent-name {}))
+  ([aggr-expr parent-name aggregations-seen]
+   (if (and (vector? aggr-expr) (seq aggr-expr))
+     (let [[op & args] aggr-expr
+           seen (get aggregations-seen aggr-expr)]
+       (cond
+         seen
+         [(str \$ seen) aggregations-seen]
+
+         (= :aggregation-options op)
+         (extract-aggregations (first args) parent-name aggregations-seen)
+
+         (aggregation-op op)
+         (let [aggr-name (str parent-name "~" (annotate/aggregation-name aggr-expr))]
+           [(str \$ aggr-name) (assoc aggregations-seen aggr-expr aggr-name)])
+
+         :else
+         (reduce (fn [[ges as] arg]
+                   (let [[ge as] (extract-aggregations arg parent-name as)]
+                     [(conj ges ge) as]))
+                 [[op] aggregations-seen]
+                 args)))
+     [aggr-expr aggregations-seen])))
+
+(defn- simplify-extracted-aggregations
+  "Simplifies the extracted aggregation ()for `aggr-name` if the expression
+  contains only a single top-level aggregation. In this case there is no
+  need for namespacing and `aggr-name` can be used as the name of the group
+  introduced for the aggregation.
+  `extracted-aggr` is typically the result of [[extract-aggregations]]."
+  [aggr-name [aggr-expr aggregations-seen :as extracted-aggr]]
+  (if-let [aggr-group (and (string? aggr-expr)
+                           (str/starts-with? aggr-expr (str \$ aggr-name "~"))
+                           (= (count aggregations-seen) 1)
+                           (let [[k v] (first aggregations-seen)]
+                             (when (= v (subs aggr-expr 1))
+                               k)))]
+    [(str \$ aggr-name) {aggr-group aggr-name}]
+    extracted-aggr))
+
+(defn- expand-aggregations
+  "Expands the aggregations in `aggr-expr` into groupings and post processing
+  expressions. The return value is a map with the following keys:
+  `:group` - a map containing the groups of aggregation expression,
+  `:post` - a vector of maps containing the expressions referring to the
+  fields generated by the groups. Each map in the `:post` vector may (and
+  usually does) refer to the fields introduced by the preceding maps."
+  [aggr-expr]
+  (let [aggr-name (annotate/aggregation-name aggr-expr)
+        [aggr-expr' aggregations-seen] (simplify-extracted-aggregations
+                                        aggr-name
+                                        (extract-aggregations aggr-expr aggr-name))
+        raggr-expr (->rvalue aggr-expr')
+        expandeds (map (fn [[aggr name]]
+                         (expand-aggregation [:aggregation-options aggr {:name name}]))
+                       aggregations-seen)]
+    {:group (into {} (map :group) expandeds)
+     :post (cond-> [(into {} (mapcat :post) expandeds)]
+             (not= raggr-expr (str \$ aggr-name)) (conj {aggr-name raggr-expr}))}))
+
+(defn- order-postprocessing
+  "Takes a sequence of post processing vectors (see [[expand-aggregations]]) and
+  returns a sequence with the maps at the same index merged.
+  This is an optimization to reduce the number of stages in the pipeline and
+  assumes that
+    a) maps can only depend on maps preceding them in their own vector and
+    b) the keys in the maps at the same level are unique."
+  [posts]
+  (when (seq posts)
+    (for [i (range (apply max (map count posts)))]
+      (into {} (map #(get % i)) posts))))
 
 (defn- group-and-post-aggregations
   "Mongo is picky about which top-level aggregations it allows with groups. Eg. even
@@ -967,11 +1001,11 @@
    of preceding stages.
    The intermittent results accrued in `$group` stage are discarded in the final `$project` stage."
   [id aggregations]
-  (let [expanded-ags (map expand-embedded-aggregation aggregations)
+  (let [expanded-ags (map expand-aggregations aggregations)
         group-ags    (mapcat :group expanded-ags)
-        post-ags     (mapcat :post expanded-ags)]
+        post-ags     (order-postprocessing (map :post expanded-ags))]
     (into [{$group (into (ordered-map/ordered-map "_id" id) group-ags)}]
-          (map (fn [p] {:$addFields p}))
+          (keep (fn [p] (when (seq p) {:$addFields p})))
           post-ags)))
 
 (defn- projection-group-map [fields]
diff --git a/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj b/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj
index a7bfac2b0e3..b3661df7540 100644
--- a/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj
+++ b/modules/drivers/mongo/test/metabase/driver/mongo/query_processor_test.clj
@@ -32,6 +32,17 @@
                                                          {:native []}
                                                          :joins [{:source-query "wow"}]}}))))))
 
+(deftest order-postprocessing-test
+  (is (= [{"expression_2~share" {"$divide" ["$count-where-141638" "$count-141639"]}}
+          {"expression" {"$add" ["$expression~count" {"$multiply" ["$expression~count" "$expression~sum"]}]}
+           "expression_2" {"$multiply" [2 "$expression_2~share"]}}]
+         (#'mongo.qp/order-postprocessing
+          [[{} {"expression" {"$add" ["$expression~count" {"$multiply" ["$expression~count" "$expression~sum"]}]}}]
+           [{}]
+           [{}]
+           [{"expression_2~share" {"$divide" ["$count-where-141638" "$count-141639"]}}
+            {"expression_2" {"$multiply" [2 "$expression_2~share"]}}]]))))
+
 (deftest relative-datetime-test
   (mt/test-driver :mongo
     (testing "Make sure relative datetimes are compiled sensibly"
@@ -49,8 +60,8 @@
                   :mbql?       true}
                  (qp/compile
                   (mt/mbql-query attempts
-                    {:aggregation [[:count]]
-                     :filter      [:time-interval $datetime :last :month]})))))))))
+                                 {:aggregation [[:count]]
+                                  :filter      [:time-interval $datetime :last :month]})))))))))
 
 (deftest absolute-datetime-test
   (mt/test-driver :mongo
diff --git a/test/metabase/query_processor_test/expression_aggregations_test.clj b/test/metabase/query_processor_test/expression_aggregations_test.clj
index b7b84a57653..3a8d5dee814 100644
--- a/test/metabase/query_processor_test/expression_aggregations_test.clj
+++ b/test/metabase/query_processor_test/expression_aggregations_test.clj
@@ -114,6 +114,23 @@
                                  [:* [:count $id] [:sum $price]]]]
                   :breakout    [$price]})))))))
 
+(deftest nested-post-multi-aggregation-test
+  (mt/test-drivers (mt/normal-drivers-with-feature :expression-aggregations)
+    (testing "nested post-aggregation math: count + (count * sum)"
+      (is (= [[1   990 22 22 2.0]
+              [2 10502 59 59 2.0]
+              [3   689 13 13 2.0]
+              [4   186  6  6 0.0]]
+             (mt/formatted-rows [int int int int float]
+               (mt/run-mbql-query venues
+                 {:aggregation [[:+
+                                 [:count $id]
+                                 [:* [:count $id] [:sum [:+ $price 1]]]]
+                                [:count $id]
+                                [:count]
+                                [:* 2 [:share [:< $price 4]]]]
+                  :breakout    [$price]})))))))
+
 (deftest math-inside-aggregations-test
   (mt/test-drivers (mt/normal-drivers-with-feature :expression-aggregations)
     (testing "post aggregation math + math inside aggregations: max(venue_price) + min(venue_price - id)"
-- 
GitLab