Create a Shiny App in R using bslib package to Market Basket Analysis.
Drop files here
or click to upload
import streamlit as st
import pandas as pd
import numpy as np
from mlxtend.frequent_patterns import apriori, association_rules
from mlxtend.preprocessing import TransactionEncoder
import plotly.graph_objects as go
import plotly.express as px
st.set_page_config(page_title="Market Basket Analysis", layout="wide")
st.title("Market Basket Analysis Tool")
# Sample data generation
def generate_sample_data():
np.random.seed(42)
products = ['Bread', 'Milk', 'Cheese', 'Eggs', 'Yogurt', 'Butter', 'Juice', 'Coffee', 'Tea', 'Sugar']
transactions = []
for _ in range(1000):
n_items = np.random.randint(1, 6)
transaction = list(np.random.choice(products, size=n_items, replace=False))
transactions.append(transaction)
return transactions, products
# Main function for market basket analysis
def perform_mba(transactions, min_support, min_confidence):
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_ary, columns=te.columns_)
# Generate frequent itemsets
frequent_itemsets = apriori(df, min_support=min_support, use_colnames=True)
# Generate association rules
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=min_confidence)
return frequent_itemsets, rules
# Sidebar controls
st.sidebar.header("Parameters")
min_support = st.sidebar.slider("Minimum Support", 0.01, 1.0, 0.02, 0.01)
min_confidence = st.sidebar.slider("Minimum Confidence", 0.01, 1.0, 0.2, 0.01)
# Generate or upload data
data_option = st.sidebar.radio("Choose Data Source", ["Use Sample Data", "Upload Data"])
if data_option == "Use Sample Data":
transactions, products = generate_sample_data()
else:
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
transactions = df.values.tolist()
products = list(set([item for sublist in transactions for item in sublist if pd.notna(item)]))
else:
st.warning("Please upload a CSV file or select 'Use Sample Data'")
st.stop()
# Perform market basket analysis
frequent_itemsets, rules = perform_mba(transactions, min_support, min_confidence)
# Display results in tabs
tab1, tab2, tab3 = st.tabs(["Frequent Itemsets", "Association Rules", "Visualizations"])
with tab1:
st.header("Frequent Itemsets")
st.dataframe(frequent_itemsets)
with tab2:
st.header("Association Rules")
st.dataframe(rules)
with tab3:
st.header("Visualizations")
# Scatter plot of support vs confidence
fig1 = px.scatter(rules, x="support", y="confidence",
size="lift", color="lift",
hover_data=["antecedents", "consequents"],
title="Support vs Confidence (size and color represent Lift)")
st.plotly_chart(fig1)
# Bar chart of top 10 frequent itemsets
top_itemsets = frequent_itemsets.nlargest(10, 'support')
fig2 = px.bar(top_itemsets, x=top_itemsets.index, y='support',
title="Top 10 Frequent Itemsets by Support")
st.plotly_chart(fig2)
# Network graph of association rules
def create_network_graph(rules, threshold):
filtered_rules = rules[rules['lift'] > threshold]
nodes = list(set(
[str(item) for items in filtered_rules['antecedents'] for item in items] +
[str(item) for items in filtered_rules['consequents'] for item in items]
))
edge_x = []
edge_y = []
for idx, row in filtered_rules.iterrows():
for ant in row['antecedents']:
for cons in row['consequents']:
x0, y0 = nodes.index(str(ant)) * 5, 0
x1, y1 = nodes.index(str(cons)) * 5, 5
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines')
node_x = [i * 5 for i in range(len(nodes))]
node_y = [0 if i < len(nodes)//2 else 5 for i in range(len(nodes))]
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
text=nodes,
textposition="top center",
marker=dict(
showscale=True,
colorscale='YlOrRd',
size=20,
colorbar=dict(
thickness=15,
title='Node Connections',
xanchor='left',
titleside='right'
)
))
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
title="Network Graph of Association Rules"
))
return fig
lift_threshold = st.slider("Lift Threshold for Network Graph",
min_value=float(rules['lift'].min()),
max_value=float(rules['lift'].max()),
value=float(rules['lift'].mean()))
network_fig = create_network_graph(rules, lift_threshold)
st.plotly_chart(network_fig)
Hi! I can help you with any questions about Streamlit and Python. What would you like to know?