Python: An analysis of a dataset from Kaggle
AbdulSamad Olagunju / July 15, 2021
14 min read
Tutorials
I won't lie. The first time I looked at some Python, it seemed daunting. As you know, Python is a programming language. There are many resources online that can help you gain experience with Python. I suggest starting here: PyForNeuro
Take your time working through more and more problems with code, and it will become easier over time just like any other skill.
So, back to what this article is actually about. We're going to take a look at a Kaggle dataset about medical insurance.
Let's begin! Setting yourself up.
I like to use spyder to perform my analyses. Spyder is an application you can use to write your code. To download spyder, I just downloaded anaconda. This video explains it pretty well for windows 10 users: Install Anaconda Python, Jupyter Notebook And Spyder on Windows 10
So, the first thing I like to do is download the dataset I would like to look into. For now, I mainly use CSV files.
CSV stands for comma-separated values. A CSV file is a text file that allows you to save your data in a table. You can open your CSV files with Excel. To easily access CSV files in my code, I created a repository in my Github account that contains the CSV files I downloaded from Kaggle.
To create a repository on Github, first you must create an account on Github. It should be smooth sailing from there, and I don't think I need to go further into that. Email me if you have any problems.
Now, look through Kaggle and download the dataset you like (Unfortunately, Github only allows you to upload CSV files that are less than 25mb). If this is your first time using Kaggle, just click on the link below called Medical Insurance Kaggle Dataset so you can easily get started.
Here is the link to the dataset: Medical Insurance Kaggle Dataset
Now that you have spyder and the CSV file you need in your repository, click on the CSV file on Github. Then, click raw. Copy the URL, and we are ready to get to work.
This video explains getting the Kaggle url if you want very detailed instructions: Create URL LINK for your DATA from CSV files
So, I will briefly explain how I will analyze this insurance data. Firstly, I will write some code (in Python) to describe the variables in the dataset, standard deviations, etc. Then I will use graphs such as histograms, bar plots and correlation matrices to see what the data reveals. For you to learn this easily, you should have spyder open and ready to go on your computer.
Before I start analyzing the code, I need to import the libraries I will use. In programming, you don't reinvent the wheel every time you want to accomplish a task. You use the tools already available, and that is what I will do here.
Here is the first bit of code. On spyder, I created a new file called insurance.py. A file that ends in .py will contain Python.
#import libraries needed
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats
import random
#don't want to see warnings on the console
#about older functions like distplot
import warnings
warnings.filterwarnings('ignore')
#pass in dataframe (df) csv file from github, check out the raw
df = pd.read_csv(".../csv_files/main/insurance.csv")
The first line uses commented out code (starts with #). This is code that will not be run by the compiler (this is what is turning what the english we write on the screen into a language the computer understands). It is good to use comments so the people who read your code understand what is going on.
The import statements allow me to use code others have already written in my own program. For example, I use pandas (short form pd in my code) to read the data from the CSV file. I use seaborn (short form sns in my code) to generate most of my plots, and stats to do some hypothesis testing. You have the freedom of choosing what libraries you want to use in your code, so feel free to google how you can incorporate them in your code.
I create a variable called df(stands for dataframe), and I read the CSV file containing insurance information from Kaggle into my variable df. To get the URL, watch this video: Create URL LINK for your DATA from CSV files
You see how in my libraries, I used import pandas as pd? Well, pandas is what I use to read information from CSV files into my variable df (pd is a short form for pandas, I tell the computer that when I write: import pandas as pd). pandas is a data analysis tool that makes data analysis much easier. pd.read_csv() will take the URl where your CSV file is located and write it into your variable.
Now that I have the data, I need to think about how I will analyze it. Firstly, I want to know the labels of the columns and the rows. I need to have an idea of how my data table looks like, and what it represents.
#let's see total number of (rows,columns)
print ("Number of Rows : " , df.shape[0])
print ("Number of Columns : " , df.shape[1])
#take a look at column names
print ("\nColumn Labels : \n", df.columns.tolist())
print("\n")
#let's quickly look at a summary of the dataset
print(df.info())
#lets use .nunique to see categorical and numerical data
print ("\nUnique values : \n", df.nunique())
#let's take a quick look at a total description of the dataset
print(df.describe(include="all"))
#take a look at first 5 rows of data
print(df.head())
-
Now, using df.shape, we know there are 1338 rows and 7 columns.
-
df.columns.tolist() gives us the names of the columns as a list.
-
df.info shows us a table that gives us more detail about the kind of data we have.
-
df.nunique() shows the number of unique values for each column. This helps tell us if a column has a categorical or numerical variable (eg. we can see smoker is probably categorical as it has 2 unique values, bmi is numerical as it has 548).
-
df.describe gives us a very detailed table about the attributes of our data. df.head shows us the first 5 rows of the data.
We know that the data contains information from people about their age, sex, insurance charges, bmi, etc. From this, we can start asking ourselves questions about how we would like to analyze the data. Questions like: Can we see whether people with higher bmi's have higher insurance charges? Do certain variables like whether or not someone is a smoker have a significant effect on their insurance charges? What kind of bmi makes somebody obese, and what effect does this have on their insurance charge?
We want to start thinking about these questions before we deal with the data. This will allow us to think about what kind of figures we would like to create from the data.
A Little Bit of Cleanup
#check for missing data
print("\nMissing Values: " )
print(df.isnull().sum())
#lets check for duplicate rows
number_dup_rows = len(df) - len(df.drop_duplicates(keep="first"))
if (number_dup_rows != 0):
print(f"\nNumber of duplicate rows found : {number_dup_rows}\nDuplicates removed!")
df.drop_duplicates(keep="first",inplace=True)
else:
print("No duplicate entries found!")
Before you analyze your code, you need to make sure it is clean. You don't want any missing values or duplicate rows. This is very basic data pre-processing. With datasets, you often need to clean them up much more before you can use them (taking care of outliers, etc.).
This code will take a look at whether or not we have any missing values, and whether we have duplicate rows. df.isnull().sum() will show if there are any missing values in any of the columns.
We then use len(df) to check the difference of the length of the dataframe after we use df.drop_duplicates to remove duplicate rows. Using this, we see that we need to remove 1 duplicate row.
It is important to clean up your data before you analyze it further, and the more complicated it becomes, the more cleanup it will probably require.
Let's Make Some Figures
In the following section, I will show you how you can create some simple figures to help you with your data analysis. There are many more tools, but I will keep the tools I use for this analysis very simple.
#take a deeper look at statistics for each age
#groupby shows mean of different variables, separating them by age
print(df.groupby("age").mean())
#which 20 people got charged the most?
print(df.sort_values("charges", ascending = False).head(20))
I wanted to show you the df.groupby operation. I use it to see the mean for different variables (bmi, charges, etc.) for different ages. It is a valuable operation. I also wanted to show you the sort_values operation. It provides the 20 rows of the people who got charged the most. You can use both of these functions to get a better idea of certain subsets of your data.
You should definitely play around with these on your computer if you want to understand how you can apply them. Active learning, and thinking about how you would solve different problems will really help you out in data analysis.
#lets make count plots for the categorical variables
for variable in ['sex', 'children', 'smoker', 'region']:
#the random.randint is generating colours for our plots, use count plot
sns.catplot(x=variable,data=df,kind="count",palette="Set"+
str(random.randint(1,3)),height=3.5,aspect=2, ci ="sd")
#lets make a pairplot to get some idea of what's going on in the data
#show scatter plots of all variables
plt.figure()
#positive correlation between charges and age? what are these insurance companies doing
sns.pairplot(df)
I use seaborn (sns) to create my plots. seaborn allows you to easily create figures using python. Spyder will plot them on your screen. You can also use matplotlib.pyplot (plt) to make figures, but I prefer seaborn.
I want to create lots of count plots(plots that show the number of males vs females, number of people in each region, etc.). I use a for loop. This goes through each the variables sex, children, smoker, and region. It then uses sns.catplot (makes categorical plots) to generate count plots for each variable. I set what I want to go in the x axis, what my dataset is, and I use random.randint to make different colours for the bars on the barplot. To learn more about functions like sns.catplot, you can google sns.catplot. Take a look at the seaborn website to learn more.
I will also introduce you to sns.pairplots, as it will create a bunch of plots between your different variables. This can help clarify what is going on in your data, and lead to further questions. For example, I see charges and age have a positive correlation.
These plots are simple to make, and will give you a general idea about your data.
plt.figure()
#why are insurance companies dealing mainly with young people?
sns.distplot(df["age"])
plt.figure()
sns.distplot(df["bmi"]) #nice normal distribution
plt.figure()
sns.distplot(df["charges"])
#lets take a look at the heatmap
plt.figure()
#so charges and age look to be the most correlated
sns.heatmap(df.corr(), annot=True, linewidths=.5, fmt= '.2f')
plt.title('Correlation Heatmap')
#let's make a scatter plot to take a closer look at the charges and age relationship
plt.figure()
sns.scatterplot(data = df, x = "age", y = "charges", hue = "bmi")
plt.title('Age and Charges Scatter Plot')
plt.figure()
#are smokers getting charged more
sns.boxplot(x="smoker", y="charges", data = df)
plt.figure()
#are what region is getting charged the most, who outliers are
sns.boxplot(x="region", y="charges", data = df)
In the code above, you see how you can use sns.distplot. This will create a histogram with a line superimposed on top. sns.heatmap will show you a heatmap of your data. You can also add titles and other attributes to your figures.
sns.boxplot will create a boxplot, and sns.scatterplot will show you a scatter plot. To fully understand what you should input in these functions, you should look at the documentation page on seaborn.
In the code above, I use plt.figure() before each plot so that it opens on a new page on spyder. Otherwise, spyder puts all the plots on top of each other.
#lets look at some comparisons using violin plots
#charges for male vs female
plt.figure()
sns.violinplot(data=df, x="sex", y="charges")
#charges for smoker vs non smoker
plt.figure()
sns.violinplot(data=df, x="smoker", y="charges")
#charges depending on number of kids
plt.figure()
sns.violinplot(data=df, x="children", y="charges")
#charges depending on location
plt.figure()
sns.violinplot(data=df, x="region", y="charges")
You can also make violin plots using sns.violinplot.
I showed you examples of violin plots, boxplots, scatterplots, and histograms with lines superimposed (distplot). You can play around with the variables so you may can look for any interesting insights in the data. For example, by using sns.boxplot with x = regions, I see that the insurance charges for each region are more or less the same. By using sns.boxplot with x = smoker, I see that the insurance charges are higher for smokers.
To learn more about using a module called seaborn to make plots using python, Check this out: Seaborn Tutorial
So, with this code you can create histograms, boxplots, scatterplots, clean up your duplicate rows or columns, and describe your dataset pretty adequately.
I learned how to do these analyses from these kaggle pages:
Enjoy!
All of the Code:
#import libraries needed
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats
import random
#don't want to see warnings on the console
#about older functions like distplot
import warnings
warnings.filterwarnings('ignore')
#pass in dataframe (df) csv file from github, check out the raw
df = pd.read_csv(".../csv_files/main/insurance.csv")
#let's see total number of (rows,columns)
print ("Number of Rows : " , df.shape[0])
print ("Number of Columns : " , df.shape[1])
#take a look at column names
print ("\nColumn Labels : \n", df.columns.tolist())
print("\n")
#let's quickly look at a summary of the dataset
print(df.info())
#lets use .nunique to see categorical and numerical data
print ("\nUnique values : \n", df.nunique())
#let's take a quick look at a total description of the dataset
print(df.describe(include="all"))
#take a look at first 5 rows of data
print(df.head())
#check for missing data
print("\nMissing Values: " )
print(df.isnull().sum())
#lets check for duplicate rows
number_dup_rows = len(df) - len(df.drop_duplicates(keep="first"))
if (number_dup_rows != 0):
print(f"\nNumber of duplicate rows found : {number_dup_rows}\nDuplicates removed!")
df.drop_duplicates(keep="first",inplace=True)
else:
print("No duplicate entries found!")
#take a deeper look at statistics for each age
#groupby shows mean of different variables, separating them by age
print(df.groupby("age").mean())
#which 20 people got charged the most?
print(df.sort_values("charges", ascending = False).head(20))
#lets make count plots for the categorical variables
for variable in ['sex', 'children', 'smoker', 'region']:
#the random.randint is generating colours for our plots, use count plot
sns.catplot(x=variable,data=df,kind="count",palette="Set"+
str(random.randint(1,3)),height=3.5,aspect=2, ci ="sd")
#lets make a pairplot to get some idea of what's going on in the data
#show scatter plots of all variables
plt.figure()
#positive correlation between charges and age? what are these insurance companies doing
sns.pairplot(df)
plt.figure()
#why are insurance companies dealing mainly with young people?
sns.distplot(df["age"])
plt.figure()
sns.distplot(df["bmi"]) #nice normal distribution
plt.figure()
sns.distplot(df["charges"])
#lets take a look at the heatmap
plt.figure()
#so charges and age look to be the most correlated
sns.heatmap(df.corr(), annot=True, linewidths=.5, fmt= '.2f')
plt.title('Correlation Heatmap')
#let's make a scatter plot to take a closer look at the charges and age relationship
plt.figure()
sns.scatterplot(data = df, x = "age", y = "charges", hue = "bmi")
plt.title('Age and Charges Scatter Plot')
plt.figure()
#are smokers getting charged more
sns.boxplot(x="smoker", y="charges", data = df)
plt.figure()
#are what region is getting charged the most, who outliers are
sns.boxplot(x="region", y="charges", data = df)
#lets look at some comparisons using violin plots
#charges for male vs female
plt.figure()
sns.violinplot(data=df, x="sex", y="charges")
#charges for smoker vs non smoker
plt.figure()
sns.violinplot(data=df, x="smoker", y="charges")
#charges depending on number of kids
plt.figure()
sns.violinplot(data=df, x="children", y="charges")
#charges depending on location
plt.figure()
sns.violinplot(data=df, x="region", y="charges")