summaryrefslogtreecommitdiffstats
path: root/mlplib/backward.py
blob: 1868b5abb6ccb2cea069b123ea24460d379420d0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
# -*- coding: utf-8 -*-
#
# Copyright 2021 Michael Büsch <m@bues.ch>
#
# Licensed under the Apache License version 2.0
# or the MIT license, at your option.
# SPDX-License-Identifier: Apache-2.0 OR MIT
#
"""

__all__ = [
    "BackpropGrad",
    "BackpropGrads",
    "backward_prop",
]

from mlplib.forward import forward_prop
from mlplib.loss import Loss
from mlplib.parameters import Parameters
from mlplib.util import GenericIter
from collections import deque, namedtuple
from dataclasses import dataclass, field
from typing import Callable, Optional, Tuple
import numpy as np

BackpropGrad = namedtuple("BackpropGrad", ["dw", "db"])

@dataclass
class BackpropGrads(object):
    """
    Calculated backpropagation gradients.
    """

    dw: deque[np.ndarray] = field(default_factory=deque)
    db: deque[np.ndarray] = field(default_factory=deque)

    def __iter__(self):
        return BackpropGradsIter(self, len(self.dw))

    def __reversed__(self):
        return BackpropGradsIter(self, len(self.dw), True, len(self.dw) - 1)

@dataclass
class BackpropGradsIter(GenericIter):
    def __next__(self):
        obj, pos = self._next()
        return BackpropGrad(
            dw=obj.dw[pos],
            db=obj.db[pos],
        )

def backward_prop(x: np.ndarray,
                  y: np.ndarray,
                  params: Parameters,
                  loss: Loss)\
                  -> Tuple[BackpropGrads, np.ndarray]:
    assert len(params.weights) >= 1
    assert len(params.weights) == len(params.biases)
    assert len(params.weights) == len(params.actvns)

    # Number of samples.
    m = x.shape[0]

    # Run the network in forward direction.
    yh, netstate = forward_prop(x, params, store_netstate=True)
    assert isinstance(netstate, list)
    assert len(params.weights) == len(netstate)
    assert x.shape[0] == yh.shape[0]

    # Calculate the net output loss derivative.
    da = loss.fn_d(yh, y)
    assert da.shape == (m, y.shape[1])

    grads = BackpropGrads()
    last_layer = len(params.weights) - 1

    for l_rev, ((w, _b, actv, *_), state) in\
            enumerate(zip(reversed(params),
                          reversed(netstate))):
        assert w.ndim == 2 and w.shape[0] == state.x.shape[1]
        assert da.ndim == 2 and da.shape == (m, w.shape[1])

        # Calculate the neuron backward propagation.
        dw, db, da = actv.backward_prop(w, da, state.x, state.z, m, l_rev != last_layer)

        # Store the calculated gradients.
        grads.dw.appendleft(dw)
        grads.db.appendleft(db)

    assert len(grads.dw) == len(params.weights)
    assert len(grads.db) == len(params.weights)
    return grads, yh

# vim: ts=4 sw=4 expandtab
bues.ch cgit interface