/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;

public class RNN
extends RecurrentBlock {
    private Activation activation;

    RNN(Builder builder) {
        super(builder);
        this.activation = builder.activation;
        this.gates = 1;
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        Device device = inputs.head().getDevice();
        NDList rnnParams = new NDList();
        for (Parameter parameter : this.parameters.values()) {
            rnnParams.add(parameterStore.getValue(parameter, device, training));
        }
        NDArray input = inputs.head();
        if (inputs.size() == 1) {
            int batchIndex = this.batchFirst ? 0 : 1;
            inputs.add(input.getManager().zeros(new Shape((long)this.numLayers * (long)this.getNumDirections(), input.size(batchIndex), this.stateSize)));
        }
        NDList outputs = ex.rnn(input, (NDArray)inputs.get(1), rnnParams, this.hasBiases, this.numLayers, this.activation, this.dropRate, training, this.bidirectional, this.batchFirst);
        if (this.returnState) {
            return outputs;
        }
        outputs.stream().skip(1L).forEach(NDArray::close);
        return new NDList((NDArray)outputs.get(0));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder
    extends RecurrentBlock.BaseBuilder<Builder> {
        @Override
        protected Builder self() {
            return this;
        }

        public Builder setActivation(Activation activation) {
            this.activation = activation;
            return this.self();
        }

        public RNN build() {
            Preconditions.checkArgument(this.stateSize > 0L && this.numLayers > 0, "Must set stateSize and numLayers");
            return new RNN(this);
        }
    }

    public static enum Activation {
        RELU,
        TANH;

    }
}

