This notebook contains an excerpt from the Python Data Science Handbook by Jake VanderPlas; the content is available on GitHub.
The text is released under the CC-BY-NC-ND license, and code is released under the MIT license. If you find this content useful, please consider supporting the work by buying the book!
In Depth: Naive Bayes Classification¶
The previous four sections have given a general overview of the concepts of machine learning. In this section and the ones that follow, we will be taking a closer look at several specific algorithms for supervised and unsupervised learning, starting here with naive Bayes classification.
Naive Bayes models are a group of extremely fast and simple classification algorithms that are often suitable for very high-dimensional datasets. Because they are so fast and have so few tunable parameters, they end up being very useful as a quick-and-dirty baseline for a classification problem. This section will focus on an intuitive explanation of how naive Bayes classifiers work, followed by a couple examples of them in action on some datasets.
Bayesian Classification¶
Naive Bayes classifiers are built on Bayesian classification methods. These rely on Bayes's theorem, which is an equation describing the relationship of conditional probabilities of statistical quantities. In Bayesian classification, we're interested in finding the probability of a label given some observed features, which we can write as $P(L~|~{\rm features})$. Bayes's theorem tells us how to express this in terms of quantities we can compute more directly:
$$ P(L~|~{\rm features}) = \frac{P({\rm features}~|~L)P(L)}{P({\rm features})} $$If we are trying to decide between two labels—let's call them $L_1$ and $L_2$—then one way to make this decision is to compute the ratio of the posterior probabilities for each label:
$$ \frac{P(L_1~|~{\rm features})}{P(L_2~|~{\rm features})} = \frac{P({\rm features}~|~L_1)}{P({\rm features}~|~L_2)}\frac{P(L_1)}{P(L_2)} $$All we need now is some model by which we can compute $P({\rm features}~|~L_i)$ for each label. Such a model is called a generative model because it specifies the hypothetical random process that generates the data. Specifying this generative model for each label is the main piece of the training of such a Bayesian classifier. The general version of such a training step is a very difficult task, but we can make it simpler through the use of some simplifying assumptions about the form of this model.
This is where the "naive" in "naive Bayes" comes in: if we make very naive assumptions about the generative model for each label, we can find a rough approximation of the generative model for each class, and then proceed with the Bayesian classification. Different types of naive Bayes classifiers rest on different naive assumptions about the data, and we will examine a few of these in the following sections.
We begin with the standard imports:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
Gaussian Naive Bayes¶
Perhaps the easiest naive Bayes classifier to understand is Gaussian naive Bayes. In this classifier, the assumption is that data from each label is drawn from a simple Gaussian distribution. Imagine that you have the following data:
from sklearn.datasets import make_blobs
X, y = make_blobs(100, 2, centers=2, random_state=2, cluster_std=1.5)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='RdBu');
One extremely fast way to create a simple model is to assume that the data is described by a Gaussian distribution with no covariance between dimensions. This model can be fit by simply finding the mean and standard deviation of the points within each label, which is all you need to define such a distribution. The result of this naive Gaussian assumption is shown in the following figure:
The ellipses here represent the Gaussian generative model for each label, with larger probability toward the center of the ellipses. With this generative model in place for each class, we have a simple recipe to compute the likelihood $P({\rm features}~|~L_1)$ for any data point, and thus we can quickly compute the posterior ratio and determine which label is the most probable for a given point.
This procedure is implemented in Scikit-Learn's sklearn.naive_bayes.GaussianNB
estimator:
from sklearn.naive_bayes import GaussianNB
model = GaussianNB()
model.fit(X, y);
Now let's generate some new data and predict the label:
rng = np.random.RandomState(0)
Xnew = [-6, -14] + [14, 18] * rng.rand(2000, 2)
ynew = model.predict(Xnew)
Now we can plot this new data to get an idea of where the decision boundary is:
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='RdBu')
lim = plt.axis()
plt.scatter(Xnew[:, 0], Xnew[:, 1], c=ynew, s=20, cmap='RdBu', alpha=0.1)
plt.axis(lim);
We see a slightly curved boundary in the classifications—in general, the boundary in Gaussian naive Bayes is quadratic.
A nice piece of this Bayesian formalism is that it naturally allows for probabilistic classification, which we can compute using the predict_proba
method:
yprob = model.predict_proba(Xnew)
yprob[-8:].round(2)
array([[0.89, 0.11], [1. , 0. ], [1. , 0. ], [1. , 0. ], [1. , 0. ], [1. , 0. ], [0. , 1. ], [0.15, 0.85]])
The columns give the posterior probabilities of the first and second label, respectively. If you are looking for estimates of uncertainty in your classification, Bayesian approaches like this can be a useful approach.
Of course, the final classification will only be as good as the model assumptions that lead to it, which is why Gaussian naive Bayes often does not produce very good results. Still, in many cases—especially as the number of features becomes large—this assumption is not detrimental enough to prevent Gaussian naive Bayes from being a useful method.
Multinomial Naive Bayes¶
The Gaussian assumption just described is by no means the only simple assumption that could be used to specify the generative distribution for each label. Another useful example is multinomial naive Bayes, where the features are assumed to be generated from a simple multinomial distribution. The multinomial distribution describes the probability of observing counts among a number of categories, and thus multinomial naive Bayes is most appropriate for features that represent counts or count rates.
The idea is precisely the same as before, except that instead of modeling the data distribution with the best-fit Gaussian, we model the data distribuiton with a best-fit multinomial distribution.
Example: Classifying Text¶
One place where multinomial naive Bayes is often used is in text classification, where the features are related to word counts or frequencies within the documents to be classified. We discussed the extraction of such features from text in Feature Engineering; here we will use the sparse word count features from the 20 Newsgroups corpus to show how we might classify these short documents into categories.
Let's download the data and take a look at the target names:
from sklearn.datasets import fetch_20newsgroups
data = fetch_20newsgroups()
data.target_names
--------------------------------------------------------------------------- OSError Traceback (most recent call last) File /opt/conda/lib/python3.10/urllib/request.py:1348, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args) 1347 try: -> 1348 h.request(req.get_method(), req.selector, req.data, headers, 1349 encode_chunked=req.has_header('Transfer-encoding')) 1350 except OSError as err: # timeout error File /opt/conda/lib/python3.10/http/client.py:1282, in HTTPConnection.request(self, method, url, body, headers, encode_chunked) 1281 """Send a complete request to the server.""" -> 1282 self._send_request(method, url, body, headers, encode_chunked) File /opt/conda/lib/python3.10/http/client.py:1328, in HTTPConnection._send_request(self, method, url, body, headers, encode_chunked) 1327 body = _encode(body, 'body') -> 1328 self.endheaders(body, encode_chunked=encode_chunked) File /opt/conda/lib/python3.10/http/client.py:1277, in HTTPConnection.endheaders(self, message_body, encode_chunked) 1276 raise CannotSendHeader() -> 1277 self._send_output(message_body, encode_chunked=encode_chunked) File /opt/conda/lib/python3.10/http/client.py:1037, in HTTPConnection._send_output(self, message_body, encode_chunked) 1036 del self._buffer[:] -> 1037 self.send(msg) 1039 if message_body is not None: 1040 1041 # create a consistent interface to message_body File /opt/conda/lib/python3.10/http/client.py:975, in HTTPConnection.send(self, data) 974 if self.auto_open: --> 975 self.connect() 976 else: File /opt/conda/lib/python3.10/http/client.py:1447, in HTTPSConnection.connect(self) 1445 "Connect to a host on a given (SSL) port." -> 1447 super().connect() 1449 if self._tunnel_host: File /opt/conda/lib/python3.10/http/client.py:941, in HTTPConnection.connect(self) 940 sys.audit("http.client.connect", self, self.host, self.port) --> 941 self.sock = self._create_connection( 942 (self.host,self.port), self.timeout, self.source_address) 943 # Might fail in OSs that don't implement TCP_NODELAY File /opt/conda/lib/python3.10/socket.py:845, in create_connection(address, timeout, source_address) 844 try: --> 845 raise err 846 finally: 847 # Break explicitly a reference cycle File /opt/conda/lib/python3.10/socket.py:833, in create_connection(address, timeout, source_address) 832 sock.bind(source_address) --> 833 sock.connect(sa) 834 # Break explicitly a reference cycle OSError: [Errno 99] Cannot assign requested address During handling of the above exception, another exception occurred: URLError Traceback (most recent call last) Cell In[7], line 3 1 from sklearn.datasets import fetch_20newsgroups ----> 3 data = fetch_20newsgroups() 4 data.target_names File /opt/conda/lib/python3.10/site-packages/sklearn/datasets/_twenty_newsgroups.py:269, in fetch_20newsgroups(data_home, subset, categories, shuffle, random_state, remove, download_if_missing, return_X_y) 267 if download_if_missing: 268 logger.info("Downloading 20news dataset. This may take a few minutes.") --> 269 cache = _download_20newsgroups( 270 target_dir=twenty_home, cache_path=cache_path 271 ) 272 else: 273 raise IOError("20Newsgroups dataset not found") File /opt/conda/lib/python3.10/site-packages/sklearn/datasets/_twenty_newsgroups.py:74, in _download_20newsgroups(target_dir, cache_path) 71 os.makedirs(target_dir) 73 logger.info("Downloading dataset from %s (14 MB)", ARCHIVE.url) ---> 74 archive_path = _fetch_remote(ARCHIVE, dirname=target_dir) 76 logger.debug("Decompressing %s", archive_path) 77 tarfile.open(archive_path, "r:gz").extractall(path=target_dir) File /opt/conda/lib/python3.10/site-packages/sklearn/datasets/_base.py:1324, in _fetch_remote(remote, dirname) 1302 """Helper function to download a remote dataset into path 1303 1304 Fetch a dataset pointed by remote's url, save into path using remote's (...) 1320 Full path of the created file. 1321 """ 1323 file_path = remote.filename if dirname is None else join(dirname, remote.filename) -> 1324 urlretrieve(remote.url, file_path) 1325 checksum = _sha256(file_path) 1326 if remote.checksum != checksum: File /opt/conda/lib/python3.10/urllib/request.py:241, in urlretrieve(url, filename, reporthook, data) 224 """ 225 Retrieve a URL into a temporary location on disk. 226 (...) 237 data file as well as the resulting HTTPMessage object. 238 """ 239 url_type, path = _splittype(url) --> 241 with contextlib.closing(urlopen(url, data)) as fp: 242 headers = fp.info() 244 # Just return the local path and the "headers" for file:// 245 # URLs. No sense in performing a copy unless requested. File /opt/conda/lib/python3.10/urllib/request.py:216, in urlopen(url, data, timeout, cafile, capath, cadefault, context) 214 else: 215 opener = _opener --> 216 return opener.open(url, data, timeout) File /opt/conda/lib/python3.10/urllib/request.py:525, in OpenerDirector.open(self, fullurl, data, timeout) 523 for processor in self.process_response.get(protocol, []): 524 meth = getattr(processor, meth_name) --> 525 response = meth(req, response) 527 return response File /opt/conda/lib/python3.10/urllib/request.py:634, in HTTPErrorProcessor.http_response(self, request, response) 631 # According to RFC 2616, "2xx" code indicates that the client's 632 # request was successfully received, understood, and accepted. 633 if not (200 <= code < 300): --> 634 response = self.parent.error( 635 'http', request, response, code, msg, hdrs) 637 return response File /opt/conda/lib/python3.10/urllib/request.py:557, in OpenerDirector.error(self, proto, *args) 555 http_err = 0 556 args = (dict, proto, meth_name) + args --> 557 result = self._call_chain(*args) 558 if result: 559 return result File /opt/conda/lib/python3.10/urllib/request.py:496, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args) 494 for handler in handlers: 495 func = getattr(handler, meth_name) --> 496 result = func(*args) 497 if result is not None: 498 return result File /opt/conda/lib/python3.10/urllib/request.py:749, in HTTPRedirectHandler.http_error_302(self, req, fp, code, msg, headers) 746 fp.read() 747 fp.close() --> 749 return self.parent.open(new, timeout=req.timeout) File /opt/conda/lib/python3.10/urllib/request.py:519, in OpenerDirector.open(self, fullurl, data, timeout) 516 req = meth(req) 518 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method()) --> 519 response = self._open(req, data) 521 # post-process response 522 meth_name = protocol+"_response" File /opt/conda/lib/python3.10/urllib/request.py:536, in OpenerDirector._open(self, req, data) 533 return result 535 protocol = req.type --> 536 result = self._call_chain(self.handle_open, protocol, protocol + 537 '_open', req) 538 if result: 539 return result File /opt/conda/lib/python3.10/urllib/request.py:496, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args) 494 for handler in handlers: 495 func = getattr(handler, meth_name) --> 496 result = func(*args) 497 if result is not None: 498 return result File /opt/conda/lib/python3.10/urllib/request.py:1391, in HTTPSHandler.https_open(self, req) 1390 def https_open(self, req): -> 1391 return self.do_open(http.client.HTTPSConnection, req, 1392 context=self._context, check_hostname=self._check_hostname) File /opt/conda/lib/python3.10/urllib/request.py:1351, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args) 1348 h.request(req.get_method(), req.selector, req.data, headers, 1349 encode_chunked=req.has_header('Transfer-encoding')) 1350 except OSError as err: # timeout error -> 1351 raise URLError(err) 1352 r = h.getresponse() 1353 except: URLError: <urlopen error [Errno 99] Cannot assign requested address>
For simplicity here, we will select just a few of these categories, and download the training and testing set:
categories = ['talk.religion.misc', 'soc.religion.christian',
'sci.space', 'comp.graphics']
train = fetch_20newsgroups(subset='train', categories=categories)
test = fetch_20newsgroups(subset='test', categories=categories)
--------------------------------------------------------------------------- OSError Traceback (most recent call last) File /opt/conda/lib/python3.10/urllib/request.py:1348, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args) 1347 try: -> 1348 h.request(req.get_method(), req.selector, req.data, headers, 1349 encode_chunked=req.has_header('Transfer-encoding')) 1350 except OSError as err: # timeout error File /opt/conda/lib/python3.10/http/client.py:1282, in HTTPConnection.request(self, method, url, body, headers, encode_chunked) 1281 """Send a complete request to the server.""" -> 1282 self._send_request(method, url, body, headers, encode_chunked) File /opt/conda/lib/python3.10/http/client.py:1328, in HTTPConnection._send_request(self, method, url, body, headers, encode_chunked) 1327 body = _encode(body, 'body') -> 1328 self.endheaders(body, encode_chunked=encode_chunked) File /opt/conda/lib/python3.10/http/client.py:1277, in HTTPConnection.endheaders(self, message_body, encode_chunked) 1276 raise CannotSendHeader() -> 1277 self._send_output(message_body, encode_chunked=encode_chunked) File /opt/conda/lib/python3.10/http/client.py:1037, in HTTPConnection._send_output(self, message_body, encode_chunked) 1036 del self._buffer[:] -> 1037 self.send(msg) 1039 if message_body is not None: 1040 1041 # create a consistent interface to message_body File /opt/conda/lib/python3.10/http/client.py:975, in HTTPConnection.send(self, data) 974 if self.auto_open: --> 975 self.connect() 976 else: File /opt/conda/lib/python3.10/http/client.py:1447, in HTTPSConnection.connect(self) 1445 "Connect to a host on a given (SSL) port." -> 1447 super().connect() 1449 if self._tunnel_host: File /opt/conda/lib/python3.10/http/client.py:941, in HTTPConnection.connect(self) 940 sys.audit("http.client.connect", self, self.host, self.port) --> 941 self.sock = self._create_connection( 942 (self.host,self.port), self.timeout, self.source_address) 943 # Might fail in OSs that don't implement TCP_NODELAY File /opt/conda/lib/python3.10/socket.py:845, in create_connection(address, timeout, source_address) 844 try: --> 845 raise err 846 finally: 847 # Break explicitly a reference cycle File /opt/conda/lib/python3.10/socket.py:833, in create_connection(address, timeout, source_address) 832 sock.bind(source_address) --> 833 sock.connect(sa) 834 # Break explicitly a reference cycle OSError: [Errno 99] Cannot assign requested address During handling of the above exception, another exception occurred: URLError Traceback (most recent call last) Cell In[8], line 3 1 categories = ['talk.religion.misc', 'soc.religion.christian', 2 'sci.space', 'comp.graphics'] ----> 3 train = fetch_20newsgroups(subset='train', categories=categories) 4 test = fetch_20newsgroups(subset='test', categories=categories) File /opt/conda/lib/python3.10/site-packages/sklearn/datasets/_twenty_newsgroups.py:269, in fetch_20newsgroups(data_home, subset, categories, shuffle, random_state, remove, download_if_missing, return_X_y) 267 if download_if_missing: 268 logger.info("Downloading 20news dataset. This may take a few minutes.") --> 269 cache = _download_20newsgroups( 270 target_dir=twenty_home, cache_path=cache_path 271 ) 272 else: 273 raise IOError("20Newsgroups dataset not found") File /opt/conda/lib/python3.10/site-packages/sklearn/datasets/_twenty_newsgroups.py:74, in _download_20newsgroups(target_dir, cache_path) 71 os.makedirs(target_dir) 73 logger.info("Downloading dataset from %s (14 MB)", ARCHIVE.url) ---> 74 archive_path = _fetch_remote(ARCHIVE, dirname=target_dir) 76 logger.debug("Decompressing %s", archive_path) 77 tarfile.open(archive_path, "r:gz").extractall(path=target_dir) File /opt/conda/lib/python3.10/site-packages/sklearn/datasets/_base.py:1324, in _fetch_remote(remote, dirname) 1302 """Helper function to download a remote dataset into path 1303 1304 Fetch a dataset pointed by remote's url, save into path using remote's (...) 1320 Full path of the created file. 1321 """ 1323 file_path = remote.filename if dirname is None else join(dirname, remote.filename) -> 1324 urlretrieve(remote.url, file_path) 1325 checksum = _sha256(file_path) 1326 if remote.checksum != checksum: File /opt/conda/lib/python3.10/urllib/request.py:241, in urlretrieve(url, filename, reporthook, data) 224 """ 225 Retrieve a URL into a temporary location on disk. 226 (...) 237 data file as well as the resulting HTTPMessage object. 238 """ 239 url_type, path = _splittype(url) --> 241 with contextlib.closing(urlopen(url, data)) as fp: 242 headers = fp.info() 244 # Just return the local path and the "headers" for file:// 245 # URLs. No sense in performing a copy unless requested. File /opt/conda/lib/python3.10/urllib/request.py:216, in urlopen(url, data, timeout, cafile, capath, cadefault, context) 214 else: 215 opener = _opener --> 216 return opener.open(url, data, timeout) File /opt/conda/lib/python3.10/urllib/request.py:525, in OpenerDirector.open(self, fullurl, data, timeout) 523 for processor in self.process_response.get(protocol, []): 524 meth = getattr(processor, meth_name) --> 525 response = meth(req, response) 527 return response File /opt/conda/lib/python3.10/urllib/request.py:634, in HTTPErrorProcessor.http_response(self, request, response) 631 # According to RFC 2616, "2xx" code indicates that the client's 632 # request was successfully received, understood, and accepted. 633 if not (200 <= code < 300): --> 634 response = self.parent.error( 635 'http', request, response, code, msg, hdrs) 637 return response File /opt/conda/lib/python3.10/urllib/request.py:557, in OpenerDirector.error(self, proto, *args) 555 http_err = 0 556 args = (dict, proto, meth_name) + args --> 557 result = self._call_chain(*args) 558 if result: 559 return result File /opt/conda/lib/python3.10/urllib/request.py:496, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args) 494 for handler in handlers: 495 func = getattr(handler, meth_name) --> 496 result = func(*args) 497 if result is not None: 498 return result File /opt/conda/lib/python3.10/urllib/request.py:749, in HTTPRedirectHandler.http_error_302(self, req, fp, code, msg, headers) 746 fp.read() 747 fp.close() --> 749 return self.parent.open(new, timeout=req.timeout) File /opt/conda/lib/python3.10/urllib/request.py:519, in OpenerDirector.open(self, fullurl, data, timeout) 516 req = meth(req) 518 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method()) --> 519 response = self._open(req, data) 521 # post-process response 522 meth_name = protocol+"_response" File /opt/conda/lib/python3.10/urllib/request.py:536, in OpenerDirector._open(self, req, data) 533 return result 535 protocol = req.type --> 536 result = self._call_chain(self.handle_open, protocol, protocol + 537 '_open', req) 538 if result: 539 return result File /opt/conda/lib/python3.10/urllib/request.py:496, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args) 494 for handler in handlers: 495 func = getattr(handler, meth_name) --> 496 result = func(*args) 497 if result is not None: 498 return result File /opt/conda/lib/python3.10/urllib/request.py:1391, in HTTPSHandler.https_open(self, req) 1390 def https_open(self, req): -> 1391 return self.do_open(http.client.HTTPSConnection, req, 1392 context=self._context, check_hostname=self._check_hostname) File /opt/conda/lib/python3.10/urllib/request.py:1351, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args) 1348 h.request(req.get_method(), req.selector, req.data, headers, 1349 encode_chunked=req.has_header('Transfer-encoding')) 1350 except OSError as err: # timeout error -> 1351 raise URLError(err) 1352 r = h.getresponse() 1353 except: URLError: <urlopen error [Errno 99] Cannot assign requested address>
Here is a representative entry from the data:
print(train.data[5])
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[9], line 1 ----> 1 print(train.data[5]) NameError: name 'train' is not defined
In order to use this data for machine learning, we need to be able to convert the content of each string into a vector of numbers. For this we will use the TF-IDF vectorizer (discussed in Feature Engineering), and create a pipeline that attaches it to a multinomial naive Bayes classifier:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
model = make_pipeline(TfidfVectorizer(), MultinomialNB())
With this pipeline, we can apply the model to the training data, and predict labels for the test data:
model.fit(train.data, train.target)
labels = model.predict(test.data)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[11], line 1 ----> 1 model.fit(train.data, train.target) 2 labels = model.predict(test.data) NameError: name 'train' is not defined
Now that we have predicted the labels for the test data, we can evaluate them to learn about the performance of the estimator. For example, here is the confusion matrix between the true and predicted labels for the test data:
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(test.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
xticklabels=train.target_names, yticklabels=train.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[12], line 2 1 from sklearn.metrics import confusion_matrix ----> 2 mat = confusion_matrix(test.target, labels) 3 sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False, 4 xticklabels=train.target_names, yticklabels=train.target_names) 5 plt.xlabel('true label') NameError: name 'test' is not defined
Evidently, even this very simple classifier can successfully separate space talk from computer talk, but it gets confused between talk about religion and talk about Christianity. This is perhaps an expected area of confusion!
The very cool thing here is that we now have the tools to determine the category for any string, using the predict()
method of this pipeline. Here's a quick utility function that will return the prediction for a single string:
def predict_category(s, train=train, model=model):
pred = model.predict([s])
return train.target_names[pred[0]]
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[13], line 1 ----> 1 def predict_category(s, train=train, model=model): 2 pred = model.predict([s]) 3 return train.target_names[pred[0]] NameError: name 'train' is not defined
Let's try it out:
predict_category('sending a payload to the ISS')
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[14], line 1 ----> 1 predict_category('sending a payload to the ISS') NameError: name 'predict_category' is not defined
predict_category('discussing islam vs atheism')
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[15], line 1 ----> 1 predict_category('discussing islam vs atheism') NameError: name 'predict_category' is not defined
predict_category('determining the screen resolution')
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[16], line 1 ----> 1 predict_category('determining the screen resolution') NameError: name 'predict_category' is not defined
Remember that this is nothing more sophisticated than a simple probability model for the (weighted) frequency of each word in the string; nevertheless, the result is striking. Even a very naive algorithm, when used carefully and trained on a large set of high-dimensional data, can be surprisingly effective.
When to Use Naive Bayes¶
Because naive Bayesian classifiers make such stringent assumptions about data, they will generally not perform as well as a more complicated model. That said, they have several advantages:
- They are extremely fast for both training and prediction
- They provide straightforward probabilistic prediction
- They are often very easily interpretable
- They have very few (if any) tunable parameters
These advantages mean a naive Bayesian classifier is often a good choice as an initial baseline classification. If it performs suitably, then congratulations: you have a very fast, very interpretable classifier for your problem. If it does not perform well, then you can begin exploring more sophisticated models, with some baseline knowledge of how well they should perform.
Naive Bayes classifiers tend to perform especially well in one of the following situations:
- When the naive assumptions actually match the data (very rare in practice)
- For very well-separated categories, when model complexity is less important
- For very high-dimensional data, when model complexity is less important
The last two points seem distinct, but they actually are related: as the dimension of a dataset grows, it is much less likely for any two points to be found close together (after all, they must be close in every single dimension to be close overall). This means that clusters in high dimensions tend to be more separated, on average, than clusters in low dimensions, assuming the new dimensions actually add information. For this reason, simplistic classifiers like naive Bayes tend to work as well or better than more complicated classifiers as the dimensionality grows: once you have enough data, even a simple model can be very powerful.