From 00737e991085f9c633557dd0ddc6601709868613 Mon Sep 17 00:00:00 2001
From: adam-james <21064735+adam-james-v@users.noreply.github.com>
Date: Wed, 13 Nov 2024 13:58:01 -0800
Subject: [PATCH] Apply Column Sort To Pivot Sections (#49802)

* Apply Column Sort To Pivot Sections

Fixes #49437

This PR fixes the dataset API so that column sorts (ascending/descending settings on pivot-rows and pivot-cols) are
properly parsed and passed through the QP.

After that, I've also modified the post processor to use the sort orders properly in pivoted exports.

* println debugging :)

* fixing a few tests

* fix a few more tests
---
 src/metabase/api/dataset.clj                  |  2 +-
 .../middleware/pivot_export.clj               |  3 +-
 src/metabase/query_processor/pivot.clj        | 34 +++++++++--
 .../query_processor/pivot/postprocess.clj     | 57 ++++++++++++++-----
 test/metabase/api/downloads_exports_test.clj  |  8 ---
 test/metabase/query_processor/pivot_test.clj  |  6 +-
 6 files changed, 80 insertions(+), 30 deletions(-)

diff --git a/src/metabase/api/dataset.clj b/src/metabase/api/dataset.clj
index b4f107d27e0..e017a4b5d41 100644
--- a/src/metabase/api/dataset.clj
+++ b/src/metabase/api/dataset.clj
@@ -140,7 +140,7 @@
         query                         (dissoc query :was-pivot)
         viz-settings                  (-> (json/parse-string visualization_settings viz-setting-key-fn)
                                           (update :table.columns mbql.normalize/normalize)
-                                          mb.viz/db->norm)
+                                          mb.viz/norm->db)
         query                         (-> query
                                           (assoc :viz-settings viz-settings)
                                           (dissoc :constraints)
diff --git a/src/metabase/query_processor/middleware/pivot_export.clj b/src/metabase/query_processor/middleware/pivot_export.clj
index fc35b628273..b89c6f6b2e5 100644
--- a/src/metabase/query_processor/middleware/pivot_export.clj
+++ b/src/metabase/query_processor/middleware/pivot_export.clj
@@ -1,7 +1,8 @@
 (ns metabase.query-processor.middleware.pivot-export)
 
 (defn add-data-for-pivot-export
-  "Provide `:pivot-export-options` in the query metadata so that `qp.si/streaming-resuilts-writer` implementations can post-process query results."
+  "Provide `:pivot-export-options` in the query metadata so that `qp.si/streaming-results-writer` implementations can
+  post-process query results."
   [query rff]
   (fn add-query-for-pivot-rff* [metadata]
     ;; the `qp.si/streaming-results-writer` implmementations can apply/not-apply formatting based on the key's value
diff --git a/src/metabase/query_processor/pivot.clj b/src/metabase/query_processor/pivot.clj
index 3cbf5957c35..7ef4353638c 100644
--- a/src/metabase/query_processor/pivot.clj
+++ b/src/metabase/query_processor/pivot.clj
@@ -4,6 +4,7 @@
   dumb, right? It's not just me? Why don't we just generate a big ol' UNION query so we can run one single query
   instead of running like 10 separate queries? -- Cam"
   (:require
+   [cheshire.core :as json]
    [medley.core :as m]
    [metabase.lib.convert :as lib.convert]
    [metabase.lib.core :as lib]
@@ -15,6 +16,7 @@
    [metabase.lib.schema.id :as lib.schema.id]
    [metabase.lib.schema.info :as lib.schema.info]
    [metabase.lib.util :as lib.util]
+   [metabase.models.visualization-settings :as mb.viz]
    [metabase.query-processor :as qp]
    [metabase.query-processor.error-type :as qp.error-type]
    [metabase.query-processor.middleware.add-dimension-projections :as qp.add-dimension-projections]
@@ -371,6 +373,27 @@
     (when (some some? (vals pivot-opts))
       pivot-opts)))
 
+(mu/defn- column-sort-order :- ::pivot-opts
+  "Looks at the `pivot_table.column_sort_order` key in the card's visualization settings and generates a map from the
+  column's index to the setting (either ascending or descending)."
+  [query        :- [:map
+                    [:database ::lib.schema.id/database]]
+   viz-settings :- [:maybe :map]]
+  (let [metadata-provider  (or (:lib/metadata query)
+                               (lib.metadata.jvm/application-database-metadata-provider (:database query)))
+        query              (lib/query metadata-provider query)
+        index-in-breakouts (into {}
+                                 (comp (filter (comp #{:source/breakouts :source/aggregations} :lib/source))
+                                       (map-indexed (fn [i column] [(:name column) i])))
+                                 (lib/returned-columns query))]
+    (-> (or (:column_settings viz-settings)
+            (::mb.viz/column-settings viz-settings))
+        (update-keys (fn [k]
+                       (if (string? k)
+                         (-> k json/parse-string last index-in-breakouts)
+                         (->> k ::mb.viz/column-name index-in-breakouts))))
+        (update-vals (comp keyword :pivot_table.column_sort_order)))))
+
 (mu/defn- field-ref-pivot-options :- ::pivot-opts
   "Looks at the `pivot_table.column_split` key in the card's visualization settings and generates `pivot-rows` and
   `pivot-cols` to use for generating subqueries. Supports field ref-based settings only."
@@ -420,10 +443,13 @@
   [query        :- [:map
                     [:database ::lib.schema.id/database]]
    viz-settings :- [:maybe :map]]
-  (let [{:keys [rows columns]} (:pivot_table.column_split viz-settings)]
-    (if (and (every? string? rows) (every? string? columns))
-      (column-name-pivot-options query viz-settings)
-      (field-ref-pivot-options query viz-settings))))
+  (when viz-settings
+    (let [{:keys [rows columns]} (:pivot_table.column_split viz-settings)]
+      (merge
+       (if (and (every? string? rows) (every? string? columns))
+         (column-name-pivot-options query viz-settings)
+         (field-ref-pivot-options query viz-settings))
+       {:column-sort-order (column-sort-order query viz-settings)}))))
 
 (mu/defn- column-mapping-for-subquery :- ::pivot-column-mapping
   [num-canonical-cols            :- ::lib.schema.common/int-greater-than-or-equal-to-zero
diff --git a/src/metabase/query_processor/pivot/postprocess.clj b/src/metabase/query_processor/pivot/postprocess.clj
index f57c94dc271..038ca24c9b2 100644
--- a/src/metabase/query_processor/pivot/postprocess.clj
+++ b/src/metabase/query_processor/pivot/postprocess.clj
@@ -221,14 +221,14 @@
 (defn- build-column-totals
   "Build column totals for a section."
   [section-path col-combos pivot-measures totals row-totals? ordered-formatters pivot-rows]
-  (let [totals-row (distinct (for [col-combo   col-combos
-                                   measure-key pivot-measures]
-                               (fmt (get ordered-formatters measure-key)
-                                    (get-in totals (concat
-                                                    [:column-totals]
-                                                    section-path
-                                                    col-combo
-                                                    [measure-key])))))]
+  (let [totals-row (for [col-combo   col-combos
+                         measure-key pivot-measures]
+                     (fmt (get ordered-formatters measure-key)
+                          (get-in totals (concat
+                                          [:column-totals]
+                                          section-path
+                                          col-combo
+                                          [measure-key]))))]
     (when (some #(and (some? %) (not= "" %)) totals-row)
       (concat
        (cons (format "Totals for %s" (fmt (get ordered-formatters (first pivot-rows)) (last section-path)))
@@ -255,6 +255,30 @@
      (fmt (get ordered-formatters measure-key)
           (get-in totals [:grand-total measure-key])))))
 
+(defn- sort-pivot-subsections
+  [config section]
+  (let [{:keys [pivot-rows column-sort-order]} config]
+    (reduce
+     (fn [section [idx pivot-row-idx]]
+       (let [sort-spec (get column-sort-order pivot-row-idx :ascending)
+             transform (if (= :descending sort-spec) reverse identity)
+             groups    (group-by #(nth % idx) section)]
+         (mapcat second (transform (sort groups)))))
+     section
+     (reverse (map vector (range) pivot-rows)))))
+
+(defn- sort-column-combos
+  [config column-combos]
+  (let [{:keys [pivot-cols column-sort-order]} config]
+    (reduce
+     (fn [section [idx pivot-row-idx]]
+       (let [sort-spec (get column-sort-order pivot-row-idx :ascending)
+             transform (if (= :descending sort-spec) reverse identity)
+             groups    (group-by #(nth % idx) section)]
+         (mapcat second (transform (sort groups)))))
+     column-combos
+     (reverse (map vector (range) pivot-cols)))))
+
 (defn- append-totals-to-subsections
   [pivot section col-combos ordered-formatters]
   (let [{:keys [config
@@ -307,13 +331,17 @@
         {:keys [pivot-rows
                 pivot-cols
                 pivot-measures
+                column-sort-order
                 column-titles
                 row-totals?
                 col-totals?]}   config
+        sort-fns                (update-vals column-sort-order (fn [direction] (get {:ascending  identity
+                                                                                     :descending reverse} direction)))
         row-formatters          (mapv #(get ordered-formatters %) pivot-rows)
         col-formatters          (mapv #(get ordered-formatters %) pivot-cols)
         row-combos              (apply math.combo/cartesian-product (map row-values pivot-rows))
         col-combos              (apply math.combo/cartesian-product (map column-values pivot-cols))
+        col-combos              (sort-column-combos config col-combos)
         row-totals?             (and row-totals? (boolean (seq pivot-cols)))
         column-headers          (build-column-headers config col-combos col-formatters)
         headers                 (or (seq (build-headers column-headers config))
@@ -325,18 +353,21 @@
      (filter seq
              (apply concat
                     (let [sections-rows
-                          (for [section-row-combos (sort-by ffirst (vals (group-by first row-combos)))]
+                          (for [section-row-combos ((get sort-fns (first pivot-rows) identity) (sort-by ffirst (vals (group-by first row-combos))))]
                             (concat
                              (remove nil?
-                                     (for [row-combo (sort-by first section-row-combos)]
+                                     (for [row-combo section-row-combos]
                                        (build-row row-combo col-combos pivot-measures data totals row-totals? ordered-formatters row-formatters)))))]
                       (mapv
                        (fn [section-rows]
                          (->>
+                          section-rows
+                          (sort-pivot-subsections config)
                           ;; section rows are either enriched with column-totals rows or left as is
-                          (if col-totals?
-                            (append-totals-to-subsections pivot section-rows col-combos ordered-formatters)
-                            section-rows)
+                          ((fn [rows]
+                             (if (and col-totals? (> (count pivot-rows) 1))
+                               (append-totals-to-subsections pivot rows col-combos ordered-formatters)
+                               rows)))
                           ;; then, we apply the row-formatters to the pivot-rows portion of each row,
                           ;; filtering out any rows that begin with "Totals ..."
                           (mapv
diff --git a/test/metabase/api/downloads_exports_test.clj b/test/metabase/api/downloads_exports_test.clj
index 29a3f6cdd56..1451cac7ce0 100644
--- a/test/metabase/api/downloads_exports_test.clj
+++ b/test/metabase/api/downloads_exports_test.clj
@@ -280,13 +280,9 @@
         (testing "formatted"
           (is (= [[["Category" "2016" "2017" "2018" "2019" "Row totals"]
                    ["Doohickey" "$632.14" "$854.19" "$496.43" "$203.13" "$2,185.89"]
-                   ["Totals for Doohickey" "$632.14" "$854.19" "$496.43" "$203.13"]
                    ["Gadget" "$679.83" "$1,059.11" "$844.51" "$435.75" "$3,019.20"]
-                   ["Totals for Gadget" "$679.83" "$1,059.11" "$844.51" "$435.75"]
                    ["Gizmo" "$529.70" "$1,080.18" "$997.94" "$227.06" "$2,834.88"]
-                   ["Totals for Gizmo" "$529.70" "$1,080.18" "$997.94" "$227.06"]
                    ["Widget" "$987.39" "$1,014.68" "$912.20" "$195.04" "$3,109.31"]
-                   ["Totals for Widget" "$987.39" "$1,014.68" "$912.20" "$195.04"]
                    ["Grand totals" "$2,829.06" "$4,008.16" "$3,251.08" "$1,060.98" "$11,149.28"]]
                   #{:unsaved-card-download :card-download :dashcard-download
                     :alert-attachment :subscription-attachment
@@ -303,13 +299,9 @@
                     "2019-01-01T00:00:00Z"
                     "Row totals"]
                    ["Doohickey" "632.14" "854.19" "496.43" "203.13" "2185.89"]
-                   ["Totals for Doohickey" "632.14" "854.19" "496.43" "203.13"]
                    ["Gadget" "679.83" "1059.11" "844.51" "435.75" "3019.20"]
-                   ["Totals for Gadget" "679.83" "1059.11" "844.51" "435.75"]
                    ["Gizmo" "529.7" "1080.18" "997.94" "227.06" "2834.88"]
-                   ["Totals for Gizmo" "529.7" "1080.18" "997.94" "227.06"]
                    ["Widget" "987.39" "1014.68" "912.2" "195.04" "3109.31"]
-                   ["Totals for Widget" "987.39" "1014.68" "912.2" "195.04"]
                    ["Grand totals" "2829.06" "4008.16" "3251.08" "1060.98" "11149.28"]]
                   #{:unsaved-card-download :card-download :dashcard-download
                     :alert-attachment :subscription-attachment
diff --git a/test/metabase/query_processor/pivot_test.clj b/test/metabase/query_processor/pivot_test.clj
index 53987a92259..3720218e528 100644
--- a/test/metabase/query_processor/pivot_test.clj
+++ b/test/metabase/query_processor/pivot_test.clj
@@ -191,7 +191,7 @@
   (testing "`pivot-options` correctly generates pivot-rows and pivot-cols from a card's viz settings"
     (let [query         (api.pivots/pivot-query false)
           viz-settings  (:visualization_settings (api.pivots/pivot-card))
-          pivot-options {:pivot-rows [1 0], :pivot-cols [2] :pivot-measures nil}]
+          pivot-options {:pivot-rows [1 0], :pivot-cols [2] :pivot-measures nil :column-sort-order {}}]
       (is (= pivot-options
              (#'qp.pivot/pivot-options query viz-settings)))
       (are [num-breakouts expected] (= expected
@@ -214,7 +214,7 @@
                          {:rows    ["ID"]
                           :columns ["RATING"]}}
           pivot-options (#'qp.pivot/pivot-options query viz-settings)]
-      (is (= {:pivot-rows [], :pivot-cols [] :pivot-measures nil}
+      (is (= {:pivot-rows [], :pivot-cols [] :pivot-measures nil :column-sort-order {}}
              pivot-options))
       (is (= [[0 1] [1] [0] []]
              (#'qp.pivot/breakout-combinations 2 (:pivot-rows pivot-options) (:pivot-cols pivot-options)))))))
@@ -241,7 +241,7 @@
                                {:rows    ["CATEGORY"]
                                 :columns ["CREATED_AT"]}}
                 pivot-options (#'qp.pivot/pivot-options query viz-settings)]
-            (is (= {:pivot-rows [0], :pivot-cols [1] :pivot-measures nil}
+            (is (= {:pivot-rows [0], :pivot-cols [1] :pivot-measures nil :column-sort-order {}}
                    pivot-options))
             (is (= [[0 1] [1] [0] []]
                    (#'qp.pivot/breakout-combinations 2 (:pivot-rows pivot-options) (:pivot-cols pivot-options))))
-- 
GitLab