Commit c60efff0 authored by Douglas's avatar Douglas

Adds more explanation and removes SGD regression

- explanation about the preparation of data for scikit-learn fit methods
- removed SGD regression because its prediction is very bad and makes the plot really ugly
parent f7b2142b
......@@ -32,6 +32,25 @@ ts = pd.read_csv(csv_file, sep=';',
squeeze=True,
index_col='date')
# All scikit-learn `fit` methods require a 2D array of size
# [n_samples x n_features]. So we transform our DateTimeIndex
# into an array of time deltas describing how many days have
# passed since the first observation and call the days method
# to get just the numbers of days instead of a TimeDelta object.
# Something like this:
#
# >>> [1 ,2 ,3 , 4, ..., n_samples]
#
# Then, we reshape it using (-1, 1) shape, where -1 will tell
# numpy to infer the array size and use it there. This reshape is
# used because we have only 1 feature (the number of the day
# of the order). The result will be an array like below:
#
# >>> [[1], [2], [3], ..., [n_samples]]
#
# If we had more features, each element of the outter array
# would have more than one element.
#
X = (ts.index - ts.index[0]).days.reshape(-1, 1)
y = ts.values
......@@ -55,12 +74,6 @@ mtl = linear_model.ElasticNet(alpha=0.1, normalize=True)
mtl.fit(X_train, y_train)
plt.plot(X_test, mtl.predict(X_test), color='green', linewidth=2)
# SGD Regression
sgd = linear_model.SGDRegressor(shuffle=False, eta0=0.25)
sgd.fit(X_train, y_train)
plt.plot(X_test, sgd.predict(X_test), color='red', linewidth=1)
plt.xticks()
plt.yticks()
plt.show()
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment