November 10, 2021

How to Implement Association Rule Mining in Snowpark

By Charlie Isaksson

Have you ever wondered how recommendations popup as you click on products while shopping online? Think about that “frequently bought together” window you see when you go to check out. 

If so, you’re in luck. In this article, we introduce the algorithm commonly used by large e-commerce companies like Amazon to uncover associations between products—known as the association rule.

Throughout this post, we’ll provide a step-by-step implementation of the simplest versions of association rule mining in Snowpark (from the Snowflake Data Cloud) for market basket analysis. The idea is to give you the building blocks to implement more advanced association rules in a framework that allows you to scale your algorithm—all while still using the same infrastructure as your Snowflake data warehouse. 

What is Association Rule Mining?

Association Rule Mining is a technique used to analyze retail baskets or transaction data. The rules are represented in the if/then logic. They aid in discovering frequent patterns, associations, correlations, or associations from datasets found in various kinds of databases.

The rule can be split into two parts:

  • Antecedent (If): Is the item found in the dataset?
  • Consequent (Then): Is the item that is found in relationship with the antecedent item?

A typical example of an association rule on market-basket data is the beer and diapers story. That is: If a customer purchases diapers, then 80 percent of them are likely to purchase beer. This can help retailers make more intelligent placements of their products in order to increase revenue and improve the customer experience, even if items aren’t associated on the surface level.

Association rules are created by identifying the frequency of the if/then patterns. There are two relationships observed:

  • Support: Indicates the frequency of the if/then relationships occurring into the database. Mathematically, we can define it as:
  • Confidence: The number of times these relationships have been found to be true. Defined as:

Consequently, in a given transaction with multiple items, Association Rule Mining tries to find the rules that govern how or why such products/items are often bought together. For example, peanut butter and jelly are frequently purchased together because a lot of people like to make PB&J sandwiches.

Over the years, data scientists have introduced various statistical algorithms to implement Association Rule Mining. 

There are many methods to perform association rules (this paper provides a good overview of different association rule approaches). One such algorithm is the Apriori, often considered to be one of the simplest association rule algorithms. The clue is in its name, “Apriori,” which refers to its lack of requirement for prior knowledge of frequent itemsets.     

Exploratory Data Analysis

The dataset used in this blog post is the publicly-available online retail transnational data (available here). The dataset ranges from 2010-12-01 to 2011-12-09; about one year. We found from Figure 1 that October, November, and December have the highest sales, where November experienced the highest sales peak with a 88.6 percent increase over other months.         

A bar graph the demonstrates sales over month

Figure 1: Shows the sales over a month from all countries.  

The next insight from Figure 2 shows that the highest average amount paid by customers is in the Netherlands, Australia, and Lebanon while the lowest amount paid is in Saudi Arabia.

A bar graph that shows the average amount paid by the customer over all countries

Figure 2: Shows the average amount paid by the customer over all countries.

Now for the associate rule: in Figure 3, we only look at highly frequent items (specifically, items with more than 1,000 transactions). The highlighted items are the most frequent and its link’s associations. In reality, the items are targeted based on the customer’s interest.              

A cluster graphic that shows the association between different items

Figure 3: Shows the association between different items.  

Association Rules for Online Retailer on Snowflake with Snowpark

For large sets of data, there can be hundreds of items in hundreds of thousands transactions. The Apriori algorithm tries to extract rules for each possible association of items.

For example, using a “brute force” approach to associate rules would require:  

  1. Get a list of all frequent itemsets of all sizes
  2. Compare every itemset with every other itemset
  3. If any of the itemset pairs are subset from the superset relation, then we can form an association rule

As you can see from the above example, this process can be extremely slow due to the number of combinations.

Two important properties  improve the efficiency of level-wise generation of frequent itemsets and reduce the search space:

  1. Upward closure: If an itemset I is infrequent, then any superset of itemset I can not be frequent.
  2. Downward closure: If an itemset I is frequent, then all subsets of itemset I are also frequent.

For instance, if [Bread, Peanut Butter] are a frequent itemset, then Bread and Peanut Butter must individually be frequent items. On the contrary, if either one of Bread or Peanut Butter are infrequent items, then the set of [Bread, Peanut Butter] can not be frequent.

Association Rule Mining can be described in two-steps:

  1. Find all possible frequent itemsets  (an itemset is a set of items that occurs in a shopping basket)
  2. Find all possible association rules

Find All Possible Frequent Itemsets

Before using any rule mining algorithm, we need to load the data by first creating a database, then a table in Snowflake. The below code will create a table with schema:   

				
					CREATE OR REPLACE TABLE <database name>.<schema name>.RETAIL(
  InvoiceNo INT NOT NULL,
  StockCode INT NOT NULL,
  Description VARCHAR(500) NOT NULL,
  Quantity INT,
  InvoiceDate VARCHAR(200),
  UnitPrice DOUBLE,
  CustomerID DOUBLE,
  Country VARCHAR(200)
);
				
			

The retail database should yield the output from Figure 4.

A screenshot showing a sample from retail in the Snowflake database.

Figure 4: Shows a sample from retail in Snowflake database. 

Now that we have the data in Snowflake, we can turn to our step-by-step implementation of the Apriori algorithm.  

How to Implement the Apriori Algorithm in Snowpark

Step 1

Prepare retail transaction data by suppressing the duplicated combination of invoiceno and stockcode. The SQL code below drops the table “rtl_transaction_unique_table” (if it exists), then creates it before uniquely creating combination between invoiceno and stockcode:

				
					DROP TABLE <database name>.<schema name>.rtl_transaction_unique_table;
CREATE TABLE rtl_transaction_unique_table as
SELECT INVOICENO, STOCKCODE, count(*) as num_rows 
FROM <database name>.<schema name>.RETAIL
GROUP BY INVOICENO, STOCKCODE;
				
			

The above code creates the below output from Figure 5: 

A screenshot that shows a sample from rtl_transaction_unique_table in the Snowflake database.

Figure 5: Shows sample from rtl_transaction_unique_table in Snowflake database.

Step 2

Sort down to relevant associations with two substeps: 

  1. Find frequent itemsets at first level    
  2. Filter high frequency items that are greater than 1000

The code below starts by removing the table “product_baskets_table” if it exists, then creates a new table with unique stockcode named product_id and its frequency as num_baskets:

				
					DROP TABLE <database name>.<schema name>.product_baskets_table;
CREATE TABLE <database name>.<schema name>.product_baskets_table AS
SELECT STOCKCODE as product_id, count(*) as num_baskets 
FROM <database name>.<schema name>.rtl_transaction_unique_table
GROUP BY STOCKCODE;

SELECT * FROM <database name>.<schema name>.product_baskets_table;
				
			

Figure 6 shows the output from the above SQL code: 

A screenshot that shows a sample from product_baskets_table in the Snowflake database.

Figure 6: Shows sample from product_baskets_table in Snowflake database.

The second SQL code filters any items that have item frequency greater than 1,000 based on the num_baskets attribute. It also creates a new table freq_itemsets_table:

				
					DROP TABLE <database name>.<schema name>.freq_itemsets_table;
CREATE TABLE <database name>.<schema name>.freq_itemsets_table as 
SELECT to_varchar(product_id) as itemset, 1 as size, num_baskets as support
FROM <database name>.<schema name>.product_baskets_table
WHERE num_baskets > 1000;
				
			

We can see that Figure 7 contains an itemset at first level, denoted as size and the num_baskets variable renamed as support.

A screenshot that shows a sample from freq_itemsets_table in the Snowflake database.

Figure 7: Shows sample from freq_itemsets_table in Snowflake database.

Step 3

Transform retail transaction data in order to find frequent itemsets at a higher level, as below:

				
					DROP TABLE <database name>.<schema name>.rtl_basket_transformed_table;
CREATE TABLE <database name>.<schema name>.rtl_basket_transformed_table AS
SELECT a.INVOICENO, array_to_string(array_agg(a.STOCKCODE),'_') as basket
FROM (SELECT INVOICENO, STOCKCODE FROM rtl_transaction_unique_table ORDER BY INVOICENO, STOCKCODE) a
GROUP BY a.INVOICENO;

SELECT * FROM <database name>.<schema name>.rtl_basket_transformed_table;
				
			

Snowflake provides a considerable collectionade of useful functions, one of them is the  array_agg that creates a list with all the items from the unique transitions. The array_to_string converts each list to a string, so we can use the Snowpark UDF function. Figure 8 shows the output from the code above SQL code.

A screenshot that shows a sample from rtl_basket_transformed_table in the Snowflake database.

Figure 8: Shows sample from rtl_basket_transformed_table in Snowflake database.

Step 4

Prune transformed retail transaction basket using the first level frequent itemsets.

In general, any level of combination needs to undergo pruning. Pruning removes infrequent itemsets before finding actual support. For instance: let’s say we’ve generated combinations of level three.  Any given combination in the third level cannot be frequent if any of the second level subsets is infrequent.  Such combinations can be removed from the list without aggregating and finding actual support. This will substantially improve the performance by reducing the search space.

For example, T1 = [2,3,4,66,88] is a transformed retail transaction basket. [2,3,4,88,89,99] is a list of first level frequent itemsets, then pruned transformed retail transaction baskets yield this list: [2,3,4,88]. Basically, the intersection between the two lists. Snowflake provides such a function: array_intersection that takes the pruned basket p.basket and first level frequent items p.size1_array (performed earlier).

One thing to highlight from the below SQL code: the use of CROSS JOIN, which is less than ideal. This can be avoided once Snowflake provides a broadcast functionality similar to Apache Spark. Broadcast enables sharing of variables across executors.

				
					DROP TABLE <database name>.<schema name>.pruned_basket_table;
CREATE TABLE <database name>.<schema name>.pruned_basket_table as
SELECT p.INVOICENO, array_intersection(p.basket, p.size1_array) as pruned_basket, array_size(pruned_basket) as pruned_basket_size, array_size(p.basket) as original_basket_size, p.basket as original_basket
FROM (
SELECT t.INVOICENO, split(t.BASKET,'_') as basket, fq.size1_array as size1_array
FROM <database name>.<schema name>.rtl_basket_transformed_table t
CROSS JOIN 
(
SELECT array_agg(to_varchar(itemset)) as size1_array FROM <database name>.<schema name>.freq_itemsets_table 
WHERE size = 1) fq) p;

SELECT * FROM <database name>.<schema name>.pruned_basket_table;
				
			

We can notice from Figure 9 that the pruned_basket attribute contains an empty list, that is due to the pruning process we did in the above code snippet.

A screenshot that shows a sample from pruned_basket_table in the Snowflake database.

Figure 9: Shows sample from pruned_basket_table in Snowflake database.

Step 5

The pruned basket can be used for finding all higher level combinations. Until now, we’ve been using SQL right in Snowflake to find frequent items and pruned baskets. That is feasible up to first level frequent itemsets. However, higher level combinations can be efficiently performed in Snowpark. This is due to the recursive nature of the combination algorithm used.

What follows is a series of Scala code that computes higher level combinations. We start by computing the second level combinations:

				
					val freqItemSets = session.read.table("freq_itemsets_table")
var prunedBasket = session.read.table("pruned_basket_table")

var prunedBasketFreqItem1: DataFrame =  prunedBasket.crossJoin(freqItemSets.filter(col("SIZE") === 1)
 .agg(array_agg(col("ITEMSET")).as("freq_itemsets")))
 .select(Seq[String]("INVOICENO", "PRUNED_BASKET", "FREQ_ITEMSETS"))

def sortFreqItemList = udf((fItems: Array[String]) => {
 val freqItemsTmp: Array[String] = fItems.map(_.toInt).sorted.map(_.toString)
 freqItemsTmp
})
def sortPrunedBasketList = udf((pBasket: Array[String]) => {
 val combTmp: Array[String] = pBasket.map(_.toInt).sorted.map(_.toString)
 combTmp
})

prunedBasketFreqItem1 = prunedBasketFreqItem1.withColumn("FREQ_ITEMSETS",
 sortFreqItemList(col("FREQ_ITEMSETS")))

prunedBasketFreqItem1 = prunedBasketFreqItem1.withColumn("PRUNED_BASKET",
 sortPrunedBasketList(col("PRUNED_BASKET")))

prunedBasketFreqItem1.write.mode(SaveMode.Overwrite).saveAsTable("PRUNED_BASKET_2")
				
			

The code above reads frequent items and pruned baskets from Snowflake as a Snowpark DataFrame. Then we use the crossJoin function to combine the two DataFrames. Finally, we sort each list for the freq_itemsets and pruned_basket attributes and save the output as a table in Snowflake: pruned_basket_2.

The next code segment below starts by reading the pruned_basket_2 table that we saved in the Snowflake table from the code above. We then use a user-defined function (UDF) transLevelTwoUDF to generate the second level combinations from each basket. Next, we filter out all the empty lists that exist from Step 5 in table pruned_basket_table. We then save the output to the Snowflake table as pruned_basket_2_0.

				
					val prunedBasketFreqItem1 = session.read.table("PRUNED_BASKET_2")

val transLevelTwoUDF = udf((pBasket: Array[String], R: String) => {
 val v: List[Int] = pBasket.map(_.toInt).sorted.toList

 val accumulator = new Accumulator[Int]
 val nonRepeatingCombPull = new NonRepeatingCombPull[Int](v, R.toInt)
 PermComb.constructPushGenerator(accumulator.push, nonRepeatingCombPull.iterator)

 accumulator.listAccu.map(_.mkString("_")).toArray
})

val levelTwoDf = prunedBasketFreqItem1.withColumn("combination",
 transLevelTwoUDF(col("PRUNED_BASKET"), lit("2")))

val levelTwoDf_clean = levelTwoDf.filter(array_size(col("COMBINATION")) > 0)
levelTwoDf_clean.write.mode(SaveMode.Overwrite).saveAsTable("PRUNED_BASKET_2_0")
				
			

From Figure 10, we can see the output from the above code. Notice the column combination that creates the second level combinations pairs created from the column pruned_basket.

A screenshot that shows a sample from pruned_basket_2_0 in the Snowflake database.

Figure 10: Shows sample from pruned_basket_2_0 in Snowflake database.

Similar to step 4, we need to prune the second level of transformed retail transactions.

Notice, we also use the UDF function to get the intersection from both the combination and the freq_itemsets columns. Let us reiterate the most important point. From the above code, we needed to do crossJoin to merge in the freq_itemsets. This can be avoided once Snowflake adds the Broadcast functionality to be able to share freq_itemsets DataFrame across executors. We then remove any empty list that didn’t form any associations. Finally, we save the output to the Snowflake table pruned_basket_2_1.

				
					val levelTwoDf = session.read.table("PRUNED_BASKET_2_0")

val transLevelTwoPrunedUDF = udf((fItem: Array[String], comb: Array[String], R: String) => {
 val combTmp: Array[ArrayBuffer[Int]] = comb.map(_.split("_").map(_.toInt).sorted.to[ArrayBuffer])
 val fItemTmp: Array[Int] = fItem.map(_.toInt).sorted

 val res:Array[ArrayBuffer[Int]] = combTmp.filter(x => {
    x.intersect(fItemTmp).length.==(R.toInt)
 })
 res.map(_.mkString("_"))
})

val levelTwoPrunedDf = levelTwoDf.withColumn("FREQ_ITEMSETS_FILTER",
 transLevelTwoPrunedUDF(col("FREQ_ITEMSETS"), col("COMBINATION"), lit("2")))

val levelTwoPrunedDf_clean = levelTwoPrunedDf.filter(array_size(col("FREQ_ITEMSETS_FILTER")) > 0)

levelTwoPrunedDf_clean.write.mode(SaveMode.Overwrite).saveAsTable("PRUNED_BASKET_2_1")
				
			

The above code will create a new column freq_itemsets_filter with all the pruned items. (See the output from Figure 11).

A screenshot that shows a sample from pruned_basket_2_1 in the Snowflake database.

Figure 11: Shows sample from pruned_basket_2_1 in Snowflake database.

The only step left to complete level two retail transactions is to put it in a form that aligns with the table freq_itemsets_table that we created in step 2. To do that, we use the flatten function to unpack each row list into one column that we add frequency of one. Then we use the groupBy operation to sum up all the distinct itemsets.

We also add a new column to indicate the second level. Finally, we merge the prepared level two dataframe freqItemLevelTwo with the existing level one DataFrame freqItemSets. Then we save the merged freqItemsLevelTwo DataFrame to Snowflake table freq_itemsets_table.

				
					val levelTwoPrunedDf_clean = session.read.table("PRUNED_BASKET_2_1")

val freqItemLevelTwo: DataFrame =  levelTwoPrunedDf_clean
 .flatten(col("FREQ_ITEMSETS_FILTER"))
 .select( trim(col("VALUE"), lit("\"")).as("ITEMSET"))
 .withColumn("feq", lit(1))
 .groupBy("ITEMSET").agg(sum(col("feq")).as("SUPPORT"))
 .withColumn("SIZE", lit(2))
 .select(col("ITEMSET"), col("SIZE"), col("SUPPORT"))

val freqItemsLevelTwo:DataFrame = freqItemSets.union(freqItemLevelTwo)
freqItemsLevelTwo.write.mode(SaveMode.Overwrite).saveAsTable("freq_itemsets_table")
				
			

Similar to Figure 7, we merged in level two itemsets to the existing level one table freq_itemsets_table. (See Figure 12 for output from the above code.)

A screenshot of the Snowflake database

Figure 12: Shows sample from freq_itemsets_table in Snowflake database.

We repeat the above steps for the third level and so on, until level n is achieved, where n = 1,2,3,4,…k. The code below is similar to the code above with the exception of generating higher order combinations.

				
					def sortPrunedBasket = udf((pBasket: Array[String]) => {
 val combTmp: Array[String] = pBasket.map(_.toInt).sorted.map(_.toString)

 combaT mp
})
prunedBasket = prunedBasket.withColumn("PRUNED_BASKET", sortPrunedBasket(col("PRUNED_BASKET")))

var prunedBasketFreqItem: DataFrame =  prunedBasket.crossJoin(freqItemSets.filter(col("SIZE") === 2)
 .agg(array_agg(col("ITEMSET")).as("freq_itemsets")))
 .select(Seq[String]("INVOICENO", "PRUNED_BASKET", "FREQ_ITEMSETS"))

def sortFreqItemSets = udf((fItems: Array[String]) => {
 val freqItemsTmp: Array[String] = fItems.map(_.split("_").map(_.toInt).sorted).map(_.mkString("_"))
 freqItemsTmp
})

prunedBasketFreqItem = prunedBasketFreqItem.withColumn("FREQ_ITEMSETS",
 sortFreqItemSets(col("FREQ_ITEMSETS")))

prunedBasketFreqItem.write.mode(SaveMode.Overwrite).saveAsTable("PRUNED_BASKET_3_0")
				
			

The above code segment does the merge between pruned basket and frequent itemsets DataFrames then sorts both attributes before saving the results into the Snowflake table pruned_basket_3_0. Figure 13 yields the output from the above code.

A screenshot of the Snowflake database

Figure 13: Shows sample from pruned_basket_3_0 in Snowflake database.

Next, we generate the third level combinations based on the pruned_basket attribute. Again, we filter out empty lists and save the output to the Snowflake table pruned_basket_3_1.

				
					val prunedBasketFreqItem = session.read.table("PRUNED_BASKET_3_0")

val transformationUDF = udf((pBasket: Array[String], R: String) => {
 val v: List[Int] = pBasket.map(_.toInt).sorted.toList

 val accumulator = new Accumulator[Int]
 val nonRepeatingCombPull = new NonRepeatingCombPull[Int](v, R.toInt)
 PermComb.constructPushGenerator(accumulator.push, nonRepeatingCombPull.iterator)

 accumulator.listAccu.map(_.mkString("_")).toArray
})

val tmp = prunedBasketFreqItem.withColumn("combination",
 transformationUDF(col("PRUNED_BASKET"), lit("3") ))

val tmp_clean = tmp.filter(array_size(col("COMBINATION")) > 0)
tmp_clean.write.mode(SaveMode.Overwrite).saveAsTable("PRUNED_BASKET_3_1")
				
			

The output from above is shown in Figure 14, with a new column combination having each row containing a combination of level three.

A screenshot that shows a sample from pruned_basket_3_1 in the Snowflake database.

Figure 14: Shows sample from pruned_basket_3_1 in Snowflake database.

Step 6

Filter or select frequent itemsets of level n that satisfy required support level.

The UDF function basically takes both the freq_itemsets and combination attributes (see Figure 14) because both contain a list of level two item pairs we need to iterate and split before computing the intersection between them. We remove any empty list and save the outcome to the Snowflake table pruned_basket_3_2.

				
					val tmp = session.read.table("PRUNED_BASKET_3_1")

val transformation2UDF = udf((fItem: Array[String], comb: Array[String], R: String) => {
 val combTmp: Array[String] = comb.map(_.split("_").map(_.toInt).sorted.mkString("_"))
 val fItemTmp: Array[String] = fItem.map(_.split("_").map(_.toInt).sorted.mkString("_"))

 val res:Array[String] = combTmp.filter(x => {
   val xx: List[Int] = x.split("_").map(_.toInt).sorted.toList

   val accumulator = new Accumulator[Int]
   val nonRepeatingCombPull = new NonRepeatingCombPull[Int](xx, R.toInt)
   PermComb.constructPushGenerator(accumulator.push, nonRepeatingCombPull.iterator)

   val combinationList: ArrayBuffer[String] = accumulator.listAccu.map(_.mkString("_")).to[ArrayBuffer]
   combinationList.intersect(fItemTmp).length.==(R.toInt+1)
 })
 res
})

val tmp2 = tmp.withColumn("FREQ_ITEMSETS_FILTER",
 transformation2UDF(col("FREQ_ITEMSETS"), col("COMBINATION"), lit("2")))

val pruned_basket_level3 = tmp2.filter(array_size(col("FREQ_ITEMSETS_FILTER")) > 0)
pruned_basket_level3.write.mode(SaveMode.Overwrite).saveAsTable("PRUNED_BASKET_3_2")
				
			

We look at Figure 15 to see the result of the freq_itemsets_filter attribute with third level combinations. This process keeps repeating for every n level. We can optimize the above implementation, however we intentionally extend this in favor of interpretability.

One improvement will be to immediately ignore the combination pair as we create them rather than first generating all the combinations, then performing pruning just to remove non intersecting itemsets.   

A screenshot that shows a sample from pruned_basket_3_2 in the Snowflake database.

Figure 15: Shows sample from pruned_basket_3_2 in Snowflake database.

Finally, we read table pruned_basket_3_2 from above and then put it in a form that aligns with table freq_itemsets_table that we created in step 2.

				
					val pruned_basket:DataFrame = session.read.table("PRUNED_BASKET_3_2")
 .flatten(col("FREQ_ITEMSETS_FILTER"))
 .select( trim(col("VALUE"), lit("\"")).as("ITEMSET"))
 .withColumn("feq", lit(2))
 .groupBy("ITEMSET").count()
 .withColumn("SIZE", lit(3))
 .select(col("ITEMSET"), col("SIZE"),
   col("COUNT").as("SUPPORT"))

val freqItems:DataFrame = freqItemSets.union(pruned_basket)
				
			

To help understand the strength of association, we need to compute its Confidence. The measure gives an idea of how frequent an itemset is in all the transactions. The first part from Equation 2 is to compute the “ Transactions containing both X and Y.” The code below computes exactly that. With crossJoin between level two and three, we can compute the array_intersection between the superset_array and itemset_array attributes, then compute the difference between size of the intersection and the itemset_array.

				
					val a = freqItems.filter(col("SIZE") === lit("2")).select(Seq("ITEMSET", "SUPPORT"))
val b = freqItems.filter(col("SIZE") === lit("3")).select(Seq("ITEMSET", "SUPPORT"))

val A = a.withColumn("itemset_array", split(col("ITEMSET"), lit("_")))
 .withColumn("itemset_support", col("SUPPORT"))

val B = b.withColumn("superset_array", split(col("ITEMSET"), lit("_")))
 .withColumn("superset_support", col("SUPPORT"))

val fin = A.crossJoin(B)
val fin_out = fin.where((array_size(array_intersection(fin.col("SUPERSET_ARRAY"), fin.col("ITEMSET_ARRAY"))) - array_size(fin.col("ITEMSET_ARRAY"))) === 0 )
				
			

The below output is a sample retrieved from the above code: 

				
					-------------------------------------------------------------------------------
|"ITEMSET_ARRAY"  |"ITEMSET_SUPPORT"  |"SUPERSET_ARRAY"  |"SUPERSET_SUPPORT"  |
-------------------------------------------------------------------------------
|[                |161                |[                 |98                  |
|  "365",         |                   |  "365",          |                    |
|  "934"          |                   |  "934",          |                    |
|]                |                   |  "1037"          |                    |
|                 |                   |]                 |                    |
|[                |161                |[                 |51                  |
|  "365",         |                   |  "365",          |                    |
|  "934"          |                   |  "934",          |                    |
|]                |                   |  "2476"          |                    |
|                 |                   |]                 |                    |
|[                |161                |[                 |51                  |
|  "365",         |                   |  "365",          |                    |
|  "934"          |                   |  "934",          |                    |
|]                |                   |  "1281"          |                    |
|                 |                   |]                 |                    |
|[                |161                |[                 |47                  |
|  "365",         |                   |  "365",          |                    |
|  "934"          |                   |  "934",          |                    |
|]                |                   |  "1986"          |                    |
|                 |                   |]                 |                    |
|[                |161                |[                 |60                  |
|  "365",         |                   |  "103",          |                    |
|  "934"          |                   |  "365",          |                    |
|]                |                   |  "934"           |                    |
|                 |                   |]                 |                    |
-------------------------------------------------------------------------------
				
			

Next, we compute the confidence by dividing superset_support by itemset_support:

				
					val assoDf = fin_out.withColumn("confidence", col("SUPERSET_SUPPORT")/col("ITEMSET_SUPPORT"))
 .select(array_to_string(col("SUPERSET_ARRAY"), lit("_")).as("superset"),
   array_to_string(col("ITEMSET_ARRAY"), lit("_") ).as("itemset"), col("confidence"))
 .sort(col("CONFIDENCE").desc)
				
			

The above code yields the following output:

				
					---------------------------------------------
|"SUPERSET"      |"ITEMSET"  |"CONFIDENCE"  |
---------------------------------------------
|934_1207_3559   |934_1207   |0.830357      |
|1207_1752_3559  |1207_1752  |0.798561      |
|1207_1226_3559  |1207_1226  |0.789588      |
|836_1207_3559   |836_1207   |0.785714      |
|1207_1467_3559  |1207_1467  |0.775148      |
---------------------------------------------
				
			

Find All Possible Association Rules

Once all frequent itemsets are found, you can compute associate rules. But before we dive into association rules formulation, we want to clarify one thing: 

Some text in a square, dotted line that presents a possible association rule

The above figure (see full list from Figure 16) presents a possible association rule.

It consists of an itemset (antecedent) and a consequent_itemset (consequent), both of which are a list of all the items in the superset. The table in Figure 16 also contains a confidence column that helps us understand the strength of association between these two. The confidence is 83.04%, which indicates the conditional probability of occurrence of jumbo bag red retrospot given {paper chain kit 50’s christmas and jumbo bag pink polkadot}. With this information, the store owner can place different types of bags (and even paper chain kits) to increase the sales.

To generate all these rules, use the code below along  with the help of a UDF function. The function takes in both the superset and itemset and returns the items in the superset after removing all items in the itemset.

				
					val assoRuleUDF = udf((superSetStr: String,itemSetStr:String) => {
 val superSetList = superSetStr.split("_").toList.map(x=>x.toInt)
 val itemSetList = itemSetStr.split("_").toList.map(x=>x.toInt)
 val superSet = superSetList.sortWith(_ < _).toSet
 val itemSet = itemSetList.sortWith(_<_).toSet

 val diffList = superSet.diff(itemSet).toList

 diffList.mkString("_")
})

val assoRuleDF = assoDf.withColumn("CONSEQUENT_ITEMSET",
 assoRuleUDF(col("SUPERSET"), col("ITEMSET")))

assoRuleDF.sort(Seq(col("CONFIDENCE").desc )).write.mode(SaveMode.Overwrite).saveAsTable("AssoRuleResult_2_3")
				
			

Figure 16 shows the output from the above code. The table includes one column confidence that indicates the likeliness of occurrence of consequent (consequent_itemset) on the cart given that the cart already has the antecedents (Itemset) from all items from the Superset.

A dataset that shows samples from possible association rules.

Figure 16: Shows sample from possible association rules.

Summary

In this post, we have learned how to perform “Market Basket Analysis” in Snowpark, how to implement the Apriori algorithm as an associate rule, and how to interpret the results.

The motivation behind this post is to demonstrate the full implementation using Snowpark and to show how we can efficiently handle thousands of different products. Generally, just 10 products can generate 57,000 rules, a number which increases exponentially with the increase in the number of items. We have shown that we can filter out infrequent items to reduce the number of items to generate combinations, thus improving the efficiency.  

If your team is interested in learning more about ML models running in Snowflake, feel free to reach out to the phData ML team. We’re here to help! 

Special thanks to Mandar Kale for his contributions to this post!

Data Coach is our premium analytics training program with one-on-one coaching from renowned experts.

Accelerate and automate your data projects with the phData Toolkit