使用cross-entropy分类手写数字

本节使用network2.py。 network2.py不仅包含cross-entropy,还有其他一些改进。 使用help(network2.Network.SGD)可以获取文档信息。

通过代码验证发现使用cross-entropy的识别率要高一些。 但就此认定cross-entropy要优于二次代价函数是不严谨的。 还需要结合超参数的选择。

代码

cross entropy代价函数

class CrossEntropyCost来实现cross entropy代价函数。

因为代价函数涉及到了两处不同的计算改变,因此将这两处的函数封装到一个类中。

  1. cost的计算

@staticmethod
def fn(a, y):
    return np.sum(np.nan_to_num(-y*np.log(a)-(1-y)*np.log(1-a)))
  1. error的计算

@staticmethod
def delta(z, a, y):
    return (a-y)

二次代价函数

相应的,二次代价函数也封装成了一个类

class QuadraticCost(object):

    @staticmethod
    def fn(a, y):
        return 0.5*np.linalg.norm(a-y)**2

    @staticmethod
    def delta(z, a, y):
        return (a-y) * sigmoid_prime(z)

Last updated