When will you grow up?

Text Generation(using LSTM) 본문

02. Study/Keras

Text Generation(using LSTM)

미카이 2017. 11. 24. 02:38

이번시간에는 RNN model을 기반으로 generative models을 만들어 보겠습니다.

추가적으로 예측모델(Predictive models)을 만드는데 그럴듯한 스퀀스를 생성합니다.


이 예제에서는 원하는 large text를 이용하여 학습을 시켜 스퀀스 data를 생성할 수 있습니다. 


문제 발생시 : 삭제 하도록 하겠습니다


Input Text data : http://www.bioinf.jku.at/publications/older/2604.pdf

위 주소에서 크롤링을 하여 Text 파일로 "LSTM"이라는 이름으로 저장을 시켰다.






[Data]






[모델 정의]






[Epochs 시각화]







[학습된 모델로 Text Generation]









1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
import sys
from matplotlib import pyplot as plt
 
# text load 및 load된 text 대문자들을 다 소문자로 변경
filename = "LSTM_1.txt"
raw_text = open(filename).read()
raw_text = raw_text.lower()
 
# 중복된 문자들을 지우고 정수에 맵핑시키고, 각 단어들을 만든다
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
 
n_chars = len(raw_text)
n_vocab = len(chars)
print ("Total Characters: ", n_chars) #71469
print ("Total Vocab: ", n_vocab) #68
 
# 위에서 만든 dict된 쌍을 input data 및 output(Y) data를 생성한다
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print ("Total Patterns: ", n_patterns) #71369
 
# reshape X to be [samples, time steps, features]
= numpy.reshape(dataX, (n_patterns, seq_length, 1))
# normalize
= X / float(n_vocab)
# output data를 one-hot-encoding
= np_utils.to_categorical(dataY)
 
# LSTM model 정의
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(256))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
 
 
# 학습단계마다 weights을 저장시키는데 이전보다 결과가 좋아질시 저장
filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True,
mode='min')
callbacks_list = [checkpoint]
 
hist = model.fit(X, y, nb_epoch=500, batch_size=128, callbacks=callbacks_list)
 
 
#모델 시각
fig, loss_ax = plt.subplots()
 
acc_ax = loss_ax.twinx()
 
loss_ax.plot(hist.history['loss'], 'y', label='train loss')
 
loss_ax.set_xlabel('epoch')
loss_ax.set_ylabel('loss')
 
loss_ax.legend(loc='upper left')
acc_ax.legend(loc='lower left')
 
plt.show()
 
 
# load the network weights
filename = "weights-improvement-374-0.0183.hdf5"
model.load_weights(filename)
model.compile(loss='categorical_crossentropy', optimizer='adam')
 
int_to_char = dict((i, c) for i, c in enumerate(chars))
 
# pick a random seed
start = numpy.random.randint(0len(dataX)-1)
pattern = dataX[start]
print ("Seed:")
print ("\""''.join([int_to_char[value] for value in pattern]), "\"")
# generate characters
for i in range(1000):
    x = numpy.reshape(pattern, (1len(pattern), 1))
    x = x / float(n_vocab)
    prediction = model.predict(x, verbose=0)
    index = numpy.argmax(prediction)
    result = int_to_char[index]
    seq_in = [int_to_char[value] for value in pattern]
    sys.stdout.write(result)
    pattern.append(index)
    pattern = pattern[1:len(pattern)]
print ("\nDone.")
 

cs

[전체 Code]






reference

Keras.io

Deep Learning for Python

'02. Study > Keras' 카테고리의 다른 글

VGG+ResNet(Fashion_MNIST)  (0) 2017.12.10
Sequence-to Sequence  (0) 2017.12.08
Long Short Term Memory(using IMDB dataset)  (0) 2017.11.12
Deep Neural Network(using pima dataset)  (2) 2017.11.05
Convolution Neural Network (using FASHION-MNIST data)  (0) 2017.11.05
Comments