Skip to content

Commit

Permalink
Add regression metric reporter and update requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
risenW committed Dec 24, 2019
1 parent 9b7a0d7 commit dfb9781
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 256 deletions.
1 change: 0 additions & 1 deletion datasist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
from . import timeseries
from . import visualizations
from . import model
from . import nlp
51 changes: 50 additions & 1 deletion datasist/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
'''
import platform

from sklearn.metrics import roc_curve, confusion_matrix, precision_score, accuracy_score, recall_score, f1_score, make_scorer
from sklearn.metrics import roc_curve, confusion_matrix, precision_score, accuracy_score, recall_score, f1_score, make_scorer, mean_absolute_error, mean_squared_error, r2_score, mean_squared_log_error
from sklearn.model_selection import KFold, cross_val_score
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -232,6 +232,55 @@ def get_classification_report(y_train=None, prediction=None, show_roc_plot=True,
plt.savefig("roc_plot.png")


def get_regression_report(y_true=None, prediction=None, show_r2_plot=True, save_plot=False):
'''
Generates performance report for a regression problem.
Parameters:
------------------
y_true: Array, series, list.
The truth/ground value from the train data set.
prediction: Array, series, list.
The predicted value by a trained model.
show_r2_plot: Bool, default True.
Show the r-squared curve.
save_plot: Bool, default True.
Save the plot to the current working directory.
'''
mae = mean_absolute_error(y_true, prediction)
mse = mean_squared_error(y_true, prediction)
msle = precision_score(y_true, prediction)
r2 = r2_score(y_true, prediction)

print("Mean Absolute Error: ", round(mae, 5))
print("Mean Squared Error: ", round(mse, 5))
print("Mean Squared Log Error: ", round(msle, 5))
print("R-squared Error: ", round(r2, 5))
print("*" * 100)

if show_r2_plot:
plt.scatter(y_true,prediction)
plt.xlabel('Truth values')
plt.ylabel('Predicted values')
plt.plot(np.unique(y_true), np.poly1d(np.polyfit(y_true, y_true, 1))(np.unique(y_true)))
plt.text(0.7, 0.2, 'R-squared = %0.2f' % r2)
plt.show()

if save_plot:
plt.savefig("r2_plot.png")





def compare_model(models_list=None, x_train=None, y_train=None, scoring_metric=None, scoring_cv=3, silenced=True, plot=True):
"""
Train multiple user-defined model and display report based on defined metric. Enables user to pick the best base model for a problem.
Expand Down
2 changes: 2 additions & 0 deletions datasist/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from datasist import model


def test_compare_model_classification():
x_train, y_train = make_classification(
n_samples=50,
Expand All @@ -24,6 +25,7 @@ def test_compare_model_classification():
assert type(model_scores) is list
assert hasattr(fitted_model[0], "predict")


def test_compare_model_regression():
x_train, y_train = make_classification(
n_samples=50,
Expand Down
11 changes: 0 additions & 11 deletions datasist/tests/test_nlp.py

This file was deleted.

8 changes: 2 additions & 6 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ <h1 class="title">Module <code>datasist</code></h1>
from . import timeseries
from . import visualizations
from . import model
from . import nlp</code></pre>
</code></pre>
</details>
</section>
<section>
Expand All @@ -43,10 +43,7 @@ <h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dd>
<section class="desc"><p>This module contains all functions relating to modeling in using sklearn library.</p></section>
</dd>
<dt><code class="name"><a title="datasist.nlp" href="nlp.html">datasist.nlp</a></code></dt>
<dd>
<section class="desc"><p>This module contains all functions relating to nlp</p></section>
</dd>
<dd></dd>
<dt><code class="name"><a title="datasist.structdata" href="structdata.html">datasist.structdata</a></code></dt>
<dd>
<section class="desc"><p>This module contains all functions relating to the cleaning and exploration of structured data sets; mostly in pandas format</p></section>
Expand Down Expand Up @@ -80,7 +77,6 @@ <h1><img src="datasist.png" alt="logo"></h1>
<ul>
<li><code><a title="datasist.feature_engineering" href="feature_engineering.html">datasist.feature_engineering</a></code></li>
<li><code><a title="datasist.model" href="model.html">datasist.model</a></code></li>
<li><code><a title="datasist.nlp" href="nlp.html">datasist.nlp</a></code></li>
<li><code><a title="datasist.structdata" href="structdata.html">datasist.structdata</a></code></li>
<li><code><a title="datasist.timeseries" href="timeseries.html">datasist.timeseries</a></code></li>
<li><code><a title="datasist.visualizations" href="visualizations.html">datasist.visualizations</a></code></li>
Expand Down
230 changes: 0 additions & 230 deletions docs/nlp.html

This file was deleted.

4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@ pandas
matplotlib
seaborn
scikit-learn
numpy
spacy
en
numpy
Loading

0 comments on commit dfb9781

Please sign in to comment.