package hivemall.optimizer;

import hivemall.utils.lang.Primitives;
import java.util.Map;
import javax.annotation.Nonnull;
import org.apache.lucene.util.packed.PackedInts;

/* loaded from: input_file:hivemall/optimizer/Regularization.class */
public abstract class Regularization {
    private static final float DEFAULT_LAMBDA = 1.0E-4f;
    protected final float lambda;

    /* loaded from: input_file:hivemall/optimizer/Regularization$ElasticNet.class */
    public static final class ElasticNet extends Regularization {
        private static final float DEFAULT_L1_RATIO = 0.5f;

        @Nonnull
        private final L1 l1;

        @Nonnull
        private final L2 l2;
        private final float l1Ratio;

        public ElasticNet(@Nonnull Map<String, String> map) {
            super(map);
            this.l1 = new L1(map);
            this.l2 = new L2(map);
            this.l1Ratio = Primitives.parseFloat(map.get("l1_ratio"), 0.5f);
            if (this.l1Ratio < PackedInts.COMPACT || this.l1Ratio > 1.0f) {
                throw new IllegalArgumentException("L1 ratio should be in [0.0, 1.0], but got " + this.l1Ratio);
            }
        }

        @Override // hivemall.optimizer.Regularization
        public float getRegularizer(float f) {
            return (this.l1Ratio * this.l1.getRegularizer(f)) + ((1.0f - this.l1Ratio) * this.l2.getRegularizer(f));
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Regularization$L1.class */
    public static final class L1 extends Regularization {
        public L1(Map<String, String> map) {
            super(map);
        }

        @Override // hivemall.optimizer.Regularization
        public float getRegularizer(float f) {
            return f > PackedInts.COMPACT ? 1.0f : -1.0f;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Regularization$L2.class */
    public static final class L2 extends Regularization {
        public L2(Map<String, String> map) {
            super(map);
        }

        @Override // hivemall.optimizer.Regularization
        public float getRegularizer(float f) {
            return f;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Regularization$PassThrough.class */
    public static final class PassThrough extends Regularization {
        public PassThrough(Map<String, String> map) {
            super(map);
        }

        @Override // hivemall.optimizer.Regularization
        public float getRegularizer(float f) {
            return PackedInts.COMPACT;
        }

        @Override // hivemall.optimizer.Regularization
        public float regularize(float f, float f2) {
            return f2;
        }
    }

    public Regularization(@Nonnull Map<String, String> map) {
        this.lambda = Primitives.parseFloat(map.get("lambda"), 1.0E-4f);
    }

    public float regularize(float f, float f2) {
        return f2 + (this.lambda * getRegularizer(f));
    }

    abstract float getRegularizer(float f);

    @Nonnull
    public static Regularization get(@Nonnull Map<String, String> map) throws IllegalArgumentException {
        String str = map.get("regularization");
        if (str != null && !"no".equalsIgnoreCase(str)) {
            if ("l1".equalsIgnoreCase(str)) {
                return new L1(map);
            }
            if ("l2".equalsIgnoreCase(str)) {
                return new L2(map);
            }
            if ("elasticnet".equalsIgnoreCase(str)) {
                return new ElasticNet(map);
            }
            if ("rda".equalsIgnoreCase(str)) {
                return new PassThrough(map);
            }
            throw new IllegalArgumentException("Unsupported regularization name: " + str);
        }
        return new PassThrough(map);
    }
}
