Ví dụ 01: Train một phép cộng

Bài này mình sẽ giới thiệu với các bạn một ví dụ đơn giản để làm quen với tensorflow.

Input: hai số thực

Output: tổnghaisốthực

Mục tiêu: Train để model có thể thực hiện cộng hai số thực.

Mô hình của model:

capture-1

Source code

import tensorflow as tf
import numpy as np

batch_size = 1
input_layer = 2
hidden_layer = 5
output_layer = 1
learning_rate = 0.01

#Tạo place holder để dựng model, đây chính là nơi chứa dữ liệu cho model.
input=tf.placeholder(tf.float32, shape=(batch_size,input_layer))
target=tf.placeholder(tf.float32, shape=(batch_size, output_layer))

#Khai báo parameter của model
w1=tf.get_variable(“w1”, [input_layer, hidden_layer], dtype=tf.float32)
w2=tf.get_variable(“w2”, [hidden_layer, output_layer], dtype=tf.float32)

#Khai báo model
l1=tf.matmul(input, w1)
l2=tf.matmul(l1, w2)

#Tính loss và train
loss=tf.reduce_mean(0.5*tf.square(target – l2))
optimizer=tf.train.GradientDescentOptimizer(learning_rate)
train_opt=optimizer.minimize(loss)

#Tạo một session
with tf.Session() as sess:
#Khởi tạo giá trị của các parameters (w1, w2)
tf.initialize_all_variables().run()
for i in range(10000):
#Random giá trị cho input và target

x=np.random.random_sample((batch_size, input_layer))
y=np.array([[np.sum(x)]])

#train
#Lấy ra giá trị predict của model (l2) và train model(train_opt)
#Truyền vào giá trị input và target
_output, _ = sess.run([l2, train_opt], feed_dict={input: x , target:y })
if i % 100==0:
print(x, y, _output)

Kết quả

0 [[ 0.36704104 0.71434379]] [[ 1.08138483]] [[ 0.72435409]] 0.0637355
1000 [[ 0.63112473 0.76400415]] [[ 1.39512887]] [[ 1.39563251]] 1.26836e-07
2000 [[ 0.98767198 0.63501923]] [[ 1.62269121]] [[ 1.62257838]] 6.35876e-09
3000 [[ 0.38165326 0.04241553]] [[ 0.42406878]] [[ 0.42406371]] 1.28342e-11
4000 [[ 0.14268065 0.23031255]] [[ 0.37299319]] [[ 0.37299329]] 3.9968e-15
5000 [[ 0.90865944 0.54084765]] [[ 1.44950709]] [[ 1.44950533]] 1.59872e-11
6000 [[ 0.66731106 0.24386042]] [[ 0.91117149]] [[ 0.91116989]] 1.29496e-12
7000 [[ 0.82374591 0.74712057]] [[ 1.57086648]] [[ 1.57086515]] 8.59757e-13
8000 [[ 0.70094555 0.64503272]] [[ 1.34597827]] [[ 1.34597707]] 7.10543e-13
9000 [[ 0.06899652 0.39773195]] [[ 0.46672847]] [[ 0.46672854]] 1.77636e-15

Advertisements