티스토리 뷰

5. 파이썬

[텐서플로2] 회귀분석 예제 regression.py

패스트코드블로그 2020. 5. 14. 23:08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
 
class SeqModel:
    def __init__(self):
        self.model = None
 
    @staticmethod
    @tf.function
    def simple_func():
        a = tf.constant(1)
        b = tf.constant(2)
        c = tf.constant(3)
        z = a + b + c
        return z
 
    def create_model(self):
        input = tf.keras.Input(shape=(1,))
        output = tf.keras.layers.Dense(1)(input)
        self.model = tf.keras.Model(input, output)
        """
        Total params: 2
        Trainable params: 2
        Non-trainable params: 0
        """
        print(self.model.summary())
 
    @staticmethod
    def make_random_data():
        x = np.random.uniform(low = -2, high= 2, size = 200)
        y = []
        for t in x:
            r = np.random.normal(loc = 0.0,
                                 scale=(0.5 + t*t/3),
                                 size = None)
            y.append(r)
        return x, 1.726*- 0.84 + np.array(y)
 
    def execute(self):
        (x, y) = self.make_random_data()
        x_train, y_train = x[:150], y[:150]
        x_test, y_test = x[:150], y[:150]
        self.model = tf.keras.Sequential()
        self.model.add(tf.keras.layers.Dense(units=1, input_dim=1))
        self.model.compile(optimizer='sgd', loss='mse')
        self.model.save('./data/simple_model.h5')
 
    def load_model(self):
        (x, y) = self.make_random_data()
        x_train, y_train = x[:150], y[:150]
        x_train, y_train = x[:150], y[:150]
        load_model = tf.keras.models.load_model('./data/simple_model.h5')
        history = load_model.fit(x_train, y_train, epochs = 300, validation_split = 0.3)
        epochs = np.arange(1300 +1)
        plt.plot(epochs, history.history['loss'], label = 'Training loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()
 
 
if __name__ == '__main__':
    m = SeqModel()
    m.create_model()
    m.execute()
    m.load_model()
cs
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함