How to Use Snowpark With Hex For Machine Learning

In recent years, the field of machine learning has been rapidly advancing, and organizations are seeking new ways to gain insights from their data. Snowpark, an open-source project from the Snowflake Data Cloud, enables users to write code in their preferred programming language. 

When combined with Hex, a data notebooking platform, Snowpark provides an efficient and flexible way to gain and share data-driven insights with Machine Learning. In this blog post, we will explore the capabilities of Snowpark with Hex and provide guidance on how to use these tools for your own machine-learning projects using a sample data set from Snowflake. 

Why Hex & Snowpark?

Notebooking with Jupyter has advantages and disadvantages. While incredibly popular, there are a few shortcomings when working with data. External databases are not natively easy to connect to, Snowpark compatible environments have to be built and maintained from scratch, not to mention the lack of easy versioning, collaboration tools… or even the opaque hidden global variable state. 

Enter Hex.

What is Hex?

Hex is a powerful and flexible notebooking environment with a ready-built Snowpark Python kernel. Hex also provides an easy connector with the Snowflake Data Cloud, making it an incredibly simple and powerful way to perform analysis, prototype, and deploy data logic running on Snowflake. 

Hex notebooks are hosted, can be developed collaboratively, and integrate with github for tracking. Even better, Hex automatically tracks cell I/O and constructs an execution DAG, allowing easy downstream/upstream cell reruns.

Now that we have a better understanding of what Hex is, let’s explore how to connect it to Snowflake!

How to Connect Hex to Snowflake

Connecting Hex to Snowflake is a pretty straightforward process. With Hex, you can open the data connections tab in the left bar.

data-connects

For Snowflake connections, you’ll need:

  • Account Name
  • Warehouse
  • Database
  • Username & Password
add-snowpark

Be sure to enable the Snowpark integration, and Writeback if you’d like to push data from the notebook to the Warehouse.

Accessing Data for Training

Based on the credentials used in the data connection, different databases will be available to you. To get started, you can use included Snowflake sample data, or add databases from the Snowflake marketplace. Some are freely available and fun to play around with, like the ThoughtSpot Fantasy Football Dataset.

nfl-data

Once databases are added to your Snowflake account, they can be explored in Hex with the Data sources tab.

all-sources

Exploratory Data Analysis with Hex and Snowpark

Using the Snowpark dataframe API, we can quickly explore the data. Here we’re exploring the play-by-play table from the Fantasy Football dataset. All of these commands are translated into SQL and pushed down to the Snowflake warehouse. Each filter statement builds part of the SQL statement which is executed when .show() is called. This is fast and efficient, and returns the data needed for the 10 rows displayed.

table
game-data

Leveraging Hex, we can make these explorations dynamic with no-code input cells. These cells allow us to easily change inputs to charts and avoid repeating display code unnecessarily. For example, let’s plot the touchdowns vs. touchdowns allowed calculated from the Fantasy Football tables.

no-code input cells

Instead of hardcoding our choice of offense and defense stat, we can make it dynamic with Hex and browse any combination we’re curious about.

TOUCHDOWNS_ALLOWED

We can also build charts without writing any Python. Using Hex’s chart cell, building charts is quite straightforward. Let’s make a bar chart of TOUCHDOWNS_ALLOWED, displaying the team and sorting by the amount.

sorting by the amount.

Training ML With Snowpark Via Hex

Once your data has been explored and transformed into an ML-ready shape, you can train an ML model using a stored procedure. It’s possible to train ML models using the Snowpark UDF API, but it’s a more niche use case. 

To learn more about it, check out this blog on CPG forecasting with Snowpark

If you’re not set up to train a model with a stored procedure (maybe you’d like to use a GPU), no worries!

We’re going to train an xgBoost model to predict the point spread of NFL games in our dataset. Stored procedures are declared by decorating a Python function. Within this function we can use any of our favorite data science packages. Notice that we’re going to first create a stage within Snowflake to keep the trained model for inference later.

				
					from snowflake.snowpark.types import Variant
from snowflake.snowpark.functions import sproc


hex_snowpark_session.sql("USE DATABASE USER_DB").show()
hex_snowpark_session.sql("USE SCHEMA USER_DB.NFL").show()
hex_snowpark_session.sql("CREATE OR REPLACE STAGE  model_stage").show()


hex_snowpark_session.add_packages(
   *["snowflake-snowpark-python", "scikit-learn", "joblib", "xgboost", "numpy"]
)




@sproc(
   session=hex_snowpark_session,
   name="train_xgb_nfl",
   is_permanent=True,
   stage_location="@model_stage",
   replace=True,
)
def train_xgb_nfl(session: Session, features_table: str) -> Variant:
   import xgboost as xgb
   from sklearn.model_selection import train_test_split
   import os
   from joblib import dump
   import numpy as np
   from sklearn.metrics import mean_squared_error as MSE
   from sklearn.metrics import r2_score as r2_score


   df_in = session.table(features_table)
   …train the model…

				
			

The rest of the function constructs an xgBoost regressor and trains it. After it is trained, the model artifact is transferred to the stage we created.

				
					model_file = os.path.join('/tmp', 'model.joblib')
dump(model, model_file)
session.file.put(model_file, "@model_stage",overwrite=True)


				
			

ML Inference With Snowpark Via Hex

Now that we have a trained model saved to the Snowflake stage, we can use the Snowpark UDF API to write a function that loads the model and produces a forecast on rows of data.

				
					import sys
import cachetools
from joblib import load
import pandas as pd
from snowflake.snowpark.types import PandasSeriesType, FloatType


@cachetools.cached(cache={})
def read_file(filename):
   import_dir = sys._xoptions.get("snowflake_import_directory")
   if import_dir:
       return load(os.path.join(import_dir, os.path.basename(filename)))


def predict_xgb_nfl_pandas(data: pd.DataFrame) -> pd.Series:
   import xgboost
   import pandas as pd


   model = read_file("model.joblib.gz")
   prediction = pd.Series(model.predict(data))
   return prediction


predict_xgb_nfl_pandas = F.pandas_udf(
   predict_xgb_nfl_pandas,
   name="predict_xgb_nfl_pandas",
   replace=True,
   is_permanent=True,
   stage_location="@model_stage",
   input_types=inference_types,
   parallel=10,
   max_batch_size=1000,
   return_type=PandasSeriesType(FloatType()),
   session=hex_snowpark_session,
   imports=["https://i0.wp.com/www.phdata.io/hex/model/model.joblib.gz"],
   packages=["xgboost==1.5.0", "cachetools"]
)

				
			

For extra performance, the model loading is decorated with cachetools. Now, serial calls of our UDF on the same Snowflake warehouse kernel won’t need to reload the model. Once this UDF is registered, it can be called just like any other SnowSQL function.

				
					hex_snowpark_session.sql("""
select user_db.nfl.predict_xgb_nfl_pandas(
      "YARDS_GAINED",
       "YARDS_ALLOWED",
       "YARDS_GAINED_AWAY",
       "YARDS_ALLOWED_AWAY",
       "AWAY_MONEYLINE",
       "HOME_MONEYLINE",
       div0("TOUCHDOWNS", "TOUCHDOWNS_AWAY"),
       div0("YARDS_GAINED", "YARDS_ALLOWED"),
       div0("YARDS_GAINED_AWAY", "YARDS_ALLOWED_AWAY"),
       "SACKS",
       "FUMBLES_LOST",
       "FUMBLES_RECOVERED",
       "FUMBLES_LOST_AWAY",
       "FUMBLES_RECOVERED_AWAY",
       "SACKS_FORCED",
       "SACKS_AWAY",
       "SACKS_FORCED_AWAY",
       "TOUCHDOWNS",
       "TOUCHDOWNS_AWAY",
       "TOUCHDOWNS_ALLOWED",
       "TOUCHDOWNS_ALLOWED_AWAY",
       "INTERCEPTIONS",
       "INTERCEPTIONS_FORCED",
       "HOME_TEAM_WIN_PERCENTAGE",
       "AWAY_TEAM_WIN_PERCENTAGE"
) from user_db.nfl.nfl_data
""")

				
			

Conclusion

Snowpark enables best-practice development in a familiar environment with incredible performance. Using Hex, we can quickly establish connections to data, analyze and investigate it, develop machine learning models, and then put them into operation for inference purposes.

Want to learn more? We’re hitting the road with Snowflake and giving hands-on labs around the US this Spring of 2023…Stay tuned to phData’s LinkedIn for more updates.

Can’t wait? Check out these resources and reach out to our Data Science and ML team!

Explore These Related Articles

More to explore

Data Coach is our premium analytics training program with one-on-one coaching from renowned experts.

Accelerate and automate your data projects with the phData Toolkit