Skip to content
Snippets Groups Projects
Commit 09776296 authored by Cam Saül's avatar Cam Saül Committed by GitHub
Browse files

Merge pull request #3802 from metabase/sql-expression-aggregations

Expression aggregations for SQL drivers
parents 472335d6 406e1249
No related branches found
No related tags found
No related merge requests found
......@@ -303,6 +303,7 @@
:standard-deviation-aggregations
:foreign-keys
:expressions
:expression-aggregations
:native-parameters}
(set-timezone-sql driver) (conj :set-timezone)))
......
......@@ -6,12 +6,14 @@
[clojure.tools.logging :as log]
(honeysql [core :as hsql]
[format :as hformat]
[helpers :as h])
[helpers :as h]
types)
(metabase [config :as config]
[driver :as driver])
[metabase.driver.generic-sql :as sql]
[metabase.query-processor :as qp]
metabase.query-processor.interface
(metabase.query-processor [annotate :as annotate]
interface)
[metabase.util :as u]
[metabase.util.honeysql-extensions :as hx])
(:import java.sql.Timestamp
......@@ -54,14 +56,16 @@
(or (get-in *query* [:query :expressions (keyword expression-name)]) (:expressions (:query *query*))
(throw (Exception. (format "No expression named '%s'." (name expression-name))))))
;; TODO - maybe this fn should be called `->honeysql` instead.
(defprotocol ^:private IGenericSQLFormattable
(formatted [this]
"Return an appropriate HoneySQL form for an object."))
(extend-protocol IGenericSQLFormattable
nil (formatted [_] nil)
Number (formatted [this] this)
String (formatted [this] this)
nil (formatted [_] nil)
Number (formatted [this] this)
String (formatted [this] this)
honeysql.types.SqlCall (formatted [this] this)
Expression
(formatted [{:keys [operator args]}]
......@@ -113,31 +117,54 @@
;;; ## Clause Handlers
(defn- aggregation->honeysql
"Generate the HoneySQL form for an aggregation."
[driver aggregation-type field]
{:pre [(keyword? aggregation-type)]}
(if-not field
;; aggregation clauses w/o a field
(do (assert (= aggregation-type :count)
(format "Aggregations of type '%s' must specify a field." aggregation-type))
:%count.*)
;; aggregation clauses w/ a Field
(hsql/call (case aggregation-type
:avg :avg
:count :count
:distinct :distinct-count
:stddev (sql/stddev-fn driver)
:sum :sum
:min :min
:max :max)
(formatted field))))
(defn- expression-aggregation->honeysql
"Generate the HoneySQL form for an expression aggregation."
[driver expression]
(formatted (update expression :args (fn [args]
(for [arg args]
(cond
(number? arg) arg
(:aggregation-type arg) (aggregation->honeysql driver (:aggregation-type arg) (:field arg))
(:operator arg) (expression-aggregation->honeysql driver arg)))))))
(defn- apply-expression-aggregation [driver honeysql-form expression]
(h/merge-select honeysql-form [(expression-aggregation->honeysql driver expression)
(hx/escape-dots (annotate/expression-aggregation-name expression))]))
(defn apply-aggregation
"Apply a `aggregation` clauses to HONEYSQL-FORM. Default implementation of `apply-aggregation` for SQL drivers."
([driver honeysql-form {aggregations :aggregation}]
(loop [form honeysql-form, [{:keys [aggregation-type field]} & more] aggregations]
(let [form (apply-aggregation driver form aggregation-type (formatted field))]
(loop [form honeysql-form, [ag & more] aggregations]
(let [form (if (instance? Expression ag)
(apply-expression-aggregation driver form ag)
(let [{:keys [aggregation-type field]} ag]
(apply-aggregation driver form aggregation-type field)))]
(if-not (seq more)
form
(recur form more)))))
([driver honeysql-form aggregation-type field]
(h/merge-select honeysql-form [(if-not field
;; aggregation clauses w/o a field
(do (assert (= aggregation-type :count)
(format "Aggregations of type '%s' must specify a field." aggregation-type))
:%count.*)
;; aggregation clauses w/ a Field
(hsql/call (case aggregation-type
:avg :avg
:count :count
:distinct :distinct-count
:stddev (sql/stddev-fn driver)
:sum :sum
:min :min
:max :max)
field))
(h/merge-select honeysql-form [(aggregation->honeysql driver aggregation-type field)
;; the column alias is always the same as the ag type except for `:distinct` with is called `:count` (WHY?)
(if (= aggregation-type :distinct)
:count
......
......@@ -98,7 +98,6 @@
"Return an appropriate name for an expression aggregation, e.g. `sum + count`."
^String [ag]
(cond
;;
(instance? Expression ag) (let [{:keys [operator args]} ag]
(str/join (str " " (name operator) " ")
(for [arg args]
......@@ -106,7 +105,6 @@
(str "(" (expression-aggregation-name arg) ")")
(expression-aggregation-name arg)))))
(:aggregation-type ag) (name (:aggregation-type ag))
;; a constant like
:else ag))
(defn- expression-aggregate-field-info [expression]
......
......@@ -181,9 +181,9 @@
(def ^:private ExpressionOperator (s/named (s/enum :+ :- :* :/) "Valid expression operator"))
(s/defrecord Expression [operator :- ExpressionOperator
args :- [(s/cond-pre (s/recursive #'RValue)
(s/recursive #'Aggregation))]])
(s/defrecord Expression [operator :- ExpressionOperator
args :- [(s/cond-pre (s/recursive #'RValue)
(s/recursive #'Aggregation))]])
(def AnyField
"Schema for a `FieldPlaceholder`, `AgRef`, or `Expression`."
......
......@@ -65,7 +65,7 @@
;;; +------------------------------------------------------------------------------------------------------------------------+
;;; | MATH AGGREGATIONS |
;;; | EXPRESSION AGGREGATIONS |
;;; +------------------------------------------------------------------------------------------------------------------------+
(defmacro ^:private druid-query-returning-rows {:style/indent 0} [& body]
......
(ns metabase.query-processor-test.expression-aggregations-test
"Tests for expression aggregations."
(:require [expectations :refer :all]
[metabase.query-processor.expand :as ql]
[metabase.query-processor-test :refer :all]
[metabase.test.data :as data]
[metabase.test.data.datasets :as datasets, :refer [*engine*]]
[metabase.util :as u]))
;; sum, *
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 1211]
[2 5710]
[3 1845]
[4 1476]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/sum (ql/* $id $price)))
(ql/breakout $price)))))
;; min, +
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 10]
[2 4]
[3 4]
[4 20]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/min (ql/+ $id $price)))
(ql/breakout $price)))))
;; max, /
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 94]
[2 50]
[3 26]
[4 20]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/max (ql// $id $price)))
(ql/breakout $price)))))
;; avg, -
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
(if (= *engine* :h2)
[[1 55]
[2 97]
[3 142]
[4 246]]
[[1 55]
[2 96]
[3 141]
[4 246]])
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/avg (ql/* $id $price)))
(ql/breakout $price)))))
;; post-aggregation math w/ 2 args: count + sum
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 44]
[2 177]
[3 52]
[4 30]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/+ (ql/count $id)
(ql/sum $price)))
(ql/breakout $price)))))
;; post-aggregation math w/ 3 args: count + sum + count
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 66]
[2 236]
[3 65]
[4 36]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/+ (ql/count $id)
(ql/sum $price)
(ql/count $price)))
(ql/breakout $price)))))
;; post-aggregation math w/ a constant: count * 10
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 220]
[2 590]
[3 130]
[4 60]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/* (ql/count $id)
10))
(ql/breakout $price)))))
;; nested post-aggregation math: count + (count * sum)
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 506]
[2 7021]
[3 520]
[4 150]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/+ (ql/count $id)
(ql/* (ql/count $id)
(ql/sum $price))))
(ql/breakout $price)))))
;; post-aggregation math w/ avg: count + avg
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
(if (= *engine* :h2)
[[1 77]
[2 107]
[3 60]
[4 68]]
[[1 77]
[2 107]
[3 60]
[4 67]])
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/+ (ql/count $id)
(ql/avg $id)))
(ql/breakout $price)))))
;; post aggregation math + math inside aggregations: max(venue_price) + min(venue_price - id)
(datasets/expect-with-engines (engines-that-support :expression-aggregations)
[[1 -92]
[2 -96]
[3 -74]
[4 -73]]
(format-rows-by [int int]
(rows (data/run-query venues
(ql/aggregation (ql/+ (ql/max $price)
(ql/min (ql/- $price $id))))
(ql/breakout $price)))))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment