Keras nlp入门 imdb电影评论分类

it2022-05-05  151

import numpy as np from keras import models from keras import layers from keras.datasets import imdb import matplotlib.pyplot as plt (train_data,train_labels),(test_data,test_labels) = imdb.load_data(num_words=10000) #1.数据处理 def vectorize_sequences(sequences, dimension=10000): results = np.zeros((len(sequences), dimension)) #数据集长度,每个评论维度10000 for i, sequence in enumerate(sequences): results[i, sequence] = 1 # one-hot return results x_train = vectorize_sequences(train_data) x_test = vectorize_sequences(test_data) y_train = np.asarray(train_labels).astype('float32') # 向量化标签数据 y_test = np.asarray(test_labels).astype('float32') #2.构建网络模型 model = models.Sequential() model.add(layers.Dense(16, activation='relu',input_shape=(10000,))) model.add(layers.Dense(16,activation='relu')) model.add(layers.Dense(1,activation='sigmoid')) #3.编译 model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy']) x_val = x_train[:10000] partial_x_train = x_train[10000:] y_val = y_train[:10000] partial_y_train = y_train[10000:] history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, validation_data=(x_val, y_val)) yaml_string = model.to_yaml() with open('./models/imdb.yaml', 'w') as outfile: outfile.write(yaml_string) model.save_weights('./models/imdb.h5') history_dict = history.history loss_values = history_dict['loss'] val_loss_values = history_dict['val_loss'] epochs = range(1, len(loss_values) + 1) plt.plot(epochs, loss_values, 'bo', label='Training loss')#bo:blue dot蓝点 plt.plot(epochs, val_loss_values, 'b', label='Validation loss')#b: blue蓝色 plt.title('Training and validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.savefig('./loss.jpg') plt.show() plt.clf() acc_values = history_dict['acc'] val_acc_values = history_dict['val_acc'] plt.plot(epochs, acc_values, 'bo', label='Training acc') plt.plot(epochs, val_acc_values, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.savefig('./acc.jpg') plt.show()

 


最新回复(0)