Python/machine_learning/Reuters - OneVsRestClassifier.ipynb

406 lines
151 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import nltk\n",
"except ModuleNotFoundError:\n",
" !pip install nltk"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"## This code downloads the required packages.\n",
"## You can run `nltk.download('all')` to download everything.\n",
"\n",
"nltk_packages = [\n",
" (\"reuters\", \"corpora/reuters.zip\")\n",
"]\n",
"\n",
"for pid, fid in nltk_packages:\n",
" try:\n",
" nltk.data.find(fid)\n",
" except LookupError:\n",
" nltk.download(pid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up corpus"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from nltk.corpus import reuters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up train/test data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_documents, train_categories = zip(*[(reuters.raw(i), reuters.categories(i)) for i in reuters.fileids() if i.startswith('training/')])\n",
"test_documents, test_categories = zip(*[(reuters.raw(i), reuters.categories(i)) for i in reuters.fileids() if i.startswith('test/')])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"all_categories = sorted(list(set(reuters.categories())))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following cell defines a function **tokenize** that performs following actions:\n",
"- Receive a document as an argument to the function\n",
"- Tokenize the document using `nltk.word_tokenize()`\n",
"- Use `PorterStemmer` provided by the `nltk` to remove morphological affixes from each token\n",
"- Append stemmed token to an already defined list `stems`\n",
"- Return the list `stems`"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from nltk.stem.porter import PorterStemmer\n",
"def tokenize(text):\n",
" tokens = nltk.word_tokenize(text)\n",
" stems = []\n",
" for item in tokens:\n",
" stems.append(PorterStemmer().stem(item))\n",
" return stems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To begin, I first used TF-IDF for feature selection on both train as well as test data using `TfidfVectorizer`.\n",
"\n",
"But first, What `TfidfVectorizer` actually does?\n",
"- `TfidfVectorizer` converts a collection of raw documents to a matrix of **TF-IDF** features.\n",
"\n",
"**TF-IDF**?\n",
"- TFIDF (abbreviation of the term *frequencyinverse document frequency*) is a numerical statistic that is intended to reflect how important a word is to a document in a collection or corpus. [tfidf](https://en.wikipedia.org/wiki/Tf%E2%80%93idf)\n",
"\n",
"**Why `TfidfVectorizer`**?\n",
"- `TfidfVectorizer` scale down the impact of tokens that occur very frequently (e.g., “a”, “the”, and “of”) in a given corpus. [Feature Extraction and Transformation](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#tf-idf)\n",
"\n",
"I gave following two arguments to `TfidfVectorizer`:\n",
"- tokenizer: `tokenize` function\n",
"- stop_words\n",
"\n",
"Then I used `fit_transform` and `transform` on the train and test documents repectively.\n",
"\n",
"**Why `fit_transform` for training data while `transform` for test data**?\n",
"\n",
"To avoid data leakage during cross-validation, imputer computes the statistic on the train data during the `fit`, **stores it** and uses the same on the test data, during the `transform`. This also prevents the test data from appearing in `fit` operation."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"\n",
"vectorizer = TfidfVectorizer(tokenizer = tokenize, stop_words = 'english')\n",
"\n",
"vectorised_train_documents = vectorizer.fit_transform(train_documents)\n",
"vectorised_test_documents = vectorizer.transform(test_documents)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the **efficient implementation** of machine learning algorithms, many machine learning algorithms **requires all input variables and output variables to be numeric**. This means that categorical data must be converted to a numerical form.\n",
"\n",
"For this purpose, I used `MultiLabelBinarizer` from `sklearn.preprocessing`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"\n",
"mlb = MultiLabelBinarizer()\n",
"train_labels = mlb.fit_transform(train_categories)\n",
"test_labels = mlb.transform(test_categories)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, To **train** the classifier, I used `LinearSVC` in combination with the `OneVsRestClassifier` function in the scikit-learn package.\n",
"\n",
"The strategy of `OneVsRestClassifier` is of **fitting one classifier per label** and the `OneVsRestClassifier` can efficiently do this task and also outputs are easy to interpret. Since each label is represented by **one and only one classifier**, it is possible to gain knowledge about the label by inspecting its corresponding classifier. [OneVsRestClassifier](http://scikit-learn.org/stable/modules/multiclass.html#one-vs-the-rest)\n",
"\n",
"The reason I combined `LinearSVC` with `OneVsRestClassifier` is because `LinearSVC` supports **Multi-class**, while we want to perform **Multi-label** classification."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"from sklearn.multiclass import OneVsRestClassifier\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"classifier = OneVsRestClassifier(LinearSVC())\n",
"classifier.fit(vectorised_train_documents, train_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After fitting the classifier, I decided to use `cross_val_score` to **measure score** of the classifier by **cross validation** on the training data. But the only problem was, I wanted to **shuffle** data to use with `cross_val_score`, but it does not support shuffle argument.\n",
"\n",
"So, I decided to use `KFold` with `cross_val_score` as `KFold` supports shuffling the data.\n",
"\n",
"I also enabled `random_state`, because `random_state` will guarantee the same output in each run. By setting the `random_state`, it is guaranteed that the pseudorandom number generator will generate the same sequence of random integers each time, which in turn will affect the split.\n",
"\n",
"Why **42**?\n",
"- [Why '42' is the preferred number when indicating something random?](https://softwareengineering.stackexchange.com/questions/507/why-42-is-the-preferred-number-when-indicating-something-random)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"from sklearn.model_selection import KFold, cross_val_score\n",
"\n",
"kf = KFold(n_splits=10, random_state = 42, shuffle = True)\n",
"scores = cross_val_score(classifier, vectorised_train_documents, train_labels, cv = kf)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cross-validation scores: [0.83655084 0.86743887 0.8043758 0.83011583 0.83655084 0.81724582\n",
" 0.82754183 0.8030888 0.80694981 0.82731959]\n",
"Cross-validation accuracy: 0.8257 (+/- 0.0368)\n"
]
}
],
"source": [
"print('Cross-validation scores:', scores)\n",
"print('Cross-validation accuracy: {:.4f} (+/- {:.4f})'.format(scores.mean(), scores.std() * 2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the end, I used different methods (`accuracy_score`, `precision_score`, `recall_score`, `f1_score` and `confusion_matrix`) provided by scikit-learn **to evaluate** the classifier. (both *Macro-* and *Micro-averages*)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix\n",
"\n",
"predictions = classifier.predict(vectorised_test_documents)\n",
"\n",
"accuracy = accuracy_score(test_labels, predictions)\n",
"\n",
"macro_precision = precision_score(test_labels, predictions, average='macro')\n",
"macro_recall = recall_score(test_labels, predictions, average='macro')\n",
"macro_f1 = f1_score(test_labels, predictions, average='macro')\n",
"\n",
"micro_precision = precision_score(test_labels, predictions, average='micro')\n",
"micro_recall = recall_score(test_labels, predictions, average='micro')\n",
"micro_f1 = f1_score(test_labels, predictions, average='micro')\n",
"\n",
"cm = confusion_matrix(test_labels.argmax(axis = 1), predictions.argmax(axis = 1))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8099\n",
"Precision:\n",
"- Macro: 0.6076\n",
"- Micro: 0.9471\n",
"Recall:\n",
"- Macro: 0.3708\n",
"- Micro: 0.7981\n",
"F1-measure:\n",
"- Macro: 0.4410\n",
"- Micro: 0.8662\n"
]
}
],
"source": [
"print(\"Accuracy: {:.4f}\\nPrecision:\\n- Macro: {:.4f}\\n- Micro: {:.4f}\\nRecall:\\n- Macro: {:.4f}\\n- Micro: {:.4f}\\nF1-measure:\\n- Macro: {:.4f}\\n- Micro: {:.4f}\".format(accuracy, macro_precision, micro_precision, macro_recall, micro_recall, macro_f1, micro_f1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In below cell, I used `matplotlib.pyplot` to **plot the confusion matrix** (of first *few results only* to keep the readings readable) using `heatmap` of `seaborn`."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSUAAAV0CAYAAAAhI3i0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xl8lOW5//HvPUlYVRRRIQkVW1xarYUWUKsiFQvUqnSlP0+1ttXD6XGptlW7aGu1p9upnurpplgFl8qiPXUFi2AtUBGIEiAQQBCKCRFXVHAhJPfvjxnoCDPPMpPMM3fuz/v1mhfJJN9c1/XMTeaZJ8/MGGutAAAAAAAAAKBUUkk3AAAAAAAAAMAvHJQEAAAAAAAAUFIclAQAAAAAAABQUhyUBAAAAAAAAFBSHJQEAAAAAAAAUFIclAQAAAAAAABQUokdlDTGjDPGrDHGrDPGfC9m9nZjzIvGmIYC6g40xvzNGNNojFlpjLk0RraHMWaxMWZZJnttAfUrjDFLjTEPF5DdaIxZYYypN8bUxczub4y5zxizOjP7CRFzR2bq7bq8YYy5LGbtb2W2V4MxZqoxpkeM7KWZ3MoodXOtDWNMX2PMY8aYZzP/HhAj+8VM7XZjzLCYdX+V2d7LjTF/McbsHyP7k0yu3hgz2xhTHad21tcuN8ZYY0y/GLV/bIxpzrrNT49T1xhzSeb/9kpjzH/HqDs9q+ZGY0x9nJmNMUOMMU/t+v9hjBkRI/sRY8zCzP+vh4wx++XJ5vz9EWWNBWSjrrF8+dB1FpANXWf5sllfz7vGAupGXWN5a4ets4DaoessIBt1jeXLh64zk+d+xhhzmDFmUWaNTTfGdIuRvdik72uDfhfky/4ps50bTPr/TlXM/G2Z65ab9H3QPlGzWV//jTFmW8y6U4wxG7Ju6yEx88YY81NjzNrM7fjNGNn5WXU3G2Puj5EdbYx5JpNdYIwZHCN7aibbYIy5wxhTmWvmrJ/znv2RKGssIBu6xgKykdZYnmzo+grKZ12fd40F1I60xvJkQ9dXQDZ0fYXkQ9dYQDbyGjM59llN9P2xXNmo95W5spH2xwLykfbJcmWzvha2P5arbtT7ypx1TYT9sYDakfbJ8mSj3lfmykbaH8t8716PbWKssVzZqGssVzbqPn+ubJx9/ryP5yKssVy1o66xnHWjrLE8dePs8+fKR11jubJR9sVyPv6Nsb7y5UPXWEA28u8xwDnW2pJfJFVIWi/p/ZK6SVom6UMx8iMlfVRSQwG1B0j6aObjfSWtjVpbkpG0T+bjKkmLJB0fs/63Jd0j6eECet8oqV+B2/wOSRdkPu4maf8Cb7cXJB0aI1MjaYOknpnPZ0j6asTsMZIaJPWSVClpjqTD464NSf8t6XuZj78n6Zcxsh+UdKSkJyQNi1l3jKTKzMe/jFl3v6yPvynp5ji1M9cPlPRXSf/Mt27y1P6xpMsj3D65sp/I3E7dM58fHKfnrK/fIOlHMWvPlvSpzMenS3oiRnaJpFMyH39d0k/yZHP+/oiyxgKyUddYvnzoOgvIhq6zfNkoayygbtQ1li8fus6C+g5bZwF1o66xfPnQdaY89zNK/+78f5nrb5b0nzGyQyUNUsB9SED29MzXjKSpueqG5LPX2P8o8/8kSjbz+TBJd0naFrPuFElfiLDG8uW/JulOSamANRa6TyDpz5K+EqPuWkkfzFx/oaQpEbMfl/S8pCMy118n6fyQ2d+zPxJljQVkQ9dYQDbSGsuTDV1fQfkoayygdqQ1licbur6Ceg5bXyG1Q9dYrqzSJzJEXmO51oKi74/lyka9r8yVjbQ/FpCPtE+Wb/0r2v5Yrro/VrT7ylzZSPtjQX1nfT3vPlme2lHvK3NlI+2PZb6+12ObGGssVzbqGsuVjbrPnysbZ58/5+O5iGssV+2oayxXNuo+f+Bj0KD1FVA76hrLlY28xjLfs/vxb9T1FZCPtMbyZCP/HuPCxbVLUmdKjpC0zlr7nLV2h6RpksZHDVtr50l6tZDC1toWa+0zmY/flNSo9IGzKFlrrd31l/SqzMVGrW2MqZX0aUl/jNV0kTJ/ARop6TZJstbusNZuLeBHjZa03lr7z5i5Skk9Tfov6r0kbY6Y+6Ckp6y1b1lrd0r6u6TPBgXyrI3xSt8pKfPvZ6JmrbWN1to1YY3myc7O9C1JT0mqjZF9I+vT3gpYZwH/H34t6coCs6HyZP9T0i+ste9mvufFuHWNMUbSBKUfnMapbSXt+mtnH+VZZ3myR0qal/n4MUmfz5PN9/sjdI3ly8ZYY/nyoessIBu6zkJ+ZwausWJ+34bkQ9dZWO2gdRaQjbrG8uVD11nA/cypku7LXJ9vjeXMWmuXWms35uo1QnZm5mtW0mLl/z2WL/+GtHt791TuNZYza4ypkPQrpddYrL6DZo2Y/09J11lr2zPfl2uNBdY2xuyr9O2215lsAdnQNZYn2ybpXWvt2sz1eX+PZXp7z/5I5vYJXWO5spmeQtdYQDbSGsuTDV1fQfkoayxfNqo82dD1FVY3aH2F5CP9HsuRPVAx1lgekfbHcol6X5knG2l/LCAfeZ8sj9D9sU4QaX8sTJR9shwirbE8Iu2PBTy2CV1j+bJR1lhANnSNBWQjra+Qx3OBa6yYx4IB2dA1FlY3bH0F5EPXWEA20hrLkv34t5DfYbvzBfwey84W9XsMKGdJHZSsUfqvrbs0KcYD1Y5ijBmk9F/3F8XIVGROMX9R0mPW2shZSTcqfYfRHiOTzUqabYx52hgzMUbu/ZJekjTZpJ+G80djTO8C6v8/xdspkbW2WdL1kjZJapH0urV2dsR4g6SRxpgDjTG9lP5L2MA49TMOsda2ZPppkXRwAT+jWF+XNCtOwKSf2vW8pC9L+lHM7FmSmq21y+LkslyceXrA7fmempDHEZJONumnAP7dGDO8gNonS9pirX02Zu4ySb/KbLPrJX0/RrZB0lmZj7+oCOtsj98fsdZYIb97IuZD19me2TjrLDsbd43l6DnWGtsjH2ud5dlekdbZHtnYa2yPfKR1tuf9jNLPLNiatTOa9z6zmPuooKxJP6X2XEmPxs0bYyYr/Zf+oyT9Jkb2YkkP7vq/VUDfP82ssV8bY7rHzH9A0pdM+mlhs4wxh8esLaX/iDZ3jwecYdkLJM00xjQpvb1/ESWr9MG8qqyng31Bwb/H9twfOVAR11iObBx5sxHWWM5slPUVkI+0xgL6jrLGcmUjra+AulLI+grIR1pjObIvK94ay7XPGvW+stD93SjZsPvJnPmI95V7ZWPcV+brO8p9Za5snPvJoG0Wdl+ZKxv1vjJXNur+WL7HNlHWWDGPi6Jk862xvNmI6ytnPuIaC+o7bI3ly0ZZY2HbK2x95ctHWWP5snH3+bMf/xbymDL24+cI2diPK4FyltRBSZPjulL+9VAm/bpDf5Z0WcgO3XtYa9ustUOU/uvECGPMMRHrnSHpRWvt0wU1nHaitfajkj4l6SJjzMiIuUqln676B2vtUEnblT7lPDKTfm2psyTdGzN3gNJ/VTpMUrWk3saYc6JkrbWNSp+e/pjSD1KWSdoZGCpDxpirlO77T3Fy1tqrrLUDM7mLY9TrJekqxTyQmeUPSj9gGqL0geQbYmQrJR2g9NMQr5A0wxiT6/97kLNV2J33f0r6VmabfUuZv4xG9HWl/089rfTTbXcEfXOhvz+KzQblo6yzXNmo6yw7m6kTeY3lqBtrjeXIR15nAds7dJ3lyMZaYznykdbZnvczSp81vte3RclGvY+KkP29pHnW2vlx89baryn9+79R0pciZkcq/WAh6CBTUN3vK32QarikvpK+GzPfXdI71tphkm6VdHucmTMC11ie7LcknW6trZU0WemnJIdmJR2t9IOXXxtjFkt6U3nuL/Psj0TaLytmXyZCNu8aC8pGWV+58ib9um2hayygdugaC8iGrq8I2ytwfQXkQ9dYrqy11iriGssodJ+107IR98dy5iPeV+bKRr2vzJWNel+ZKxtnfyxoe4fdV+bKRr2vzJWNuj9WzGObTsuGrLG82Yj
"text/plain": [
"<matplotlib.figure.Figure at 0x24d8cf39f28>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sb\n",
"import pandas as pd\n",
"\n",
"cm_plt = pd.DataFrame(cm[:73])\n",
"\n",
"plt.figure(figsize = (25, 25))\n",
"ax = plt.axes()\n",
"\n",
"sb.heatmap(cm_plt, annot=True)\n",
"\n",
"ax.xaxis.set_ticks_position('top')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, I took the data from [Coconut - Wikipedia](https://en.wikipedia.org/wiki/Coconut) to check if the classifier is able to **correctly** predict the label(s) or not.\n",
"\n",
"And here is the output:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example labels: [('coconut', 'oilseed')]\n"
]
}
],
"source": [
"example_text = '''The coconut tree (Cocos nucifera) is a member of the family Arecaceae (palm family) and the only species of the genus Cocos.\n",
"The term coconut can refer to the whole coconut palm or the seed, or the fruit, which, botanically, is a drupe, not a nut.\n",
"The spelling cocoanut is an archaic form of the word.\n",
"The term is derived from the 16th-century Portuguese and Spanish word coco meaning \"head\" or \"skull\", from the three indentations on the coconut shell that resemble facial features.\n",
"Coconuts are known for their versatility ranging from food to cosmetics.\n",
"They form a regular part of the diets of many people in the tropics and subtropics.\n",
"Coconuts are distinct from other fruits for their endosperm containing a large quantity of water (also called \"milk\"), and when immature, may be harvested for the potable coconut water.\n",
"When mature, they can be used as seed nuts or processed for oil, charcoal from the hard shell, and coir from the fibrous husk.\n",
"When dried, the coconut flesh is called copra.\n",
"The oil and milk derived from it are commonly used in cooking and frying, as well as in soaps and cosmetics.\n",
"The husks and leaves can be used as material to make a variety of products for furnishing and decorating.\n",
"The coconut also has cultural and religious significance in certain societies, particularly in India, where it is used in Hindu rituals.'''\n",
"\n",
"example_preds = classifier.predict(vectorizer.transform([example_text]))\n",
"example_labels = mlb.inverse_transform(example_preds)\n",
"print(\"Example labels: {}\".format(example_labels))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}