Add constraint to Dense layer; add callback
This commit is contained in:
parent
fa37c40f0b
commit
fad8640bbe
@ -11,6 +11,8 @@ from tensorflow.python.ops.gen_math_ops import mat_mul, sigmoid
|
|||||||
from tensorflow.python.ops.gen_nn_ops import bias_add
|
from tensorflow.python.ops.gen_nn_ops import bias_add
|
||||||
from tensorflow.python.ops.resource_variable_ops import ResourceVariable
|
from tensorflow.python.ops.resource_variable_ops import ResourceVariable
|
||||||
from keras.engine.keras_tensor import KerasTensor
|
from keras.engine.keras_tensor import KerasTensor
|
||||||
|
from keras.api._v2.keras.callbacks import Callback
|
||||||
|
from keras.api._v2.keras.constraints import Constraint
|
||||||
from keras.api._v2.keras.layers import Dense, Input
|
from keras.api._v2.keras.layers import Dense, Input
|
||||||
from keras.api._v2.keras.models import Model
|
from keras.api._v2.keras.models import Model
|
||||||
|
|
||||||
@ -41,7 +43,16 @@ def simple_layer_test() -> None:
|
|||||||
b_data = np.array([-1.0, 2.0], dtype=F32)
|
b_data = np.array([-1.0, 2.0], dtype=F32)
|
||||||
w_init = partial(init_params, w_data.T)
|
w_init = partial(init_params, w_data.T)
|
||||||
b_init = partial(init_params, b_data.T)
|
b_init = partial(init_params, b_data.T)
|
||||||
layer = Dense(units=2, activation='sigmoid', kernel_initializer=w_init, bias_initializer=b_init)(inputs)
|
|
||||||
|
class Const(Constraint):
|
||||||
|
def __init__(self, zero_mask: np.ndarray) -> None:
|
||||||
|
self.mask = zero_mask
|
||||||
|
|
||||||
|
def __call__(self, weights: ResourceVariable) -> ResourceVariable:
|
||||||
|
weights.assign(weights - self.mask * weights)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
layer = Dense(units=2, activation='sigmoid', kernel_initializer=w_init, bias_initializer=b_init, kernel_constraint=Const(w_data.T == 0))(inputs)
|
||||||
assert isinstance(layer, KerasTensor)
|
assert isinstance(layer, KerasTensor)
|
||||||
model = Model(inputs=inputs, outputs=layer)
|
model = Model(inputs=inputs, outputs=layer)
|
||||||
w_tensor = model.trainable_variables[0]
|
w_tensor = model.trainable_variables[0]
|
||||||
@ -50,13 +61,21 @@ def simple_layer_test() -> None:
|
|||||||
assert isinstance(b_tensor, ResourceVariable)
|
assert isinstance(b_tensor, ResourceVariable)
|
||||||
assert np.equal(w_tensor.numpy().T, w_data).all()
|
assert np.equal(w_tensor.numpy().T, w_data).all()
|
||||||
assert np.equal(b_tensor.numpy().T, b_data).all()
|
assert np.equal(b_tensor.numpy().T, b_data).all()
|
||||||
model.compile()
|
model.compile(optimizer='adam', loss='categorical_crossentropy')
|
||||||
x = np.array([[1.0, -2.0, 0.2]], dtype=F32)
|
x = np.array([[1.0, -2.0, 0.2]], dtype=F32)
|
||||||
print("input", x[0])
|
print("input", x[0])
|
||||||
y = model(x)
|
y = model(x)
|
||||||
assert isinstance(y, tf.Tensor)
|
assert isinstance(y, tf.Tensor)
|
||||||
print("output", np.array(y)[0])
|
print("output", np.array(y)[0])
|
||||||
assert y[0][1] == 0.5
|
assert y[0][1] == 0.5
|
||||||
|
samples = np.array([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]], dtype=F32)
|
||||||
|
labels = np.array([[0., 1.], [0., 2.], [3., 0.]], dtype=F32)
|
||||||
|
|
||||||
|
class CB(Callback):
|
||||||
|
def on_train_batch_begin(self, batch, logs=None):
|
||||||
|
print(f"...start of batch {batch}; model weights:")
|
||||||
|
print(self.model.trainable_variables[0].numpy())
|
||||||
|
model.fit(samples, labels, batch_size=1, callbacks=[CB()], verbose=0)
|
||||||
|
|
||||||
|
|
||||||
def build_model(input_shape: Sequence[int], *layers: tuple[np.ndarray, np.ndarray]) -> Model:
|
def build_model(input_shape: Sequence[int], *layers: tuple[np.ndarray, np.ndarray]) -> Model:
|
||||||
@ -127,5 +146,5 @@ def main(n: int) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# simple_layer_test()
|
simple_layer_test()
|
||||||
main(int(sys.argv[1]))
|
# main(int(sys.argv[1]))
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "feed_forward"
|
name = "feed_forward_ndarray"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user