mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-30 16:31:08 +00:00
406 lines
151 KiB
Plaintext
406 lines
151 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"try:\n",
|
|||
|
" import nltk\n",
|
|||
|
"except ModuleNotFoundError:\n",
|
|||
|
" !pip install nltk"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"## This code downloads the required packages.\n",
|
|||
|
"## You can run `nltk.download('all')` to download everything.\n",
|
|||
|
"\n",
|
|||
|
"nltk_packages = [\n",
|
|||
|
" (\"reuters\", \"corpora/reuters.zip\")\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"for pid, fid in nltk_packages:\n",
|
|||
|
" try:\n",
|
|||
|
" nltk.data.find(fid)\n",
|
|||
|
" except LookupError:\n",
|
|||
|
" nltk.download(pid)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Setting up corpus"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from nltk.corpus import reuters"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Setting up train/test data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"train_documents, train_categories = zip(*[(reuters.raw(i), reuters.categories(i)) for i in reuters.fileids() if i.startswith('training/')])\n",
|
|||
|
"test_documents, test_categories = zip(*[(reuters.raw(i), reuters.categories(i)) for i in reuters.fileids() if i.startswith('test/')])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"all_categories = sorted(list(set(reuters.categories())))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The following cell defines a function **tokenize** that performs following actions:\n",
|
|||
|
"- Receive a document as an argument to the function\n",
|
|||
|
"- Tokenize the document using `nltk.word_tokenize()`\n",
|
|||
|
"- Use `PorterStemmer` provided by the `nltk` to remove morphological affixes from each token\n",
|
|||
|
"- Append stemmed token to an already defined list `stems`\n",
|
|||
|
"- Return the list `stems`"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from nltk.stem.porter import PorterStemmer\n",
|
|||
|
"def tokenize(text):\n",
|
|||
|
" tokens = nltk.word_tokenize(text)\n",
|
|||
|
" stems = []\n",
|
|||
|
" for item in tokens:\n",
|
|||
|
" stems.append(PorterStemmer().stem(item))\n",
|
|||
|
" return stems"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"To begin, I first used TF-IDF for feature selection on both train as well as test data using `TfidfVectorizer`.\n",
|
|||
|
"\n",
|
|||
|
"But first, What `TfidfVectorizer` actually does?\n",
|
|||
|
"- `TfidfVectorizer` converts a collection of raw documents to a matrix of **TF-IDF** features.\n",
|
|||
|
"\n",
|
|||
|
"**TF-IDF**?\n",
|
|||
|
"- TFIDF (abbreviation of the term *frequency–inverse document frequency*) is a numerical statistic that is intended to reflect how important a word is to a document in a collection or corpus. [tf–idf](https://en.wikipedia.org/wiki/Tf%E2%80%93idf)\n",
|
|||
|
"\n",
|
|||
|
"**Why `TfidfVectorizer`**?\n",
|
|||
|
"- `TfidfVectorizer` scale down the impact of tokens that occur very frequently (e.g., “a”, “the”, and “of”) in a given corpus. [Feature Extraction and Transformation](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#tf-idf)\n",
|
|||
|
"\n",
|
|||
|
"I gave following two arguments to `TfidfVectorizer`:\n",
|
|||
|
"- tokenizer: `tokenize` function\n",
|
|||
|
"- stop_words\n",
|
|||
|
"\n",
|
|||
|
"Then I used `fit_transform` and `transform` on the train and test documents repectively.\n",
|
|||
|
"\n",
|
|||
|
"**Why `fit_transform` for training data while `transform` for test data**?\n",
|
|||
|
"\n",
|
|||
|
"To avoid data leakage during cross-validation, imputer computes the statistic on the train data during the `fit`, **stores it** and uses the same on the test data, during the `transform`. This also prevents the test data from appearing in `fit` operation."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
|||
|
"\n",
|
|||
|
"vectorizer = TfidfVectorizer(tokenizer = tokenize, stop_words = 'english')\n",
|
|||
|
"\n",
|
|||
|
"vectorised_train_documents = vectorizer.fit_transform(train_documents)\n",
|
|||
|
"vectorised_test_documents = vectorizer.transform(test_documents)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"For the **efficient implementation** of machine learning algorithms, many machine learning algorithms **requires all input variables and output variables to be numeric**. This means that categorical data must be converted to a numerical form.\n",
|
|||
|
"\n",
|
|||
|
"For this purpose, I used `MultiLabelBinarizer` from `sklearn.preprocessing`."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
|||
|
"\n",
|
|||
|
"mlb = MultiLabelBinarizer()\n",
|
|||
|
"train_labels = mlb.fit_transform(train_categories)\n",
|
|||
|
"test_labels = mlb.transform(test_categories)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Now, To **train** the classifier, I used `LinearSVC` in combination with the `OneVsRestClassifier` function in the scikit-learn package.\n",
|
|||
|
"\n",
|
|||
|
"The strategy of `OneVsRestClassifier` is of **fitting one classifier per label** and the `OneVsRestClassifier` can efficiently do this task and also outputs are easy to interpret. Since each label is represented by **one and only one classifier**, it is possible to gain knowledge about the label by inspecting its corresponding classifier. [OneVsRestClassifier](http://scikit-learn.org/stable/modules/multiclass.html#one-vs-the-rest)\n",
|
|||
|
"\n",
|
|||
|
"The reason I combined `LinearSVC` with `OneVsRestClassifier` is because `LinearSVC` supports **Multi-class**, while we want to perform **Multi-label** classification."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%%capture\n",
|
|||
|
"from sklearn.multiclass import OneVsRestClassifier\n",
|
|||
|
"from sklearn.svm import LinearSVC\n",
|
|||
|
"\n",
|
|||
|
"classifier = OneVsRestClassifier(LinearSVC())\n",
|
|||
|
"classifier.fit(vectorised_train_documents, train_labels)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"After fitting the classifier, I decided to use `cross_val_score` to **measure score** of the classifier by **cross validation** on the training data. But the only problem was, I wanted to **shuffle** data to use with `cross_val_score`, but it does not support shuffle argument.\n",
|
|||
|
"\n",
|
|||
|
"So, I decided to use `KFold` with `cross_val_score` as `KFold` supports shuffling the data.\n",
|
|||
|
"\n",
|
|||
|
"I also enabled `random_state`, because `random_state` will guarantee the same output in each run. By setting the `random_state`, it is guaranteed that the pseudorandom number generator will generate the same sequence of random integers each time, which in turn will affect the split.\n",
|
|||
|
"\n",
|
|||
|
"Why **42**?\n",
|
|||
|
"- [Why '42' is the preferred number when indicating something random?](https://softwareengineering.stackexchange.com/questions/507/why-42-is-the-preferred-number-when-indicating-something-random)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%%capture\n",
|
|||
|
"from sklearn.model_selection import KFold, cross_val_score\n",
|
|||
|
"\n",
|
|||
|
"kf = KFold(n_splits=10, random_state = 42, shuffle = True)\n",
|
|||
|
"scores = cross_val_score(classifier, vectorised_train_documents, train_labels, cv = kf)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Cross-validation scores: [0.83655084 0.86743887 0.8043758 0.83011583 0.83655084 0.81724582\n",
|
|||
|
" 0.82754183 0.8030888 0.80694981 0.82731959]\n",
|
|||
|
"Cross-validation accuracy: 0.8257 (+/- 0.0368)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print('Cross-validation scores:', scores)\n",
|
|||
|
"print('Cross-validation accuracy: {:.4f} (+/- {:.4f})'.format(scores.mean(), scores.std() * 2))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"In the end, I used different methods (`accuracy_score`, `precision_score`, `recall_score`, `f1_score` and `confusion_matrix`) provided by scikit-learn **to evaluate** the classifier. (both *Macro-* and *Micro-averages*)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%%capture\n",
|
|||
|
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix\n",
|
|||
|
"\n",
|
|||
|
"predictions = classifier.predict(vectorised_test_documents)\n",
|
|||
|
"\n",
|
|||
|
"accuracy = accuracy_score(test_labels, predictions)\n",
|
|||
|
"\n",
|
|||
|
"macro_precision = precision_score(test_labels, predictions, average='macro')\n",
|
|||
|
"macro_recall = recall_score(test_labels, predictions, average='macro')\n",
|
|||
|
"macro_f1 = f1_score(test_labels, predictions, average='macro')\n",
|
|||
|
"\n",
|
|||
|
"micro_precision = precision_score(test_labels, predictions, average='micro')\n",
|
|||
|
"micro_recall = recall_score(test_labels, predictions, average='micro')\n",
|
|||
|
"micro_f1 = f1_score(test_labels, predictions, average='micro')\n",
|
|||
|
"\n",
|
|||
|
"cm = confusion_matrix(test_labels.argmax(axis = 1), predictions.argmax(axis = 1))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Accuracy: 0.8099\n",
|
|||
|
"Precision:\n",
|
|||
|
"- Macro: 0.6076\n",
|
|||
|
"- Micro: 0.9471\n",
|
|||
|
"Recall:\n",
|
|||
|
"- Macro: 0.3708\n",
|
|||
|
"- Micro: 0.7981\n",
|
|||
|
"F1-measure:\n",
|
|||
|
"- Macro: 0.4410\n",
|
|||
|
"- Micro: 0.8662\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print(\"Accuracy: {:.4f}\\nPrecision:\\n- Macro: {:.4f}\\n- Micro: {:.4f}\\nRecall:\\n- Macro: {:.4f}\\n- Micro: {:.4f}\\nF1-measure:\\n- Macro: {:.4f}\\n- Micro: {:.4f}\".format(accuracy, macro_precision, micro_precision, macro_recall, micro_recall, macro_f1, micro_f1))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"In below cell, I used `matplotlib.pyplot` to **plot the confusion matrix** (of first *few results only* to keep the readings readable) using `heatmap` of `seaborn`."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSUAAAV0CAYAAAAhI3i0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xl8lOW5//HvPUlYVRRRIQkVW1xarYUWUKsiFQvUqnSlP0+1ttXD6XGptlW7aGu1p9upnurpplgFl8qiPXUFi2AtUBGIEiAQQBCKCRFXVHAhJPfvjxnoCDPPMpPMM3fuz/v1mhfJJN9c1/XMTeaZJ8/MGGutAAAAAAAAAKBUUkk3AAAAAAAAAMAvHJQEAAAAAAAAUFIclAQAAAAAAABQUhyUBAAAAAAAAFBSHJQEAAAAAAAAUFIclAQAAAAAAABQUokdlDTGjDPGrDHGrDPGfC9m9nZjzIvGmIYC6g40xvzNGNNojFlpjLk0RraHMWaxMWZZJnttAfUrjDFLjTEPF5DdaIxZYYypN8bUxczub4y5zxizOjP7CRFzR2bq7bq8YYy5LGbtb2W2V4MxZqoxpkeM7KWZ3MoodXOtDWNMX2PMY8aYZzP/HhAj+8VM7XZjzLCYdX+V2d7LjTF/McbsHyP7k0yu3hgz2xhTHad21tcuN8ZYY0y/GLV/bIxpzrrNT49T1xhzSeb/9kpjzH/HqDs9q+ZGY0x9nJmNMUOMMU/t+v9hjBkRI/sRY8zCzP+vh4wx++XJ5vz9EWWNBWSjrrF8+dB1FpANXWf5sllfz7vGAupGXWN5a4ets4DaoessIBt1jeXLh64zk+d+xhhzmDFmUWaNTTfGdIuRvdik72uDfhfky/4ps50bTPr/TlXM/G2Z65ab9H3QPlGzWV//jTFmW8y6U4wxG7Ju6yEx88YY81NjzNrM7fjNGNn5WXU3G2Puj5EdbYx5JpNdYIwZHCN7aibbYIy5wxhTmWvmrJ/znv2RKGssIBu6xgKykdZYnmzo+grKZ12fd40F1I60xvJkQ9dXQDZ0fYXkQ9dYQDbyGjM59llN9P2xXNmo95W5spH2xwLykfbJcmWzvha2P5arbtT7ypx1TYT9sYDakfbJ8mSj3lfmykbaH8t8716PbWKssVzZqGssVzbqPn+ubJx9/ryP5yKssVy1o66xnHWjrLE8dePs8+fKR11jubJR9sVyPv6Nsb7y5UPXWEA28u8xwDnW2pJfJFVIWi/p/ZK6SVom6UMx8iMlfVRSQwG1B0j6aObjfSWtjVpbkpG0T+bjKkmLJB0fs/63Jd0j6eECet8oqV+B2/wOSRdkPu4maf8Cb7cXJB0aI1MjaYOknpnPZ0j6asTsMZIaJPWSVClpjqTD464NSf8t6XuZj78n6Zcxsh+UdKSkJyQNi1l3jKTKzMe/jFl3v6yPvynp5ji1M9cPlPRXSf/Mt27y1P6xpMsj3D65sp/I3E7dM58fHKfnrK/fIOlHMWvPlvSpzMenS3oiRnaJpFMyH39d0k/yZHP+/oiyxgKyUddYvnzoOgvIhq6zfNkoayygbtQ1li8fus6C+g5bZwF1o66xfPnQdaY89zNK/+78f5nrb5b0nzGyQyUNUsB9SED29MzXjKSpueqG5LPX2P8o8/8kSjbz+TBJd0naFrPuFElfiLDG8uW/JulOSamANRa6TyDpz5K+EqPuWkkfzFx/oaQpEbMfl/S8pCMy118n6fyQ2d+zPxJljQVkQ9dYQDbSGsuTDV1fQfkoayygdqQ1licbur6Ceg5bXyG1Q9dYrqzSJzJEXmO51oKi74/lyka9r8yVjbQ/FpCPtE+Wb/0r2v5Yrro/VrT7ylzZSPtjQX1nfT3vPlme2lHvK3NlI+2PZb6+12ObGGssVzbqGsuVjbrPnysbZ58/5+O5iGssV+2oayxXNuo+f+Bj0KD1FVA76hrLlY28xjLfs/vxb9T1FZCPtMbyZCP/HuPCxbVLUmdKjpC0zlr7nLV2h6RpksZHDVtr50l6tZDC1toWa+0zmY/flNSo9IGzKFlrrd31l/SqzMVGrW2MqZX0aUl/jNV0kTJ/ARop6TZJstbusNZuLeBHjZa03lr7z5i5Skk9Tfov6r0kbY6Y+6Ckp6y1b1lrd0r6u6TPBgXyrI3xSt8pKfPvZ6JmrbWN1to1YY3myc7O9C1JT0mqjZF9I+vT3gpYZwH/H34t6coCs6HyZP9T0i+ste9mvufFuHWNMUbSBKUfnMapbSXt+mtnH+VZZ3myR0qal/n4MUmfz5PN9/sjdI3ly8ZYY/nyoessIBu6zkJ+ZwausWJ+34bkQ9dZWO2gdRaQjbrG8uVD11nA/cypku7LXJ9vjeXMWmuXWms35uo1QnZm5mtW0mLl/z2WL/+GtHt791TuNZYza4ypkPQrpddYrL6DZo2Y/09J11lr2zPfl2uNBdY2xuyr9O2215lsAdnQNZYn2ybpXWvt2sz1eX+PZXp7z/5I5vYJXWO5spmeQtdYQDbSGsuTDV1fQfkoayxfNqo82dD1FVY3aH2F5CP9HsuRPVAx1lgekfbHcol6X5knG2l/LCAfeZ8sj9D9sU4QaX8sTJR9shwirbE8Iu2PBTy2CV1j+bJR1lhANnSNBWQjra+Qx3OBa6yYx4IB2dA1FlY3bH0F5EPXWEA20hrLkv34t5DfYbvzBfwey84W9XsMKGdJHZSsUfqvrbs0KcYD1Y5ijBmk9F/3F8XIVGROMX9R0mPW2shZSTcqfYfRHiOTzUqabYx52hgzMUbu/ZJekjTZpJ+G80djTO8C6v8/xdspkbW2WdL1kjZJapH0urV2dsR4g6SRxpgDjTG9lP5L2MA49TMOsda2ZPppkXRwAT+jWF+XNCtOwKSf2vW8pC9L+lHM7FmSmq21y+LkslyceXrA7fmempDHEZJONumnAP7dGDO8gNonS9pirX02Zu4ySb/KbLPrJX0/RrZB0lmZj7+oCOtsj98fsdZYIb97IuZD19me2TjrLDsbd43l6DnWGtsjH2ud5dlekdbZHtnYa2yPfKR1tuf9jNLPLNiatTOa9z6zmPuooKxJP6X2XEmPxs0bYyYr/Zf+oyT9Jkb2YkkP7vq/VUDfP82ssV8bY7rHzH9A0pdM+mlhs4wxh8esLaX/iDZ3jwecYdkLJM00xjQpvb1/ESWr9MG8qqyng31Bwb/H9twfOVAR11iObBx5sxHWWM5slPUVkI+0xgL6jrLGcmUjra+AulLI+grIR1pjObIvK94ay7XPGvW+stD93SjZsPvJnPmI95V7ZWPcV+brO8p9Za5snPvJoG0Wdl+ZKxv1vjJXNur+WL7HNlHWWDGPi6Jk862xvNmI6ytnPuIaC+o7bI3ly0ZZY2HbK2x95ctHWWP5snH3+bMf/xbymDL24+cI2diPK4FyltRBSZPjulL+9VAm/bpDf5Z0WcgO3XtYa9ustUOU/uvECGPMMRHrnSHpRWvt0wU1nHaitfajkj4l6SJjzMiIuUqln676B2vtUEnblT7lPDKTfm2psyTdGzN3gNJ/VTpMUrWk3saYc6JkrbWNSp+e/pjSD1KWSdoZGCpDxpirlO77T3Fy1tqrrLUDM7mLY9TrJekqxTyQmeUPSj9gGqL0geQbYmQrJR2g9NMQr5A0wxiT6/97kLNV2J33f0r6VmabfUuZv4xG9HWl/089rfTTbXcEfXOhvz+KzQblo6yzXNmo6yw7m6kTeY3lqBtrjeXIR15nAds7dJ3lyMZaYznykdbZnvczSp81vte3RclGvY+KkP29pHnW2vlx89baryn9+79R0pciZkcq/WAh6CBTUN3vK32QarikvpK+GzPfXdI71tphkm6VdHucmTMC11ie7LcknW6trZU0WemnJIdmJR2t9IOXXxtjFkt6U3nuL/Psj0TaLytmXyZCNu8aC8pGWV+58ib9um2hayygdugaC8iGrq8I2ytwfQXkQ9dYrqy11iriGssodJ+107IR98dy5iPeV+bKRr2vzJWNel+ZKxtnfyxoe4fdV+bKRr2vzJWNuj9WzGObTsuGrLG82Yj
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.figure.Figure at 0x24d8cf39f28>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import seaborn as sb\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"cm_plt = pd.DataFrame(cm[:73])\n",
|
|||
|
"\n",
|
|||
|
"plt.figure(figsize = (25, 25))\n",
|
|||
|
"ax = plt.axes()\n",
|
|||
|
"\n",
|
|||
|
"sb.heatmap(cm_plt, annot=True)\n",
|
|||
|
"\n",
|
|||
|
"ax.xaxis.set_ticks_position('top')\n",
|
|||
|
"\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Pipeline"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Now, I took the data from [Coconut - Wikipedia](https://en.wikipedia.org/wiki/Coconut) to check if the classifier is able to **correctly** predict the label(s) or not.\n",
|
|||
|
"\n",
|
|||
|
"And here is the output:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Example labels: [('coconut', 'oilseed')]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"example_text = '''The coconut tree (Cocos nucifera) is a member of the family Arecaceae (palm family) and the only species of the genus Cocos.\n",
|
|||
|
"The term coconut can refer to the whole coconut palm or the seed, or the fruit, which, botanically, is a drupe, not a nut.\n",
|
|||
|
"The spelling cocoanut is an archaic form of the word.\n",
|
|||
|
"The term is derived from the 16th-century Portuguese and Spanish word coco meaning \"head\" or \"skull\", from the three indentations on the coconut shell that resemble facial features.\n",
|
|||
|
"Coconuts are known for their versatility ranging from food to cosmetics.\n",
|
|||
|
"They form a regular part of the diets of many people in the tropics and subtropics.\n",
|
|||
|
"Coconuts are distinct from other fruits for their endosperm containing a large quantity of water (also called \"milk\"), and when immature, may be harvested for the potable coconut water.\n",
|
|||
|
"When mature, they can be used as seed nuts or processed for oil, charcoal from the hard shell, and coir from the fibrous husk.\n",
|
|||
|
"When dried, the coconut flesh is called copra.\n",
|
|||
|
"The oil and milk derived from it are commonly used in cooking and frying, as well as in soaps and cosmetics.\n",
|
|||
|
"The husks and leaves can be used as material to make a variety of products for furnishing and decorating.\n",
|
|||
|
"The coconut also has cultural and religious significance in certain societies, particularly in India, where it is used in Hindu rituals.'''\n",
|
|||
|
"\n",
|
|||
|
"example_preds = classifier.predict(vectorizer.transform([example_text]))\n",
|
|||
|
"example_labels = mlb.inverse_transform(example_preds)\n",
|
|||
|
"print(\"Example labels: {}\".format(example_labels))"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.6.6"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|