package com.alibaba.rsocket.loadbalance;

import com.alibaba.rsocket.AbstractRSocket;
import com.alibaba.rsocket.RSocketRequesterSupport;
import com.alibaba.rsocket.cloudevents.CloudEventImpl;
import com.alibaba.rsocket.cloudevents.CloudEventRSocket;
import com.alibaba.rsocket.cloudevents.EventReply;
import com.alibaba.rsocket.events.ServicesExposedEvent;
import com.alibaba.rsocket.health.RSocketServiceHealth;
import com.alibaba.rsocket.listen.RSocketResponderSupport;
import com.alibaba.rsocket.metadata.GSVRoutingMetadata;
import com.alibaba.rsocket.metadata.MessageMimeTypeMetadata;
import com.alibaba.rsocket.metadata.RSocketCompositeMetadata;
import com.alibaba.rsocket.metadata.RSocketMimeType;
import com.alibaba.rsocket.observability.RsocketErrorCode;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCountUtil;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.core.RSocketConnector;
import io.rsocket.exceptions.ConnectionErrorException;
import io.rsocket.plugins.RSocketInterceptor;
import io.rsocket.uri.UriTransportRegistry;
import io.rsocket.util.ByteBufPayload;
import java.net.ConnectException;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.propertyeditors.StringArrayPropertyEditor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

/* loaded from: input_file:BOOT-INF/lib/alibaba-rsocket-core-1.1.6.jar:com/alibaba/rsocket/loadbalance/LoadBalancedRSocket.class */
public class LoadBalancedRSocket extends AbstractRSocket implements CloudEventRSocket {
    private RandomSelector<RSocket> randomSelector;
    private final String serviceId;
    private final Flux<Collection<String>> urisFactory;
    private Collection<String> firstBatchUris;
    private Map<String, RSocket> activeSockets;
    private static final int HEALTH_CHECK_INTERVAL_SECONDS = 15;
    private final RSocketRequesterSupport requesterSupport;
    private final ByteBuf healthCheckCompositeByteBuf;
    private boolean isServiceProvider;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) LoadBalancedRSocket.class);
    private static Predicate<? super Throwable> CONNECTION_ERROR_PREDICATE = th -> {
        return (th instanceof ClosedChannelException) || (th instanceof ConnectionErrorException) || (th instanceof ConnectException);
    };
    private Collection<String> lastRSocketUris = new ArrayList();
    private final Set<String> unHealthyUriSet = new HashSet();
    private long lastHealthCheckTimeStamp = System.currentTimeMillis();
    private long lastRefreshTimeStamp = System.currentTimeMillis();
    private final int retryCount = 12;

    public Set<String> getUnHealthyUriSet() {
        return this.unHealthyUriSet;
    }

    public Collection<String> getLastRSocketUris() {
        return this.lastRSocketUris;
    }

    public long getLastHealthCheckTimeStamp() {
        return this.lastHealthCheckTimeStamp;
    }

    public long getLastRefreshTimeStamp() {
        return this.lastRefreshTimeStamp;
    }

    public LoadBalancedRSocket(String str, Flux<Collection<String>> flux, RSocketRequesterSupport rSocketRequesterSupport) {
        this.isServiceProvider = false;
        this.serviceId = str;
        this.randomSelector = new RandomSelector<>(this.serviceId, new ArrayList());
        this.urisFactory = flux;
        this.requesterSupport = rSocketRequesterSupport;
        if (!rSocketRequesterSupport.exposedServices().get().isEmpty()) {
            this.isServiceProvider = true;
        }
        this.activeSockets = new HashMap();
        this.urisFactory.subscribe(this::refreshRsockets);
        ByteBuf content = RSocketCompositeMetadata.from(new GSVRoutingMetadata(null, RSocketServiceHealth.class.getCanonicalName(), "check", null), new MessageMimeTypeMetadata(RSocketMimeType.Hessian)).getContent();
        this.healthCheckCompositeByteBuf = Unpooled.copiedBuffer(content);
        ReferenceCountUtil.safeRelease(content);
        startHealthCheckTimer();
        checkUnhealthyUris();
    }

    private void refreshRsockets(Collection<String> collection) {
        if (isSameWithLastUris(collection)) {
            return;
        }
        if (this.firstBatchUris == null) {
            this.firstBatchUris = collection;
        }
        log.info(RsocketErrorCode.message("RST-300207", this.serviceId, String.join(StringArrayPropertyEditor.DEFAULT_SEPARATOR, collection)));
        this.lastRefreshTimeStamp = System.currentTimeMillis();
        this.lastRSocketUris = collection;
        this.unHealthyUriSet.clear();
        Flux.fromIterable(collection).flatMap(str -> {
            return this.activeSockets.containsKey(str) ? Mono.just(Tuples.of(str, this.activeSockets.get(str))) : connect(str).flatMap(rSocket -> {
                return healthCheck(rSocket, str).map(bool -> {
                    return Tuples.of(str, rSocket);
                });
            }).doOnError(th -> {
                log.error(RsocketErrorCode.message("RST-400500", str), th);
                this.unHealthyUriSet.add(str);
                tryToReconnect(str, th);
            });
        }).collectList().subscribe(list -> {
            if (list.isEmpty()) {
                return;
            }
            HashMap hashMap = new HashMap();
            Iterator it = list.iterator();
            while (it.hasNext()) {
                Tuple2 tuple2 = (Tuple2) it.next();
                hashMap.put((String) tuple2.getT1(), (RSocket) tuple2.getT2());
            }
            HashMap hashMap2 = new HashMap();
            for (Map.Entry<String, RSocket> entry : this.activeSockets.entrySet()) {
                if (!hashMap.containsKey(entry.getKey())) {
                    hashMap2.put(entry.getKey(), entry.getValue());
                }
            }
            HashMap hashMap3 = new HashMap();
            for (Map.Entry entry2 : hashMap.entrySet()) {
                if (!this.activeSockets.containsKey(entry2.getKey())) {
                    hashMap3.put((String) entry2.getKey(), (RSocket) entry2.getValue());
                }
            }
            this.activeSockets = hashMap;
            this.randomSelector = new RandomSelector<>(this.serviceId, new ArrayList(this.activeSockets.values()));
            if (!hashMap2.isEmpty()) {
                Flux.fromIterable(hashMap2.entrySet()).delaySubscription(Duration.ofSeconds(this.isServiceProvider ? 45 : 15)).subscribe(entry3 -> {
                    log.info(RsocketErrorCode.message("RST-200011", entry3.getKey()));
                    ((RSocket) entry3.getValue()).dispose();
                });
            }
            for (Map.Entry entry4 : hashMap3.entrySet()) {
                ((RSocket) entry4.getValue()).onClose().subscribe(r7 -> {
                    onRSocketClosed((String) entry4.getKey(), (RSocket) entry4.getValue(), null);
                });
            }
        });
    }

    @Override // io.rsocket.RSocket
    @NotNull
    public Mono<Payload> requestResponse(@NotNull Payload payload) {
        RSocket next = this.randomSelector.next();
        if (next != null) {
            return next.requestResponse(payload).onErrorResume(CONNECTION_ERROR_PREDICATE, th -> {
                onRSocketClosed(next, th);
                return requestResponse(payload);
            });
        }
        ReferenceCountUtil.safeRelease(payload);
        return Mono.error(new NoAvailableConnectionException(RsocketErrorCode.message("RST-200404", this.serviceId)));
    }

    @Override // io.rsocket.RSocket
    @NotNull
    public Mono<Void> fireAndForget(@NotNull Payload payload) {
        RSocket next = this.randomSelector.next();
        if (next != null) {
            return next.fireAndForget(payload).onErrorResume(CONNECTION_ERROR_PREDICATE, th -> {
                onRSocketClosed(next, th);
                return fireAndForget(payload);
            });
        }
        ReferenceCountUtil.safeRelease(payload);
        return Mono.error(new NoAvailableConnectionException(RsocketErrorCode.message("RST-200404", this.serviceId)));
    }

    @Override // com.alibaba.rsocket.cloudevents.CloudEventRSocket
    public Mono<Void> fireCloudEvent(CloudEventImpl<?> cloudEventImpl) {
        try {
            return metadataPush(cloudEventToMetadataPushPayload(cloudEventImpl));
        } catch (Exception e) {
            return Mono.error(e);
        }
    }

    @Override // com.alibaba.rsocket.cloudevents.CloudEventRSocket
    public Mono<Void> fireEventReply(URI uri, EventReply eventReply) {
        return fireAndForget(constructEventReplyPayload(uri, eventReply));
    }

    public Mono<Void> fireCloudEventToUpstreamAll(CloudEventImpl<?> cloudEventImpl) {
        try {
            return Flux.fromIterable(getActiveSockets().values()).flatMap(rSocket -> {
                return rSocket.metadataPush(cloudEventToMetadataPushPayload(cloudEventImpl));
            }).doOnError(th -> {
                log.error(RsocketErrorCode.message("RST-610502", new Object[0]), th);
            }).then();
        } catch (Exception e) {
            return Mono.error(e);
        }
    }

    @Override // io.rsocket.RSocket
    @NotNull
    public Flux<Payload> requestStream(@NotNull Payload payload) {
        RSocket next = this.randomSelector.next();
        if (next != null) {
            return next.requestStream(payload).onErrorResume(CONNECTION_ERROR_PREDICATE, th -> {
                onRSocketClosed(next, th);
                return requestStream(payload);
            });
        }
        ReferenceCountUtil.safeRelease(payload);
        return Flux.error(new NoAvailableConnectionException(RsocketErrorCode.message("RST-200404", this.serviceId)));
    }

    @Override // io.rsocket.RSocket
    @NotNull
    public Flux<Payload> requestChannel(@NotNull Publisher<Payload> publisher) {
        RSocket next = this.randomSelector.next();
        return next == null ? Flux.error(new NoAvailableConnectionException(RsocketErrorCode.message("RST-200404", this.serviceId))) : next.requestChannel(publisher).onErrorResume(CONNECTION_ERROR_PREDICATE, th -> {
            onRSocketClosed(next, th);
            return requestChannel(publisher);
        });
    }

    @Override // io.rsocket.RSocket
    @NotNull
    public Mono<Void> metadataPush(@NotNull Payload payload) {
        return Flux.fromIterable(this.activeSockets.values()).flatMap(rSocket -> {
            return rSocket.metadataPush(payload);
        }).then();
    }

    @Override // com.alibaba.rsocket.AbstractRSocket, io.rsocket.RSocket, reactor.core.Disposable
    public void dispose() {
        super.dispose();
        for (RSocket rSocket : this.activeSockets.values()) {
            try {
                if (!rSocket.isDisposed()) {
                    rSocket.dispose();
                }
            } catch (Exception e) {
            }
        }
        this.activeSockets.clear();
    }

    public Map<String, RSocket> getActiveSockets() {
        return this.activeSockets;
    }

    public String getActiveUris() {
        return String.join(StringArrayPropertyEditor.DEFAULT_SEPARATOR, this.activeSockets.keySet());
    }

    public void refreshUnHealthyUris() {
        Iterator<String> it = this.unHealthyUriSet.iterator();
        while (it.hasNext()) {
            tryToReconnect(it.next(), null);
        }
    }

    public void onRSocketClosed(RSocket rSocket, @Nullable Throwable th) {
        for (Map.Entry<String, RSocket> entry : this.activeSockets.entrySet()) {
            if (entry.getValue() == rSocket) {
                onRSocketClosed(entry.getKey(), entry.getValue(), null);
            }
        }
        if (rSocket.isDisposed()) {
            return;
        }
        try {
            rSocket.dispose();
        } catch (Exception e) {
        }
    }

    public void onRSocketClosed(String str, @Nullable Throwable th) {
        if (this.activeSockets.containsKey(str)) {
            onRSocketClosed(str, this.activeSockets.get(str), th);
        }
    }

    public void onRSocketClosed(String str, RSocket rSocket, @Nullable Throwable th) {
        if (this.lastRSocketUris.contains(str)) {
            this.unHealthyUriSet.add(str);
            if (this.activeSockets.containsKey(str)) {
                this.activeSockets.remove(str);
                this.randomSelector = new RandomSelector<>(this.serviceId, new ArrayList(this.activeSockets.values()));
                log.error(RsocketErrorCode.message("RST-500407", str));
                tryToReconnect(str, th);
            }
            if (!rSocket.isDisposed()) {
                try {
                    rSocket.dispose();
                } catch (Exception e) {
                }
            }
        }
        if (!this.activeSockets.isEmpty() || this.lastRSocketUris.containsAll(this.firstBatchUris)) {
            return;
        }
        refreshRsockets(this.firstBatchUris);
    }

    public void onRSocketReconnected(String str, RSocket rSocket) {
        this.activeSockets.put(str, rSocket);
        this.unHealthyUriSet.remove(str);
        this.randomSelector = new RandomSelector<>(this.serviceId, new ArrayList(this.activeSockets.values()));
        rSocket.onClose().subscribe(r8 -> {
            onRSocketClosed(str, rSocket, null);
        });
        CloudEventImpl<ServicesExposedEvent> cloudEventImpl = this.requesterSupport.servicesExposedEvent().get();
        if (cloudEventImpl != null) {
            try {
                rSocket.metadataPush(cloudEventToMetadataPushPayload(cloudEventImpl)).subscribe();
            } catch (Exception e) {
            }
        }
    }

    public void tryToReconnect(String str, @Nullable Throwable th) {
        if (CONNECTION_ERROR_PREDICATE.test(th)) {
            Flux.range(1, 12).delayElements(Duration.ofSeconds(5L)).filter(num -> {
                return this.activeSockets.isEmpty() || !this.activeSockets.containsKey(str);
            }).subscribe(num2 -> {
                connect(str).flatMap(rSocket -> {
                    return healthCheck(rSocket, str).map(bool -> {
                        return rSocket;
                    });
                }).doOnError(th2 -> {
                    getUnHealthyUriSet().add(str);
                    log.error(RsocketErrorCode.message("RST-500408", num2, str), th2);
                }).subscribe(rSocket2 -> {
                    onRSocketReconnected(str, rSocket2);
                    log.info(RsocketErrorCode.message("RST-500203", str));
                });
            });
        }
    }

    Mono<RSocket> connect(String str) {
        try {
            RSocketConnector create = RSocketConnector.create();
            for (RSocketInterceptor rSocketInterceptor : this.requesterSupport.requestInterceptors()) {
                create.interceptors(interceptorRegistry -> {
                    interceptorRegistry.forRequester(rSocketInterceptor);
                });
            }
            for (RSocketInterceptor rSocketInterceptor2 : this.requesterSupport.responderInterceptors()) {
                create.interceptors(interceptorRegistry2 -> {
                    interceptorRegistry2.forResponder(rSocketInterceptor2);
                });
            }
            return create.setupPayload(this.requesterSupport.setupPayload(this.serviceId).get()).metadataMimeType(RSocketMimeType.CompositeMetadata.getType()).dataMimeType(RSocketMimeType.Hessian.getType()).acceptor((connectionSetupPayload, rSocket) -> {
                return this.requesterSupport.socketAcceptor().accept(connectionSetupPayload, rSocket).doOnNext(rSocket -> {
                    if (rSocket instanceof RSocketResponderSupport) {
                        ((RSocketResponderSupport) rSocket).setSourcing(this.serviceId.equals("*") ? "upstream:broker:*" : "upstream::" + this.serviceId);
                    }
                });
            }).connect(UriTransportRegistry.clientForUri(str));
        } catch (Exception e) {
            log.error(RsocketErrorCode.message("RST-400500", str), (Throwable) e);
            return Mono.error(new ConnectionErrorException(str));
        }
    }

    public void startHealthCheckTimer() {
        this.lastHealthCheckTimeStamp = System.currentTimeMillis();
        Flux.interval(Duration.ofSeconds(15L)).flatMap(l -> {
            return Flux.fromIterable(this.activeSockets.entrySet());
        }).subscribe((Consumer<? super R>) entry -> {
            healthCheck((RSocket) entry.getValue(), (String) entry.getKey()).doOnError(th -> {
                if (CONNECTION_ERROR_PREDICATE.test(th)) {
                    onRSocketClosed((String) entry.getKey(), (RSocket) entry.getValue(), th);
                }
            }).subscribe();
        });
    }

    public void checkUnhealthyUris() {
        Flux.interval(Duration.ofMinutes(5L)).filter(l -> {
            return !this.unHealthyUriSet.isEmpty();
        }).subscribe(l2 -> {
            for (String str : this.unHealthyUriSet) {
                if (!this.activeSockets.containsKey(str)) {
                    connect(str).flatMap(rSocket -> {
                        return healthCheck(rSocket, str).map(bool -> {
                            return rSocket;
                        });
                    }).subscribe((Consumer<? super R>) rSocket2 -> {
                        onRSocketReconnected(str, rSocket2);
                        log.info(RsocketErrorCode.message("RST-500203", str));
                    });
                }
            }
        });
    }

    private Mono<Boolean> healthCheck(RSocket rSocket, String str) {
        return rSocket.requestResponse(ByteBufPayload.create(Unpooled.EMPTY_BUFFER, this.healthCheckCompositeByteBuf.retainedDuplicate())).timeout(Duration.ofSeconds(15L)).handle((payload, synchronousSink) -> {
            if (payload.data().readByte() == -111) {
                synchronousSink.next(true);
            } else {
                synchronousSink.error(new Exception("Health check failed :" + str));
            }
        });
    }

    public boolean isSameWithLastUris(Collection<String> collection) {
        return this.lastRSocketUris.size() == collection.size() && this.lastRSocketUris.containsAll(collection) && collection.containsAll(this.lastRSocketUris);
    }
}
