![]() |
![]() |
![]() |
Overview
In this notebook, you'll learn to use the embeddings produced by the PaLM API to train a model that can classify different types of newsgroup posts based on the topic.
In this tutorial, you'll train a classifier to predict which class a newsgroup post belongs to.
Setup
First, download and install the PaLM API Python library.
pip install -q google-generativeai
import google.generativeai as palm
import re
import tqdm
import keras
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from keras import layers
from matplotlib.ticker import MaxNLocator
from sklearn.datasets import fetch_20newsgroups
import sklearn.metrics as skmetrics
2023-05-19 10:41:53.909089: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used. 2023-05-19 10:41:53.953625: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used. 2023-05-19 10:41:53.954573: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-05-19 10:41:54.928543: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Get an API Key
To get started, you'll need to create an API key.
palm.configure(api_key='YOUR_API_KEY')
models = [m for m in palm.list_models() if 'embedText' in m.supported_generation_methods]
model = models[0]
Dataset
The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets. The split between the training and test datasets are based on messages posted before and after a specific date. For this tutorial, you will be using the subsets of the training and test datasets. You will preprocess and organize the data into Pandas dataframes.
newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')
# View list of class names for dataset
newsgroups_train.target_names
['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']
Here is an example of what a data point from the training set looks like.
idx = newsgroups_train.data[0].index('Lines')
print(newsgroups_train.data[0][idx:])
Lines: 15 I was wondering if anyone out there could enlighten me on this car I saw the other day. It was a 2-door sports car, looked to be from the late 60s/ early 70s. It was called a Bricklin. The doors were really small. In addition, the front bumper was separate from the rest of the body. This is all I know. If anyone can tellme a model name, engine specs, years of production, where this car is made, history, or whatever info you have on this funky looking car, please e-mail. Thanks, - IL ---- brought to you by your neighborhood Lerxst ----
Now you will begin preprocessing the data for this tutorial. Remove any sensitive information like names, email, or redundant parts of the text like "From: "
and "\nSubject: "
. Organize the information into a Pandas dataframe so it is more readable.
def preprocess_newsgroup_data(newsgroup_dataset):
# Apply functions to remove names, emails, and extraneous words from data points in newsgroups.data
newsgroup_dataset.data = [re.sub(r'[\w\.-]+@[\w\.-]+', '', d) for d in newsgroup_dataset.data] # Remove email
newsgroup_dataset.data = [re.sub(r"\([^()]*\)", "", d) for d in newsgroup_dataset.data] # Remove names
newsgroup_dataset.data = [d.replace("From: ", "") for d in newsgroup_dataset.data] # Remove "From: "
newsgroup_dataset.data = [d.replace("\nSubject: ", "") for d in newsgroup_dataset.data] # Remove "\nSubject: "
# Put data points into dataframe
df_processed = pd.DataFrame(newsgroup_dataset.data, columns=['Text'])
df_processed['Label'] = newsgroup_dataset.target
# Match label to target name index
df_processed['Class Name'] = ''
for idx, row in df_processed.iterrows():
df_processed.at[idx, 'Class Name'] = newsgroup_dataset.target_names[row['Label']]
return df_processed
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)
df_train.head()
Next, you will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories to run through this tutorial. Choose the science categories to compare.
def sample_data(df, num_samples, classes_to_keep):
df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)
df = df[df['Class Name'].str.contains(classes_to_keep)]
# Reset the encoding of the labels after sampling and dropping certain categories
df['Class Name'] = df['Class Name'].astype('category')
df['Encoded Label'] = df['Class Name'].cat.codes
return df
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = 'sci' # Class name should contain 'sci' in it to keep science categories
df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)
df_train.value_counts('Class Name')
Class Name sci.crypt 100 sci.electronics 100 sci.med 100 sci.space 100 Name: count, dtype: int64
df_test.value_counts('Class Name')
Class Name sci.crypt 25 sci.electronics 25 sci.med 25 sci.space 25 Name: count, dtype: int64
Create the embeddings
Next, you need to compute the text embeddings. You will be using the PaLM API to generate embeddings. For a basic understanding of how the generation of embeddings works, it's recommended to go through the embeddings quickstart notebook first.
from tqdm.auto import tqdm
tqdm.pandas()
from google.api_core import retry
def make_embed_text_fn(model):
@retry.Retry(timeout=300.0)
def embed_fn(text: str) -> list[float]:
return palm.generate_embeddings(model=model, text=text)['embedding']
return embed_fn
def create_embeddings(model, df):
df['Embeddings'] = df['Text'].progress_apply(make_embed_text_fn(model))
return df
df_train = create_embeddings(model, df_train)
df_test = create_embeddings(model, df_test)
0%| | 0/400 [00:00<?, ?it/s] 0%| | 0/100 [00:00<?, ?it/s]
df_train.head()
Build a simple classification model
Here you will define a simple model with one hidden layer and a single class probability output. The prediction will correspond to the probability of a piece of text being a particular class of news. When you build your model, Keras will automatically shuffle the data points.
def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
inputs = x = keras.Input(input_size)
x = layers.Dense(input_size, activation='relu')(x)
x = layers.Dense(num_classes, activation='sigmoid')(x)
return keras.Model(inputs=[inputs], outputs=x)
# Derive the embedding size from the first training element.
embedding_size = len(df_train['Embeddings'].iloc[0])
# Give your model a different name, as you have already used the variable name 'model'
classifier = build_classification_model(embedding_size, len(df_train['Class Name'].unique()))
classifier.summary()
classifier.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer = keras.optimizers.Adam(learning_rate=0.001),
metrics=['accuracy'])
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 768)] 0 dense (Dense) (None, 768) 590592 dense_1 (Dense) (None, 4) 3076 ================================================================= Total params: 593668 (2.26 MB) Trainable params: 593668 (2.26 MB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ 2023-05-19 10:48:14.210811: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-05-19 10:48:14.211994: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
embedding_size
768
Train the model to classify newsgroups
Finally, you can train a simple model. Use a small number of epochs to avoid overfitting. The first epoch takes much longer than the rest, because the embeddings need to be computed only once.
NUM_EPOCHS = 20
BATCH_SIZE = 32
# Split the x and y components of the train and validation subsets.
y_train = df_train['Encoded Label']
x_train = np.stack(df_train['Embeddings'])
y_val = df_test['Encoded Label']
x_val = np.stack(df_test['Embeddings'])
# Train the model for the desired number of epochs.
callback = keras.callbacks.EarlyStopping(monitor='accuracy', patience=3)
history = classifier.fit(x=x_train,
y=y_train,
validation_data=(x_val, y_val),
callbacks=[callback],
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,)
Epoch 1/20 /usr/local/google/home/markdaoust/venv3/lib/python3.10/site-packages/keras/src/backend.py:5714: UserWarning: "`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a Softmax activation and thus does not represent logits. Was this intended? output, from_logits = _get_logits( 13/13 [==============================] - 1s 21ms/step - loss: 1.2284 - accuracy: 0.6025 - val_loss: 1.0366 - val_accuracy: 0.9100 Epoch 2/20 13/13 [==============================] - 0s 8ms/step - loss: 0.8011 - accuracy: 0.9450 - val_loss: 0.6998 - val_accuracy: 0.8500 Epoch 3/20 13/13 [==============================] - 0s 7ms/step - loss: 0.4617 - accuracy: 0.9700 - val_loss: 0.4686 - val_accuracy: 0.8900 Epoch 4/20 13/13 [==============================] - 0s 7ms/step - loss: 0.2774 - accuracy: 0.9725 - val_loss: 0.3689 - val_accuracy: 0.8700 Epoch 5/20 13/13 [==============================] - 0s 7ms/step - loss: 0.1821 - accuracy: 0.9775 - val_loss: 0.3158 - val_accuracy: 0.8800 Epoch 6/20 13/13 [==============================] - 0s 8ms/step - loss: 0.1335 - accuracy: 0.9800 - val_loss: 0.2899 - val_accuracy: 0.8800 Epoch 7/20 13/13 [==============================] - 0s 7ms/step - loss: 0.1080 - accuracy: 0.9775 - val_loss: 0.2952 - val_accuracy: 0.8600 Epoch 8/20 13/13 [==============================] - 0s 7ms/step - loss: 0.0852 - accuracy: 0.9825 - val_loss: 0.2593 - val_accuracy: 0.8900 Epoch 9/20 13/13 [==============================] - 0s 7ms/step - loss: 0.0708 - accuracy: 0.9850 - val_loss: 0.2523 - val_accuracy: 0.9000 Epoch 10/20 13/13 [==============================] - 0s 7ms/step - loss: 0.0579 - accuracy: 0.9900 - val_loss: 0.2678 - val_accuracy: 0.9000 Epoch 11/20 13/13 [==============================] - 0s 7ms/step - loss: 0.0504 - accuracy: 0.9950 - val_loss: 0.2313 - val_accuracy: 0.9200 Epoch 12/20 13/13 [==============================] - 0s 7ms/step - loss: 0.0429 - accuracy: 0.9975 - val_loss: 0.2417 - val_accuracy: 0.9100 Epoch 13/20 13/13 [==============================] - 0s 6ms/step - loss: 0.0379 - accuracy: 0.9975 - val_loss: 0.2340 - val_accuracy: 0.9100 Epoch 14/20 13/13 [==============================] - 0s 7ms/step - loss: 0.0307 - accuracy: 0.9975 - val_loss: 0.2364 - val_accuracy: 0.9200 Epoch 15/20 13/13 [==============================] - 0s 8ms/step - loss: 0.0262 - accuracy: 1.0000 - val_loss: 0.2261 - val_accuracy: 0.9100 Epoch 16/20 13/13 [==============================] - 0s 6ms/step - loss: 0.0224 - accuracy: 1.0000 - val_loss: 0.2313 - val_accuracy: 0.9200 Epoch 17/20 13/13 [==============================] - 0s 6ms/step - loss: 0.0207 - accuracy: 1.0000 - val_loss: 0.2169 - val_accuracy: 0.9200 Epoch 18/20 13/13 [==============================] - 0s 6ms/step - loss: 0.0172 - accuracy: 1.0000 - val_loss: 0.2245 - val_accuracy: 0.9200
Evaluate model performance
Use Keras Model.evaluate
to get the loss and accuracy on the test dataset.
classifier.evaluate(x=x_val, y=y_val, return_dict=True)
4/4 [==============================] - 0s 4ms/step - loss: 0.2245 - accuracy: 0.9200 {'loss': 0.22447504103183746, 'accuracy': 0.9200000166893005}
One way to evaluate your model performance is to visualize the classifier performance. Use plot_history
to see the loss and accuracy trends over the epochs.
def plot_history(history):
"""
Plotting training and validation learning curves.
Args:
history: model history with all the metric measures
"""
fig, (ax1, ax2) = plt.subplots(1,2)
fig.set_size_inches(20, 8)
# Plot loss
ax1.set_title('Loss')
ax1.plot(history.history['loss'], label = 'train')
ax1.plot(history.history['val_loss'], label = 'test')
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')
ax1.legend(['Train', 'Validation'])
# Plot accuracy
ax2.set_title('Accuracy')
ax2.plot(history.history['accuracy'], label = 'train')
ax2.plot(history.history['val_accuracy'], label = 'test')
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Epoch')
ax2.legend(['Train', 'Validation'])
plt.show()
plot_history(history)
Another way to view model performance, beyond just measuring loss and accuracy is to use a confusion matrix. The confusion matrix allows you to assess the performance of the classification model beyond accuracy. You can see what misclassified points get classified as. In order to build the confusion matrix for this multi-class classification problem, get the actual values in the test set and the predicted values.
Start by generating the predicted class for each example in the validation set using Model.predict()
.
y_hat = classifier.predict(x=x_val)
y_hat = np.argmax(y_hat, axis=1)
4/4 [==============================] - 0s 2ms/step
labels_dict = dict(zip(df_test['Class Name'], df_test['Encoded Label']))
labels_dict
{'sci.crypt': 0, 'sci.electronics': 1, 'sci.med': 2, 'sci.space': 3}
cm = skmetrics.confusion_matrix(y_val, y_hat)
disp = skmetrics.ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=labels_dict.keys())
disp.plot(xticks_rotation='vertical')
plt.title('Confusion matrix for newsgroup test dataset');
plt.grid(False)
Next steps
You've now created your own text classifier using embeddings generated from the PaLM API! Try using your own textual data to train a model. One possible dataset could be the Jigsaw Toxic Comment Classification Challenge to create your own toxicity classifier.
To learn more about how you can use the embeddings, check out the examples available. To learn how to use other services in the PaLM API, visit the various quickstart guides: