Predicting Google’s Stock Price using Linear Regression

What is Linear Regression? Let’s forget the term ‘linear regression’ for some time. Instead, I want you to go back to your high school’s math class. You must have plotted a graph of a given linear equation during coordinate geometry classes in high school. Let’s revise what we did there.

We were given an equation y = 2x + 3, where 2 is the coefficient of x and 3 is a constant (i.e. intercept on y axis).  What we used to do was:

  1.  Take a value of x (say x=0)
  2. Find the corresponding value of y by putting x=0 in the equation.
  3. Store the (x,y) value pair in a table.
  4. Repeat the process once or twice or as many times as we want.
  5. Plot the points on the graph to obtain the straight line.

Now, we will just do the reverse of the above method.

  1. We have some set of points (x1, y1), (x2, y2), (x3, y3) and so on till (xn, yn).
  2. We have to use these set of points to find the coefficient a and the constant b such that y=ax + b.
  3. Once we have the equation, we can find the approximate value of y for any value for x.

Basically, we found a relationship or pattern between the values of x and y and generated an equation y=ax+b. You just did linear regression without even knowing.

Let’s see the official definition of regression (from Wikipedia).

In statistics, linear regression is an approach for modeling the relationship between a scalar dependent variable y and one or more explanatory variables (or independent variables) denoted X. The case of one explanatory variable is called simple linear regression. For more than one explanatory variable, the process is called multiple linear regression.

So, in our implementation, x is the independent/explanatory variable, and y is the dependent variable, as its value is dependent on x. Now, let us implement simple linear regression using Python to understand the real life application of the method.

We will be predicting the future price of Google’s stock using simple linear regression. The data that we will be using is real data obtained from Google Finance saved to a CSV file, google.csv .

DateOpen
26708.58
25700.01
24688.92
23701.45
22707.45
19695.03
18710
17699
16692.98
12690.26
11675
10686.86
9672.32
8667.85
5703.87
4722.81
3770.22
2784.5
1750.46

In the above dataset, we have the prices at which the Google stock opened from February 1 – February 26, 2016. Using this data, we will try to predict the price at which the stock will open on February 29, 2016. We will be using scikit-learn, csv, numpy and matplotlib packages to implement and visualize simple linear regression.

First, let’s import the above modules:

csv module is used to read data from the file “google.csv”. numpy is used for array processing and conversion. Sklearn (scikitlearn) is used to implement linear regression. And, matplotlib is used to plot the data-points on graph.

First, let’s define a method to read data from google.csv .

Don’t worry if you are not familiar with reading data from CSV files using python. Just read our previous article, Interacting with CSV files using Python which has well explained and easy to follow examples to help you.

Now, let’s define a function to predict the price of Google’s stock on a given date.

The method predict_price takes 3 arguments,

–          dates: the list of dates in integer type

–          prices: the opening price of stock for the corresponding date

–          x: the date for which we want to predict the price (i.e. 29)

The fit method fits the dates and prices (x’s and y’s) to generate coefficient and constant for regression. Finally, the predict method finds the price(y) for the given date (x) and returns the predicted price, the coefficient and the constant of the relationship equation.

To understand the concept of regression better, we can use matplotlib python module to plot the data-points and the relationship formed between them.

Note: The show_plot method draws the graph using matplotlib. Do not worry if you do not understand the below code completely. It is more important to understand the graph which follows the below code. However, the show_plot method below is commented to help you in understanding the code.

Image: plot_linear_regression

Plot/Graph of data points(yellow) and relationship equation(blue)

The yellow dots in the above plot show the data-points plotted for each date and price (i.e. the initial dataset)

The blue line is the equation formed by the fit method of the linear model (see predict_price method above)

Now, when we input the date February 29 to the regression model, it just uses the equation of the blue straight line in the above plot, and finds the corresponding value on y axis.

See the full program code below.

The above program gives the below output:

See the last line of the output. They show the equation of the blue line formed in the plot. I hope you have got the concept of linear regression now.

Congrats! You just learnt a fundamental yet strong machine learning technique. The dataset, code and plot are available on Github. Your questions are welcome in the comments below.

Share this article with your friends on Facebook, Twitter and other social networks.

Leave a Reply

%d bloggers like this: