.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/bicluster/plot_bicluster_newsgroups.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_bicluster_plot_bicluster_newsgroups.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_bicluster_plot_bicluster_newsgroups.py:


================================================================
Biclustering documents with the Spectral Co-clustering algorithm
================================================================

This example demonstrates the Spectral Co-clustering algorithm on the
twenty newsgroups dataset. The 'comp.os.ms-windows.misc' category is
excluded because it contains many posts containing nothing but data.

The TF-IDF vectorized posts form a word frequency matrix, which is
then biclustered using Dhillon's Spectral Co-Clustering algorithm. The
resulting document-word biclusters indicate subsets words used more
often in those subsets documents.

For a few of the best biclusters, its most common document categories
and its ten most important words get printed. The best biclusters are
determined by their normalized cut. The best words are determined by
comparing their sums inside and outside the bicluster.

For comparison, the documents are also clustered using
MiniBatchKMeans. The document clusters derived from the biclusters
achieve a better V-measure than clusters found by MiniBatchKMeans.

.. GENERATED FROM PYTHON SOURCE LINES 25-173




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Vectorizing...
    Coclustering...
    Done in 1.40s. V-measure: 0.4415
    MiniBatchKMeans...
    Done in 2.31s. V-measure: 0.3015

    Best biclusters:
    ----------------
    bicluster 0 : 8 documents, 6 words
    categories   : 100% talk.politics.mideast
    words        : cosmo, angmar, alfalfa, alphalpha, proline, benson

    bicluster 1 : 1948 documents, 4325 words
    categories   : 23% talk.politics.guns, 18% talk.politics.misc, 17% sci.med
    words        : gun, guns, geb, banks, gordon, clinton, pitt, cdt, surrender, veal

    bicluster 2 : 1259 documents, 3534 words
    categories   : 27% soc.religion.christian, 25% talk.politics.mideast, 25% alt.atheism
    words        : god, jesus, christians, kent, sin, objective, belief, christ, faith, moral

    bicluster 3 : 775 documents, 1623 words
    categories   : 30% comp.windows.x, 25% comp.sys.ibm.pc.hardware, 20% comp.graphics
    words        : scsi, nada, ide, vga, esdi, isa, kth, s3, vlb, bmug

    bicluster 4 : 2180 documents, 2802 words
    categories   : 18% comp.sys.mac.hardware, 16% sci.electronics, 16% comp.sys.ibm.pc.hardware
    words        : voltage, shipping, circuit, receiver, processing, scope, mpce, analog, kolstad, umass







|

.. code-block:: default


    import operator
    from collections import defaultdict
    from time import time

    import numpy as np

    from sklearn.cluster import MiniBatchKMeans, SpectralCoclustering
    from sklearn.datasets import fetch_20newsgroups
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.cluster import v_measure_score


    def number_normalizer(tokens):
        """Map all numeric tokens to a placeholder.

        For many applications, tokens that begin with a number are not directly
        useful, but the fact that such a token exists can be relevant.  By applying
        this form of dimensionality reduction, some methods may perform better.
        """
        return ("#NUMBER" if token[0].isdigit() else token for token in tokens)


    class NumberNormalizingVectorizer(TfidfVectorizer):
        def build_tokenizer(self):
            tokenize = super().build_tokenizer()
            return lambda doc: list(number_normalizer(tokenize(doc)))


    # exclude 'comp.os.ms-windows.misc'
    categories = [
        "alt.atheism",
        "comp.graphics",
        "comp.sys.ibm.pc.hardware",
        "comp.sys.mac.hardware",
        "comp.windows.x",
        "misc.forsale",
        "rec.autos",
        "rec.motorcycles",
        "rec.sport.baseball",
        "rec.sport.hockey",
        "sci.crypt",
        "sci.electronics",
        "sci.med",
        "sci.space",
        "soc.religion.christian",
        "talk.politics.guns",
        "talk.politics.mideast",
        "talk.politics.misc",
        "talk.religion.misc",
    ]
    newsgroups = fetch_20newsgroups(categories=categories)
    y_true = newsgroups.target

    vectorizer = NumberNormalizingVectorizer(stop_words="english", min_df=5)
    cocluster = SpectralCoclustering(
        n_clusters=len(categories), svd_method="arpack", random_state=0
    )
    kmeans = MiniBatchKMeans(
        n_clusters=len(categories), batch_size=20000, random_state=0, n_init=3
    )

    print("Vectorizing...")
    X = vectorizer.fit_transform(newsgroups.data)

    print("Coclustering...")
    start_time = time()
    cocluster.fit(X)
    y_cocluster = cocluster.row_labels_
    print(
        "Done in {:.2f}s. V-measure: {:.4f}".format(
            time() - start_time, v_measure_score(y_cocluster, y_true)
        )
    )

    print("MiniBatchKMeans...")
    start_time = time()
    y_kmeans = kmeans.fit_predict(X)
    print(
        "Done in {:.2f}s. V-measure: {:.4f}".format(
            time() - start_time, v_measure_score(y_kmeans, y_true)
        )
    )

    feature_names = vectorizer.get_feature_names_out()
    document_names = list(newsgroups.target_names[i] for i in newsgroups.target)


    def bicluster_ncut(i):
        rows, cols = cocluster.get_indices(i)
        if not (np.any(rows) and np.any(cols)):
            import sys

            return sys.float_info.max
        row_complement = np.nonzero(np.logical_not(cocluster.rows_[i]))[0]
        col_complement = np.nonzero(np.logical_not(cocluster.columns_[i]))[0]
        # Note: the following is identical to X[rows[:, np.newaxis],
        # cols].sum() but much faster in scipy <= 0.16
        weight = X[rows][:, cols].sum()
        cut = X[row_complement][:, cols].sum() + X[rows][:, col_complement].sum()
        return cut / weight


    def most_common(d):
        """Items of a defaultdict(int) with the highest values.

        Like Counter.most_common in Python >=2.7.
        """
        return sorted(d.items(), key=operator.itemgetter(1), reverse=True)


    bicluster_ncuts = list(bicluster_ncut(i) for i in range(len(newsgroups.target_names)))
    best_idx = np.argsort(bicluster_ncuts)[:5]

    print()
    print("Best biclusters:")
    print("----------------")
    for idx, cluster in enumerate(best_idx):
        n_rows, n_cols = cocluster.get_shape(cluster)
        cluster_docs, cluster_words = cocluster.get_indices(cluster)
        if not len(cluster_docs) or not len(cluster_words):
            continue

        # categories
        counter = defaultdict(int)
        for i in cluster_docs:
            counter[document_names[i]] += 1
        cat_string = ", ".join(
            "{:.0f}% {}".format(float(c) / n_rows * 100, name)
            for name, c in most_common(counter)[:3]
        )

        # words
        out_of_cluster_docs = cocluster.row_labels_ != cluster
        out_of_cluster_docs = np.where(out_of_cluster_docs)[0]
        word_col = X[:, cluster_words]
        word_scores = np.array(
            word_col[cluster_docs, :].sum(axis=0)
            - word_col[out_of_cluster_docs, :].sum(axis=0)
        )
        word_scores = word_scores.ravel()
        important_words = list(
            feature_names[cluster_words[i]] for i in word_scores.argsort()[:-11:-1]
        )

        print("bicluster {} : {} documents, {} words".format(idx, n_rows, n_cols))
        print("categories   : {}".format(cat_string))
        print("words        : {}\n".format(", ".join(important_words)))


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 6.314 seconds)


.. _sphx_glr_download_auto_examples_bicluster_plot_bicluster_newsgroups.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.3.X?urlpath=lab/tree/notebooks/auto_examples/bicluster/plot_bicluster_newsgroups.ipynb
        :alt: Launch binder
        :width: 150 px



    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/?path=auto_examples/bicluster/plot_bicluster_newsgroups.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_bicluster_newsgroups.py <plot_bicluster_newsgroups.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_bicluster_newsgroups.ipynb <plot_bicluster_newsgroups.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_