diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java
index fc82c530fc6f..43954b81c7f4 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java
@@ -71,4 +71,22 @@ public interface ChannelEndpoint {
* @return the managed channel for this server
*/
ManagedChannel getChannel();
+
+ /**
+ * Records that an application RPC started on this endpoint.
+ *
+ *
This is used for request-load-aware routing decisions. Implementations must keep the count
+ * scoped to this endpoint instance so evicted or recreated endpoints do not share inflight state.
+ */
+ void incrementActiveRequests();
+
+ /**
+ * Records that an application RPC finished on this endpoint.
+ *
+ *
Implementations must not allow the count to go negative.
+ */
+ void decrementActiveRequests();
+
+ /** Returns the number of currently active application RPCs on this endpoint. */
+ int getActiveRequestCount();
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java
index 6e77ebd2692d..f4fdb41fe3c5 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java
@@ -82,7 +82,8 @@ public ChannelFinder(
ChannelEndpointCache endpointCache,
@Nullable EndpointLifecycleManager lifecycleManager,
@Nullable String finderKey) {
- this.rangeCache = new KeyRangeCache(Objects.requireNonNull(endpointCache), lifecycleManager);
+ this.rangeCache =
+ new KeyRangeCache(Objects.requireNonNull(endpointCache), lifecycleManager, finderKey);
this.lifecycleManager = lifecycleManager;
this.finderKey = finderKey;
}
@@ -91,6 +92,11 @@ void useDeterministicRandom() {
rangeCache.useDeterministicRandom();
}
+ @Nullable
+ String finderKey() {
+ return finderKey;
+ }
+
private static ExecutorService createCacheUpdatePool() {
ThreadPoolExecutor executor =
new ThreadPoolExecutor(
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java
new file mode 100644
index 000000000000..29a4027955f5
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.spi.v1;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Ticker;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import java.time.Duration;
+import java.util.Objects;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+
+/** Shared process-local latency scores for routed Spanner endpoints. */
+final class EndpointLatencyRegistry {
+ private static final String GLOBAL_SCOPE = "__global__";
+
+ static final Duration DEFAULT_ERROR_PENALTY = Duration.ofSeconds(10);
+ static final Duration DEFAULT_RTT = Duration.ofMillis(10);
+ static final double DEFAULT_PENALTY_VALUE = 1_000_000.0;
+ @VisibleForTesting static final Duration TRACKER_EXPIRE_AFTER_ACCESS = Duration.ofMinutes(10);
+ @VisibleForTesting static final long MAX_TRACKERS = 100_000L;
+
+ private static volatile Cache TRACKERS =
+ newTrackerCache(Ticker.systemTicker());
+
+ private EndpointLatencyRegistry() {}
+
+ static boolean hasScore(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ String endpointLabelOrAddress) {
+ TrackerKey trackerKey =
+ trackerKey(databaseScope, operationUid, preferLeader, endpointLabelOrAddress);
+ return trackerKey != null && TRACKERS.getIfPresent(trackerKey) != null;
+ }
+
+ static double getSelectionCost(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ String endpointLabelOrAddress) {
+ return getSelectionCost(
+ databaseScope, operationUid, preferLeader, null, endpointLabelOrAddress);
+ }
+
+ static double getSelectionCost(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ @javax.annotation.Nullable ChannelEndpoint endpoint,
+ String endpointLabelOrAddress) {
+ TrackerKey trackerKey =
+ trackerKey(databaseScope, operationUid, preferLeader, endpointLabelOrAddress);
+ if (trackerKey == null) {
+ return Double.MAX_VALUE;
+ }
+ double activeRequests = endpoint == null ? 0.0 : endpoint.getActiveRequestCount();
+ LatencyTracker tracker = TRACKERS.getIfPresent(trackerKey);
+ if (tracker != null) {
+ return tracker.getScore() * (activeRequests + 1.0);
+ }
+ if (activeRequests > 0.0) {
+ return DEFAULT_PENALTY_VALUE + activeRequests;
+ }
+ return defaultRttMicros() * (activeRequests + 1.0);
+ }
+
+ static void recordLatency(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ String endpointLabelOrAddress,
+ Duration latency) {
+ TrackerKey trackerKey =
+ trackerKey(databaseScope, operationUid, preferLeader, endpointLabelOrAddress);
+ if (trackerKey == null || latency == null) {
+ return;
+ }
+ getOrCreateTracker(trackerKey).update(latency);
+ }
+
+ static void recordError(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ String endpointLabelOrAddress) {
+ recordError(
+ databaseScope, operationUid, preferLeader, endpointLabelOrAddress, DEFAULT_ERROR_PENALTY);
+ }
+
+ static void recordError(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ String endpointLabelOrAddress,
+ Duration penalty) {
+ TrackerKey trackerKey =
+ trackerKey(databaseScope, operationUid, preferLeader, endpointLabelOrAddress);
+ if (trackerKey == null || penalty == null) {
+ return;
+ }
+ getOrCreateTracker(trackerKey).recordError(penalty);
+ }
+
+ @VisibleForTesting
+ static void clear() {
+ TRACKERS.invalidateAll();
+ }
+
+ @VisibleForTesting
+ static void useTrackerTicker(Ticker ticker) {
+ TRACKERS = newTrackerCache(ticker);
+ }
+
+ @VisibleForTesting
+ static String normalizeAddress(String endpointLabelOrAddress) {
+ if (endpointLabelOrAddress == null || endpointLabelOrAddress.isEmpty()) {
+ return null;
+ }
+ return endpointLabelOrAddress;
+ }
+
+ @VisibleForTesting
+ static TrackerKey trackerKey(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ String endpointLabelOrAddress) {
+ return trackerKey(databaseScope, operationUid, false, endpointLabelOrAddress);
+ }
+
+ @VisibleForTesting
+ static TrackerKey trackerKey(
+ @javax.annotation.Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader,
+ String endpointLabelOrAddress) {
+ String address = normalizeAddress(endpointLabelOrAddress);
+ if (operationUid <= 0 || address == null) {
+ return null;
+ }
+ return new TrackerKey(normalizeScope(databaseScope), operationUid, preferLeader, address);
+ }
+
+ private static long defaultRttMicros() {
+ return DEFAULT_RTT.toNanos() / 1_000L;
+ }
+
+ private static String normalizeScope(@javax.annotation.Nullable String databaseScope) {
+ return (databaseScope == null || databaseScope.isEmpty()) ? GLOBAL_SCOPE : databaseScope;
+ }
+
+ private static LatencyTracker getOrCreateTracker(TrackerKey trackerKey) {
+ try {
+ return TRACKERS.get(trackerKey, EwmaLatencyTracker::new);
+ } catch (ExecutionException e) {
+ throw new IllegalStateException("Failed to create latency tracker", e);
+ }
+ }
+
+ private static Cache newTrackerCache(Ticker ticker) {
+ return CacheBuilder.newBuilder()
+ .maximumSize(MAX_TRACKERS)
+ .expireAfterAccess(TRACKER_EXPIRE_AFTER_ACCESS.toNanos(), TimeUnit.NANOSECONDS)
+ .ticker(ticker)
+ .build();
+ }
+
+ @VisibleForTesting
+ static final class TrackerKey {
+ private final String databaseScope;
+ private final long operationUid;
+ private final boolean preferLeader;
+ private final String address;
+
+ private TrackerKey(
+ String databaseScope, long operationUid, boolean preferLeader, String address) {
+ this.databaseScope = databaseScope;
+ this.operationUid = operationUid;
+ this.preferLeader = preferLeader;
+ this.address = address;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (!(other instanceof TrackerKey)) {
+ return false;
+ }
+ TrackerKey that = (TrackerKey) other;
+ return operationUid == that.operationUid
+ && preferLeader == that.preferLeader
+ && Objects.equals(databaseScope, that.databaseScope)
+ && Objects.equals(address, that.address);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(databaseScope, operationUid, preferLeader, address);
+ }
+
+ @Override
+ public String toString() {
+ return databaseScope + ":" + operationUid + ":" + preferLeader + "@" + address;
+ }
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java
index ae78f07b14a3..91407d46a3a0 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java
@@ -70,8 +70,9 @@ class EndpointLifecycleManager {
private static final long EVICTION_CHECK_INTERVAL_SECONDS = 300;
/**
- * Maximum consecutive TRANSIENT_FAILURE probes before evicting an endpoint. Gives the channel
- * time to recover from transient network issues before we tear it down and recreate.
+ * Maximum observed TRANSIENT_FAILURE probes before evicting an endpoint. The counter resets only
+ * after the channel reaches READY, so CONNECTING/IDLE oscillation does not hide a persistently
+ * unhealthy endpoint.
*/
private static final int MAX_TRANSIENT_FAILURE_COUNT = 3;
@@ -493,7 +494,8 @@ private void stopProbing(String address) {
* All exceptions are caught to prevent {@link ScheduledExecutorService} from cancelling future
* runs of this task.
*/
- private void probe(String address) {
+ @VisibleForTesting
+ void probe(String address) {
try {
if (isShutdown.get()) {
return;
@@ -530,25 +532,24 @@ private void probe(String address) {
logger.log(
Level.FINE, "Probe for {0}: channel IDLE, requesting connection (warmup)", address);
channel.getState(true);
- state.consecutiveTransientFailures = 0;
break;
case CONNECTING:
- state.consecutiveTransientFailures = 0;
break;
case TRANSIENT_FAILURE:
state.consecutiveTransientFailures++;
logger.log(
Level.FINE,
- "Probe for {0}: channel in TRANSIENT_FAILURE ({1}/{2})",
+ "Probe for {0}: channel in TRANSIENT_FAILURE ({1}/{2} observed failures since last"
+ + " READY)",
new Object[] {
address, state.consecutiveTransientFailures, MAX_TRANSIENT_FAILURE_COUNT
});
if (state.consecutiveTransientFailures >= MAX_TRANSIENT_FAILURE_COUNT) {
logger.log(
Level.FINE,
- "Evicting endpoint {0}: {1} consecutive TRANSIENT_FAILURE probes",
+ "Evicting endpoint {0}: {1} TRANSIENT_FAILURE probes without reaching READY",
new Object[] {address, state.consecutiveTransientFailures});
evictEndpoint(address, EvictionReason.TRANSIENT_FAILURE);
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java
new file mode 100644
index 000000000000..de385899d38d
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java
@@ -0,0 +1,154 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.spi.v1;
+
+import com.google.common.annotations.VisibleForTesting;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.function.LongUnaryOperator;
+
+/**
+ * Tracks short-lived endpoint cooldowns after routed {@code RESOURCE_EXHAUSTED} failures.
+ *
+ *
This allows later requests to try a different replica instead of immediately routing back to
+ * the same overloaded endpoint.
+ */
+final class EndpointOverloadCooldownTracker {
+
+ @VisibleForTesting static final Duration DEFAULT_INITIAL_COOLDOWN = Duration.ofSeconds(10);
+ @VisibleForTesting static final Duration DEFAULT_MAX_COOLDOWN = Duration.ofMinutes(1);
+ @VisibleForTesting static final Duration DEFAULT_RESET_AFTER = Duration.ofMinutes(10);
+
+ @VisibleForTesting
+ static final class CooldownState {
+ private final int consecutiveFailures;
+ private final Instant cooldownUntil;
+ private final Instant lastFailureAt;
+
+ private CooldownState(int consecutiveFailures, Instant cooldownUntil, Instant lastFailureAt) {
+ this.consecutiveFailures = consecutiveFailures;
+ this.cooldownUntil = cooldownUntil;
+ this.lastFailureAt = lastFailureAt;
+ }
+ }
+
+ private final ConcurrentHashMap entries = new ConcurrentHashMap<>();
+ private final Duration initialCooldown;
+ private final Duration maxCooldown;
+ private final Duration resetAfter;
+ private final Clock clock;
+ private final LongUnaryOperator randomLong;
+
+ EndpointOverloadCooldownTracker() {
+ this(
+ DEFAULT_INITIAL_COOLDOWN,
+ DEFAULT_MAX_COOLDOWN,
+ DEFAULT_RESET_AFTER,
+ Clock.systemUTC(),
+ bound -> ThreadLocalRandom.current().nextLong(bound));
+ }
+
+ @VisibleForTesting
+ EndpointOverloadCooldownTracker(
+ Duration initialCooldown,
+ Duration maxCooldown,
+ Duration resetAfter,
+ Clock clock,
+ LongUnaryOperator randomLong) {
+ Duration resolvedInitial =
+ (initialCooldown == null || initialCooldown.isZero() || initialCooldown.isNegative())
+ ? DEFAULT_INITIAL_COOLDOWN
+ : initialCooldown;
+ Duration resolvedMax =
+ (maxCooldown == null || maxCooldown.isZero() || maxCooldown.isNegative())
+ ? DEFAULT_MAX_COOLDOWN
+ : maxCooldown;
+ if (resolvedMax.compareTo(resolvedInitial) < 0) {
+ resolvedMax = resolvedInitial;
+ }
+ this.initialCooldown = resolvedInitial;
+ this.maxCooldown = resolvedMax;
+ this.resetAfter =
+ (resetAfter == null || resetAfter.isZero() || resetAfter.isNegative())
+ ? DEFAULT_RESET_AFTER
+ : resetAfter;
+ this.clock = clock == null ? Clock.systemUTC() : clock;
+ this.randomLong =
+ randomLong == null ? bound -> ThreadLocalRandom.current().nextLong(bound) : randomLong;
+ }
+
+ boolean isCoolingDown(String address) {
+ if (address == null || address.isEmpty()) {
+ return false;
+ }
+ Instant now = clock.instant();
+ CooldownState state = entries.get(address);
+ if (state == null) {
+ return false;
+ }
+ if (state.cooldownUntil.isAfter(now)) {
+ return true;
+ }
+ if (Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) {
+ return false;
+ }
+ entries.remove(address, state);
+ CooldownState current = entries.get(address);
+ return current != null && current.cooldownUntil.isAfter(now);
+ }
+
+ void recordFailure(String address) {
+ if (address == null || address.isEmpty()) {
+ return;
+ }
+ Instant now = clock.instant();
+ entries.compute(
+ address,
+ (ignored, state) -> {
+ int consecutiveFailures = 1;
+ if (state != null
+ && Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) {
+ consecutiveFailures = state.consecutiveFailures + 1;
+ }
+ Duration cooldown = cooldownForFailures(consecutiveFailures);
+ return new CooldownState(consecutiveFailures, now.plus(cooldown), now);
+ });
+ }
+
+ private Duration cooldownForFailures(int failures) {
+ Duration cooldown = initialCooldown;
+ for (int i = 1; i < failures; i++) {
+ if (cooldown.compareTo(maxCooldown.dividedBy(2)) > 0) {
+ cooldown = maxCooldown;
+ break;
+ }
+ cooldown = cooldown.multipliedBy(2);
+ }
+ long cooldownMillis = Math.max(1L, cooldown.toMillis());
+ long floorMillis = Math.max(1L, cooldownMillis / 2L);
+ long rangeSize = Math.max(1L, cooldownMillis - floorMillis + 1L);
+ return Duration.ofMillis(floorMillis + randomLong.applyAsLong(rangeSize));
+ }
+
+ @VisibleForTesting
+ CooldownState getState(String address) {
+ return entries.get(address);
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java
index 0cb2331660f9..764a784cbf53 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java
@@ -18,25 +18,32 @@
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
+import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
/**
* Implementation of {@link LatencyTracker} using Exponentially Weighted Moving Average (EWMA).
*
- * Formula: $S_{i+1} = \alpha * new\_latency + (1 - \alpha) * S_i$
+ *
By default, this tracker uses a time-decayed EWMA: $S_{i+1} = \alpha(\Delta t) * new\_latency
+ * + (1 - \alpha(\Delta t)) * S_i$, where $\alpha(\Delta t) = 1 - e^{-\Delta t / \tau}$.
*
- *
This class is thread-safe.
+ *
A fixed-alpha constructor is retained for focused tests.
*/
@InternalApi
@BetaApi
public class EwmaLatencyTracker implements LatencyTracker {
public static final double DEFAULT_ALPHA = 0.05;
+ public static final Duration DEFAULT_DECAY_TIME = Duration.ofSeconds(10);
- private final double alpha;
+ @Nullable private final Double fixedAlpha;
+ private final long tauNanos;
+ private final LongSupplier nanoTimeSupplier;
private final Object lock = new Object();
@GuardedBy("lock")
@@ -45,9 +52,12 @@ public class EwmaLatencyTracker implements LatencyTracker {
@GuardedBy("lock")
private boolean initialized = false;
- /** Creates a new tracker with the default alpha value of 0.05. */
+ @GuardedBy("lock")
+ private long lastUpdatedAtNanos;
+
+ /** Creates a new tracker with Envoy-style time-based decay and a 10-second decay window. */
public EwmaLatencyTracker() {
- this(DEFAULT_ALPHA);
+ this(DEFAULT_DECAY_TIME, System::nanoTime);
}
/**
@@ -56,8 +66,25 @@ public EwmaLatencyTracker() {
* @param alpha the smoothing factor, must be in the range (0, 1]
*/
public EwmaLatencyTracker(double alpha) {
+ this(alpha, System::nanoTime);
+ }
+
+ @VisibleForTesting
+ EwmaLatencyTracker(Duration decayTime, LongSupplier nanoTimeSupplier) {
+ Preconditions.checkArgument(
+ decayTime != null && !decayTime.isZero() && !decayTime.isNegative(),
+ "decayTime must be > 0");
+ this.fixedAlpha = null;
+ this.tauNanos = decayTime.toNanos();
+ this.nanoTimeSupplier = nanoTimeSupplier;
+ }
+
+ @VisibleForTesting
+ EwmaLatencyTracker(double alpha, LongSupplier nanoTimeSupplier) {
Preconditions.checkArgument(alpha > 0.0 && alpha <= 1.0, "alpha must be in (0, 1]");
- this.alpha = alpha;
+ this.fixedAlpha = alpha;
+ this.tauNanos = 0L;
+ this.nanoTimeSupplier = nanoTimeSupplier;
}
@Override
@@ -77,12 +104,16 @@ public void update(Duration latency) {
// Use Long.MAX_VALUE to give it the lowest possible priority.
latencyMicros = Long.MAX_VALUE;
}
+ long nowNanos = nanoTimeSupplier.getAsLong();
synchronized (lock) {
if (!initialized) {
score = latencyMicros;
initialized = true;
+ lastUpdatedAtNanos = nowNanos;
} else {
+ double alpha = fixedAlpha != null ? fixedAlpha : calculateTimeBasedAlpha(nowNanos);
score = alpha * latencyMicros + (1 - alpha) * score;
+ lastUpdatedAtNanos = nowNanos;
}
}
}
@@ -92,4 +123,14 @@ public void recordError(Duration penalty) {
// Treat the error as a sample with high latency (penalty)
update(penalty);
}
+
+ private double calculateTimeBasedAlpha(long nowNanos) {
+ long deltaNanos = nowNanos - lastUpdatedAtNanos;
+ if (deltaNanos <= 0L) {
+ // Concurrent or future samples get full weight
+ return 1.0;
+ }
+ double alpha = 1.0 - Math.exp(-(double) deltaNanos / (double) tauNanos);
+ return Math.min(1.0, Math.max(0.0, alpha));
+ }
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java
index 7c1b6be1c1bd..b4725f2e8429 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java
@@ -369,7 +369,11 @@ public GapicSpannerRpc(final SpannerOptions options) {
GrpcTransportOptions.setUpCredentialsProvider(options);
InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder =
- createChannelProviderBuilder(options, headerProviderWithUserAgent, isEnableDirectAccess);
+ createBaseChannelProviderBuilder(
+ options, headerProviderWithUserAgent, isEnableDirectAccess);
+ GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator =
+ createGrpcGcpEndpointChannelConfigurator(defaultChannelProviderBuilder, options);
+ maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options);
if (options.getChannelProvider() == null
&& isEnableDirectAccess
@@ -391,7 +395,8 @@ public GapicSpannerRpc(final SpannerOptions options) {
enableLocationApi && baseChannelProvider instanceof InstantiatingGrpcChannelProvider
? new KeyAwareTransportChannelProvider(
(InstantiatingGrpcChannelProvider) baseChannelProvider,
- options.getChannelEndpointCacheFactory())
+ options.getChannelEndpointCacheFactory(),
+ endpointChannelConfigurator)
: baseChannelProvider;
spannerWatchdog =
@@ -431,12 +436,26 @@ public GapicSpannerRpc(final SpannerOptions options) {
&& isEnableDirectAccess;
this.readRetrySettings =
options.getSpannerStubSettings().streamingReadSettings().getRetrySettings();
- this.readRetryableCodes =
+ Set streamingReadRetryableCodes =
options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes();
+ this.readRetryableCodes =
+ enableLocationApi
+ ? ImmutableSet.builder()
+ .addAll(streamingReadRetryableCodes)
+ .add(Code.RESOURCE_EXHAUSTED)
+ .build()
+ : streamingReadRetryableCodes;
this.executeQueryRetrySettings =
options.getSpannerStubSettings().executeStreamingSqlSettings().getRetrySettings();
- this.executeQueryRetryableCodes =
+ Set executeStreamingSqlRetryableCodes =
options.getSpannerStubSettings().executeStreamingSqlSettings().getRetryableCodes();
+ this.executeQueryRetryableCodes =
+ enableLocationApi
+ ? ImmutableSet.builder()
+ .addAll(executeStreamingSqlRetryableCodes)
+ .add(Code.RESOURCE_EXHAUSTED)
+ .build()
+ : executeStreamingSqlRetryableCodes;
this.commitRetrySettings =
options.getSpannerStubSettings().commitSettings().getRetrySettings();
partitionedDmlRetrySettings =
@@ -725,6 +744,17 @@ private InstantiatingGrpcChannelProvider.Builder createChannelProviderBuilder(
final SpannerOptions options,
final HeaderProvider headerProviderWithUserAgent,
boolean isEnableDirectAccess) {
+ InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder =
+ createBaseChannelProviderBuilder(
+ options, headerProviderWithUserAgent, isEnableDirectAccess);
+ maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options);
+ return defaultChannelProviderBuilder;
+ }
+
+ private InstantiatingGrpcChannelProvider.Builder createBaseChannelProviderBuilder(
+ final SpannerOptions options,
+ final HeaderProvider headerProviderWithUserAgent,
+ boolean isEnableDirectAccess) {
InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setChannelConfigurator(options.getChannelConfigurator())
@@ -770,8 +800,6 @@ private InstantiatingGrpcChannelProvider.Builder createChannelProviderBuilder(
defaultChannelProviderBuilder.setExecutor(executor);
}
}
- // If it is enabled in options uses the channel pool provided by the gRPC-GCP extension.
- maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options);
return defaultChannelProviderBuilder;
}
@@ -827,6 +855,36 @@ static GcpChannelPoolOptions getGrpcGcpChannelPoolOptions(SpannerOptions options
.build();
}
+ @VisibleForTesting
+ static GcpChannelPoolOptions getGrpcGcpEndpointChannelPoolOptions(SpannerOptions options) {
+ GcpChannelPoolOptions channelPoolOptions = options.getGcpChannelPoolOptions();
+ return GcpChannelPoolOptions.newBuilder()
+ .setMaxSize(1)
+ .setMinSize(1)
+ .setInitSize(1)
+ .disableDynamicScaling()
+ .setAffinityKeyLifetime(channelPoolOptions.getAffinityKeyLifetime())
+ .setCleanupInterval(channelPoolOptions.getCleanupInterval())
+ .build();
+ }
+
+ @Nullable
+ private static GrpcGcpEndpointChannelConfigurator createGrpcGcpEndpointChannelConfigurator(
+ InstantiatingGrpcChannelProvider.Builder channelProviderBuilder, SpannerOptions options) {
+ if (!options.isGrpcGcpExtensionEnabled()) {
+ return null;
+ }
+
+ GcpManagedChannelOptions endpointGrpcGcpOptions =
+ GcpManagedChannelOptions.newBuilder(grpcGcpOptionsWithMetricsAndDcp(options))
+ .withChannelPoolOptions(getGrpcGcpEndpointChannelPoolOptions(options))
+ .build();
+ return new GrpcGcpEndpointChannelConfigurator(
+ channelProviderBuilder.getChannelConfigurator(),
+ parseGrpcGcpApiConfig(),
+ endpointGrpcGcpOptions);
+ }
+
@SuppressWarnings("rawtypes")
private static void maybeEnableGrpcGcpExtension(
InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder,
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java
index 98e7f83b094f..59415d94c37d 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java
@@ -17,6 +17,7 @@
package com.google.cloud.spanner.spi.v1;
import com.google.api.core.InternalApi;
+import com.google.api.gax.grpc.ChannelPoolSettings;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.cloud.spanner.ErrorCode;
@@ -26,11 +27,13 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.io.IOException;
+import java.time.Duration;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@@ -50,10 +53,14 @@ class GrpcChannelEndpointCache implements ChannelEndpointCache {
/** Timeout for graceful channel shutdown. */
private static final long SHUTDOWN_TIMEOUT_SECONDS = 5;
+ @VisibleForTesting static final Duration ROUTED_KEEPALIVE_TIME = Duration.ofSeconds(2);
+ @VisibleForTesting static final Duration ROUTED_KEEPALIVE_TIMEOUT = Duration.ofSeconds(20);
+
private final InstantiatingGrpcChannelProvider baseProvider;
private final Map servers = new ConcurrentHashMap<>();
private final GrpcChannelEndpoint defaultEndpoint;
private final String defaultAuthority;
+ @Nullable private final GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator;
private final AtomicBoolean isShutdown = new AtomicBoolean(false);
/**
@@ -65,7 +72,15 @@ class GrpcChannelEndpointCache implements ChannelEndpointCache {
*/
public GrpcChannelEndpointCache(InstantiatingGrpcChannelProvider channelProvider)
throws IOException {
+ this(channelProvider, null);
+ }
+
+ public GrpcChannelEndpointCache(
+ InstantiatingGrpcChannelProvider channelProvider,
+ @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator)
+ throws IOException {
this.baseProvider = channelProvider;
+ this.endpointChannelConfigurator = endpointChannelConfigurator;
String defaultEndpoint = channelProvider.getEndpoint();
this.defaultEndpoint = new GrpcChannelEndpoint(defaultEndpoint, channelProvider);
this.defaultAuthority = this.defaultEndpoint.getChannel().authority();
@@ -110,19 +125,27 @@ public ChannelEndpoint getIfPresent(String address) {
return servers.get(address);
}
- private InstantiatingGrpcChannelProvider createProviderWithAuthorityOverride(String address) {
+ @VisibleForTesting
+ InstantiatingGrpcChannelProvider createProviderWithAuthorityOverride(String address) {
InstantiatingGrpcChannelProvider endpointProvider =
(InstantiatingGrpcChannelProvider) baseProvider.withEndpoint(address);
if (Objects.equals(defaultAuthority, address)) {
return endpointProvider;
}
Builder builder = endpointProvider.toBuilder();
+ builder.setChannelPoolSettings(ChannelPoolSettings.staticallySized(1));
+ builder.setKeepAliveTimeDuration(ROUTED_KEEPALIVE_TIME);
+ builder.setKeepAliveTimeoutDuration(ROUTED_KEEPALIVE_TIMEOUT);
+ builder.setKeepAliveWithoutCalls(Boolean.TRUE);
final com.google.api.core.ApiFunction
- baseConfigurator = builder.getChannelConfigurator();
+ baseConfigurator =
+ endpointChannelConfigurator == null ? builder.getChannelConfigurator() : null;
builder.setChannelConfigurator(
channelBuilder -> {
ManagedChannelBuilder effectiveBuilder = channelBuilder;
- if (baseConfigurator != null) {
+ if (endpointChannelConfigurator != null) {
+ effectiveBuilder = endpointChannelConfigurator.configure(effectiveBuilder);
+ } else if (baseConfigurator != null) {
effectiveBuilder = baseConfigurator.apply(effectiveBuilder);
}
return effectiveBuilder.overrideAuthority(defaultAuthority);
@@ -182,6 +205,7 @@ private void shutdownChannel(GrpcChannelEndpoint server, boolean awaitTerminatio
static class GrpcChannelEndpoint implements ChannelEndpoint {
private final String address;
private final ManagedChannel channel;
+ private final AtomicInteger activeRequests = new AtomicInteger();
/**
* Creates a server from a channel provider.
@@ -267,5 +291,20 @@ public boolean isTransientFailure() {
public ManagedChannel getChannel() {
return channel;
}
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
}
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java
new file mode 100644
index 000000000000..67b9cb3f0140
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.spi.v1;
+
+import com.google.api.core.ApiFunction;
+import com.google.cloud.grpc.GcpManagedChannelBuilder;
+import com.google.cloud.grpc.GcpManagedChannelOptions;
+import io.grpc.ManagedChannelBuilder;
+import javax.annotation.Nullable;
+
+/**
+ * Rebuilds the grpc-gcp wrapper for routed endpoint channels while preserving the base channel
+ * configuration.
+ */
+final class GrpcGcpEndpointChannelConfigurator {
+ @Nullable
+ private final ApiFunction baseConfigurator;
+
+ private final String apiConfigJson;
+ private final GcpManagedChannelOptions grpcGcpOptions;
+
+ GrpcGcpEndpointChannelConfigurator(
+ @Nullable ApiFunction baseConfigurator,
+ String apiConfigJson,
+ GcpManagedChannelOptions grpcGcpOptions) {
+ this.baseConfigurator = baseConfigurator;
+ this.apiConfigJson = apiConfigJson;
+ this.grpcGcpOptions = grpcGcpOptions;
+ }
+
+ ManagedChannelBuilder configure(ManagedChannelBuilder channelBuilder) {
+ ManagedChannelBuilder effectiveBuilder = channelBuilder;
+ if (baseConfigurator != null) {
+ effectiveBuilder = baseConfigurator.apply(effectiveBuilder);
+ }
+ return GcpManagedChannelBuilder.forDelegateBuilder(effectiveBuilder)
+ .withApiConfigJsonString(apiConfigJson)
+ .withOptions(grpcGcpOptions);
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java
index 861e839a0366..638804b1633a 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java
@@ -42,9 +42,11 @@
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.common.AttributesBuilder;
import io.opentelemetry.api.trace.Span;
+import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Matcher;
@@ -104,6 +106,8 @@ public ClientCall interceptCall(
public void start(Listener responseListener, Metadata headers) {
try {
Span span = Span.current();
+ long startedAtNanos = System.nanoTime();
+ AtomicBoolean firstResponseRecorded = new AtomicBoolean(false);
DatabaseName databaseName = extractDatabaseName(headers);
String key = extractKey(databaseName, method.getFullMethodName());
String requestId = extractRequestId(headers);
@@ -115,6 +119,7 @@ public void start(Listener responseListener, Metadata headers) {
new SimpleForwardingClientCallListener(responseListener) {
@Override
public void onHeaders(Metadata metadata) {
+ recordFirstResponseLatency(requestId, startedAtNanos, firstResponseRecorded);
String serverTiming = metadata.get(SERVER_TIMING_HEADER_KEY);
try {
// Get gfe and afe Latency value
@@ -137,17 +142,22 @@ public void onClose(Status status, Metadata trailers) {
recordCustomMetrics(tagContext, attributes, isDirectPathUsed);
Map builtInMetricsAttributes = new HashMap<>();
try {
- builtInMetricsAttributes = getBuiltInMetricAttributes(key, databaseName);
+ builtInMetricsAttributes =
+ new HashMap<>(getBuiltInMetricAttributes(key, databaseName));
} catch (ExecutionException e) {
LOGGER.log(
LEVEL, "Unable to get built-in metric attributes {}", e.getMessage());
}
+ if (status.isOk()) {
+ recordFirstResponseLatency(requestId, startedAtNanos, firstResponseRecorded);
+ }
recordBuiltInMetrics(
compositeTracer,
builtInMetricsAttributes,
requestId,
isDirectPathUsed,
isAfeEnabled);
+ RequestIdTargetTracker.remove(requestId);
super.onClose(status, trailers);
}
},
@@ -208,6 +218,27 @@ private void recordBuiltInMetrics(
}
}
+ private void recordFirstResponseLatency(
+ String requestId, long startedAtNanos, AtomicBoolean firstResponseRecorded) {
+ if (!firstResponseRecorded.compareAndSet(false, true)) {
+ return;
+ }
+ RequestIdTargetTracker.RoutingTarget routingTarget = RequestIdTargetTracker.get(requestId);
+ if (routingTarget == null
+ || routingTarget.operationUid <= 0
+ || routingTarget.targetEndpoint == null
+ || routingTarget.targetEndpoint.isEmpty()) {
+ return;
+ }
+ long latencyNanos = Math.max(0L, System.nanoTime() - startedAtNanos);
+ EndpointLatencyRegistry.recordLatency(
+ routingTarget.databaseScope,
+ routingTarget.operationUid,
+ routingTarget.preferLeader,
+ routingTarget.targetEndpoint,
+ Duration.ofNanos(latencyNanos));
+ }
+
private Map parseServerTimingHeader(String serverTiming) {
Map serverTimingMetrics = new HashMap<>();
if (serverTiming != null) {
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java
index d7b32f72bcd6..34e43ebca674 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java
@@ -21,6 +21,8 @@
import com.google.api.core.InternalApi;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
import com.google.cloud.spanner.XGoogSpannerRequestId;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Ticker;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.protobuf.ByteString;
@@ -46,11 +48,10 @@
import java.io.IOException;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
-import java.util.HashSet;
import java.util.Map;
-import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -69,9 +70,10 @@ final class KeyAwareChannel extends ManagedChannel {
private static final Logger logger = Logger.getLogger(KeyAwareChannel.class.getName());
+ private static final long MAX_TRACKED_TRANSACTION_AFFINITIES = 100_000L;
+ private static final long TRANSACTION_AFFINITY_TTL_MINUTES = 10L;
private static final long MAX_TRACKED_READ_ONLY_TRANSACTIONS = 100_000L;
- private static final long MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS = 100_000L;
- private static final long EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES = 10L;
+ private static final int CHANNEL_FINDER_CLEANUP_INTERVAL = 1024;
private static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead";
private static final String STREAMING_SQL_METHOD =
"google.spanner.v1.Spanner/ExecuteStreamingSql";
@@ -88,27 +90,41 @@ final class KeyAwareChannel extends ManagedChannel {
private final String defaultEndpointAddress;
private final ReferenceQueue channelFinderReferenceQueue = new ReferenceQueue<>();
private final Map channelFinders = new ConcurrentHashMap<>();
- private final Map transactionAffinities = new ConcurrentHashMap<>();
+ private final AtomicInteger channelFinderCleanupCounter = new AtomicInteger();
+ // Maps read-write transaction IDs to their last routed endpoint.
+ // Bound and age out entries in case application code abandons a transaction
+ // without sending Commit/Rollback or otherwise clearing affinity.
+ private final Cache transactionAffinities;
// Maps read-only transaction IDs to their preferLeader value.
// Strong reads → true (prefer leader), Stale reads → false (any replica).
// Bounded to prevent unbounded growth if application code does not close read-only transactions.
private final Cache readOnlyTxPreferLeader =
CacheBuilder.newBuilder().maximumSize(MAX_TRACKED_READ_ONLY_TRANSACTIONS).build();
- // If a routed endpoint returns RESOURCE_EXHAUSTED, the next retry attempt of that same logical
- // request should avoid that endpoint once so other requests are unaffected. Bound and age out
- // entries in case a caller gives up and never issues a retry.
- private final Cache> excludedEndpointsForLogicalRequest =
- CacheBuilder.newBuilder()
- .maximumSize(MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS)
- .expireAfterWrite(EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES, TimeUnit.MINUTES)
- .build();
+ private final EndpointOverloadCooldownTracker endpointOverloadCooldowns;
private KeyAwareChannel(
InstantiatingGrpcChannelProvider channelProvider,
- @Nullable ChannelEndpointCacheFactory endpointCacheFactory)
+ @Nullable ChannelEndpointCacheFactory endpointCacheFactory,
+ @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator)
+ throws IOException {
+ this(
+ channelProvider,
+ endpointCacheFactory,
+ endpointChannelConfigurator,
+ new EndpointOverloadCooldownTracker(),
+ Ticker.systemTicker());
+ }
+
+ private KeyAwareChannel(
+ InstantiatingGrpcChannelProvider channelProvider,
+ @Nullable ChannelEndpointCacheFactory endpointCacheFactory,
+ @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator,
+ EndpointOverloadCooldownTracker endpointOverloadCooldowns,
+ Ticker transactionAffinityTicker)
throws IOException {
if (endpointCacheFactory == null) {
- this.endpointCache = new GrpcChannelEndpointCache(channelProvider);
+ this.endpointCache =
+ new GrpcChannelEndpointCache(channelProvider, endpointChannelConfigurator);
} else {
this.endpointCache = endpointCacheFactory.create(channelProvider);
}
@@ -120,13 +136,56 @@ private KeyAwareChannel(
// would interfere with test assertions.
this.lifecycleManager =
(endpointCacheFactory == null) ? new EndpointLifecycleManager(endpointCache) : null;
+ this.endpointOverloadCooldowns = endpointOverloadCooldowns;
+ this.transactionAffinities = newTransactionAffinities(transactionAffinityTicker);
}
static KeyAwareChannel create(
InstantiatingGrpcChannelProvider channelProvider,
@Nullable ChannelEndpointCacheFactory endpointCacheFactory)
throws IOException {
- return new KeyAwareChannel(channelProvider, endpointCacheFactory);
+ return new KeyAwareChannel(channelProvider, endpointCacheFactory, null);
+ }
+
+ static KeyAwareChannel create(
+ InstantiatingGrpcChannelProvider channelProvider,
+ @Nullable ChannelEndpointCacheFactory endpointCacheFactory,
+ @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator)
+ throws IOException {
+ return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointChannelConfigurator);
+ }
+
+ @VisibleForTesting
+ static KeyAwareChannel create(
+ InstantiatingGrpcChannelProvider channelProvider,
+ @Nullable ChannelEndpointCacheFactory endpointCacheFactory,
+ EndpointOverloadCooldownTracker endpointOverloadCooldowns)
+ throws IOException {
+ return create(
+ channelProvider, endpointCacheFactory, endpointOverloadCooldowns, Ticker.systemTicker());
+ }
+
+ @VisibleForTesting
+ static KeyAwareChannel create(
+ InstantiatingGrpcChannelProvider channelProvider,
+ @Nullable ChannelEndpointCacheFactory endpointCacheFactory,
+ EndpointOverloadCooldownTracker endpointOverloadCooldowns,
+ Ticker transactionAffinityTicker)
+ throws IOException {
+ return new KeyAwareChannel(
+ channelProvider,
+ endpointCacheFactory,
+ null,
+ endpointOverloadCooldowns,
+ transactionAffinityTicker);
+ }
+
+ private static Cache newTransactionAffinities(Ticker ticker) {
+ return CacheBuilder.newBuilder()
+ .maximumSize(MAX_TRACKED_TRANSACTION_AFFINITIES)
+ .expireAfterAccess(TRANSACTION_AFFINITY_TTL_MINUTES, TimeUnit.MINUTES)
+ .ticker(ticker)
+ .build();
}
private static final class ChannelFinderReference extends SoftReference {
@@ -161,20 +220,23 @@ private void cleanupStaleChannelFinders() {
}
}
+ private void maybeCleanupStaleChannelFinders() {
+ if ((channelFinderCleanupCounter.incrementAndGet() & (CHANNEL_FINDER_CLEANUP_INTERVAL - 1))
+ == 0) {
+ cleanupStaleChannelFinders();
+ }
+ }
+
private ChannelFinder getOrCreateChannelFinder(String databaseId) {
- cleanupStaleChannelFinders();
+ maybeCleanupStaleChannelFinders();
ChannelFinderReference ref = channelFinders.get(databaseId);
ChannelFinder finder = (ref != null) ? ref.get() : null;
if (finder == null) {
synchronized (channelFinders) {
- cleanupStaleChannelFinders();
ref = channelFinders.get(databaseId);
finder = (ref != null) ? ref.get() : null;
if (finder == null) {
- finder =
- lifecycleManager != null
- ? new ChannelFinder(endpointCache, lifecycleManager, databaseId)
- : new ChannelFinder(endpointCache);
+ finder = new ChannelFinder(endpointCache, lifecycleManager, databaseId);
channelFinders.put(
databaseId,
new ChannelFinderReference(databaseId, finder, channelFinderReferenceQueue));
@@ -269,7 +331,7 @@ private ChannelEndpoint affinityEndpoint(
if (transactionId == null || transactionId.isEmpty()) {
return null;
}
- String address = transactionAffinities.get(transactionId);
+ String address = transactionAffinities.getIfPresent(transactionId);
if (address == null || excludedEndpoints.test(address)) {
return null;
}
@@ -296,7 +358,7 @@ private void clearAffinity(ByteString transactionId) {
if (transactionId == null || transactionId.isEmpty()) {
return;
}
- transactionAffinities.remove(transactionId);
+ transactionAffinities.invalidate(transactionId);
readOnlyTxPreferLeader.invalidate(transactionId);
}
@@ -305,7 +367,7 @@ void clearTransactionAffinity(ByteString transactionId) {
}
void clearTransactionAndChannelAffinity(ByteString transactionId, @Nullable Long channelHint) {
- String address = transactionAffinities.remove(transactionId);
+ String address = transactionAffinities.asMap().remove(transactionId);
readOnlyTxPreferLeader.invalidate(transactionId);
if (channelHint != null) {
ManagedChannel channel = defaultChannel;
@@ -319,38 +381,41 @@ void clearTransactionAndChannelAffinity(ByteString transactionId, @Nullable Long
}
}
- private void maybeExcludeEndpointOnNextCall(
- @Nullable ChannelEndpoint endpoint, @Nullable String logicalRequestKey) {
- if (endpoint == null || logicalRequestKey == null) {
+ private void recordEndpointCooldown(@Nullable ChannelEndpoint endpoint) {
+ if (endpoint == null) {
return;
}
String address = endpoint.getAddress();
- if (!defaultEndpointAddress.equals(address)) {
- excludedEndpointsForLogicalRequest
- .asMap()
- .compute(
- logicalRequestKey,
- (ignored, excludedEndpoints) -> {
- Set updated =
- excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints;
- updated.add(address);
- return updated;
- });
+ if (defaultEndpointAddress.equals(address)) {
+ return;
}
+ endpointOverloadCooldowns.recordFailure(address);
}
- private Predicate consumeExcludedEndpointsForCurrentCall(
- @Nullable String logicalRequestKey) {
- if (logicalRequestKey == null) {
- return address -> false;
+ private void maybeRecordErrorPenalty(
+ @Nullable String databaseScope,
+ @Nullable ChannelEndpoint endpoint,
+ io.grpc.Status.Code statusCode,
+ long operationUid,
+ boolean preferLeader) {
+ if (!shouldExcludeEndpointOnRetry(statusCode) || endpoint == null || operationUid <= 0L) {
+ return;
}
- Set excludedEndpoints =
- excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey);
- if (excludedEndpoints == null || excludedEndpoints.isEmpty()) {
- return address -> false;
+ String address = endpoint.getAddress();
+ if (defaultEndpointAddress.equals(address)) {
+ return;
}
- excludedEndpoints = new HashSet<>(excludedEndpoints);
- return excludedEndpoints::contains;
+ EndpointLatencyRegistry.recordError(databaseScope, operationUid, preferLeader, address);
+ }
+
+ private static boolean shouldExcludeEndpointOnRetry(io.grpc.Status.Code statusCode) {
+ return statusCode == io.grpc.Status.Code.RESOURCE_EXHAUSTED
+ || statusCode == io.grpc.Status.Code.UNAVAILABLE;
+ }
+
+ @VisibleForTesting
+ boolean isCoolingDown(String address) {
+ return endpointOverloadCooldowns.isCoolingDown(address);
}
private boolean isReadOnlyTransaction(ByteString transactionId) {
@@ -452,6 +517,10 @@ static final class KeyAwareClientCall
private ChannelFinder channelFinder;
@Nullable private Predicate excludedEndpoints;
@Nullable private ChannelEndpoint selectedEndpoint;
+ @Nullable private String selectedTargetEndpoint;
+ @Nullable private String selectedDatabaseScope;
+ private long selectedOperationUid;
+ private boolean selectedPreferLeader;
@Nullable private ByteString transactionIdToClear;
private boolean allowDefaultAffinity;
private long pendingRequests;
@@ -518,6 +587,9 @@ public void sendMessage(RequestT message) {
Predicate excludedEndpoints = excludedEndpoints();
ChannelEndpoint endpoint = null;
ChannelFinder finder = null;
+ String databaseScope = null;
+ long operationUid = 0L;
+ boolean preferLeader = false;
if (message instanceof ReadRequest) {
ReadRequest.Builder reqBuilder = ((ReadRequest) message).toBuilder();
@@ -525,6 +597,9 @@ public void sendMessage(RequestT message) {
RoutingDecision routing = routeFromRequest(reqBuilder);
finder = routing.finder;
endpoint = routing.endpoint;
+ databaseScope = routing.databaseScope;
+ operationUid = routing.operationUid;
+ preferLeader = routing.preferLeader;
message = (RequestT) reqBuilder.build();
} else if (message instanceof ExecuteSqlRequest) {
ExecuteSqlRequest.Builder reqBuilder = ((ExecuteSqlRequest) message).toBuilder();
@@ -532,6 +607,9 @@ public void sendMessage(RequestT message) {
RoutingDecision routing = routeFromRequest(reqBuilder);
finder = routing.finder;
endpoint = routing.endpoint;
+ databaseScope = routing.databaseScope;
+ operationUid = routing.operationUid;
+ preferLeader = routing.preferLeader;
message = (RequestT) reqBuilder.build();
} else if (message instanceof BeginTransactionRequest) {
BeginTransactionRequest.Builder reqBuilder =
@@ -539,6 +617,7 @@ public void sendMessage(RequestT message) {
String databaseId = parentChannel.extractDatabaseIdFromSession(reqBuilder.getSession());
if (databaseId != null) {
finder = parentChannel.getOrCreateChannelFinder(databaseId);
+ databaseScope = databaseId;
}
if (finder != null && reqBuilder.hasMutationKey()) {
endpoint = finder.findServer(reqBuilder, excludedEndpoints);
@@ -555,6 +634,7 @@ public void sendMessage(RequestT message) {
String databaseId = parentChannel.extractDatabaseIdFromSession(request.getSession());
if (databaseId != null) {
finder = parentChannel.getOrCreateChannelFinder(databaseId);
+ databaseScope = databaseId;
}
CommitRequest.Builder reqBuilder = null;
if (finder != null && request.getMutationsCount() > 0) {
@@ -593,7 +673,21 @@ public void sendMessage(RequestT message) {
throw new IllegalStateException("No default endpoint available for key-aware call");
}
selectedEndpoint = endpoint;
+ selectedTargetEndpoint = endpoint.getAddress();
+ selectedDatabaseScope = databaseScope != null ? databaseScope : routingScope(finder);
+ selectedOperationUid = operationUid;
+ selectedPreferLeader = preferLeader;
this.channelFinder = finder;
+ selectedEndpoint.incrementActiveRequests();
+ XGoogSpannerRequestId requestId = callOptions.getOption(REQUEST_ID_CALL_OPTIONS_KEY);
+ if (requestId != null) {
+ RequestIdTargetTracker.record(
+ requestId.getHeaderValue(),
+ selectedDatabaseScope,
+ selectedTargetEndpoint,
+ operationUid,
+ selectedPreferLeader);
+ }
// Record real traffic for idle eviction tracking.
parentChannel.onRequestRouted(endpoint);
@@ -745,7 +839,7 @@ private void maybeTrackReadOnlyBegin(TransactionSelector selector) {
private Predicate excludedEndpoints() {
if (excludedEndpoints == null) {
- excludedEndpoints = parentChannel.consumeExcludedEndpointsForCurrentCall(logicalRequestKey);
+ excludedEndpoints = parentChannel.endpointOverloadCooldowns::isCoolingDown;
}
return excludedEndpoints;
}
@@ -762,15 +856,18 @@ private RoutingDecision routeFromRequest(ReadRequest.Builder reqBuilder) {
if (databaseId != null) {
finder = parentChannel.getOrCreateChannelFinder(databaseId);
}
+ boolean preferLeader = preferLeader(reqBuilder.getTransaction());
if (databaseId != null && endpoint == null) {
Boolean preferLeaderOverride = parentChannel.readOnlyPreferLeader(transactionId);
+ preferLeader = preferLeaderOverride != null ? preferLeaderOverride : preferLeader;
ChannelEndpoint routed =
preferLeaderOverride != null
? finder.findServer(reqBuilder, preferLeaderOverride, excludedEndpoints)
: finder.findServer(reqBuilder, excludedEndpoints);
endpoint = routed;
}
- return new RoutingDecision(finder, endpoint);
+ return new RoutingDecision(
+ finder, endpoint, databaseId, operationUid(reqBuilder.getRoutingHint()), preferLeader);
}
private RoutingDecision routeFromRequest(ExecuteSqlRequest.Builder reqBuilder) {
@@ -785,25 +882,64 @@ private RoutingDecision routeFromRequest(ExecuteSqlRequest.Builder reqBuilder) {
if (databaseId != null) {
finder = parentChannel.getOrCreateChannelFinder(databaseId);
}
+ boolean preferLeader = preferLeader(reqBuilder.getTransaction());
if (databaseId != null && endpoint == null) {
Boolean preferLeaderOverride = parentChannel.readOnlyPreferLeader(transactionId);
+ preferLeader = preferLeaderOverride != null ? preferLeaderOverride : preferLeader;
ChannelEndpoint routed =
preferLeaderOverride != null
? finder.findServer(reqBuilder, preferLeaderOverride, excludedEndpoints)
: finder.findServer(reqBuilder, excludedEndpoints);
endpoint = routed;
}
- return new RoutingDecision(finder, endpoint);
+ return new RoutingDecision(
+ finder, endpoint, databaseId, operationUid(reqBuilder.getRoutingHint()), preferLeader);
}
}
private static final class RoutingDecision {
@Nullable private final ChannelFinder finder;
@Nullable private final ChannelEndpoint endpoint;
-
- private RoutingDecision(@Nullable ChannelFinder finder, @Nullable ChannelEndpoint endpoint) {
+ @Nullable private final String databaseScope;
+ private final long operationUid;
+ private final boolean preferLeader;
+
+ private RoutingDecision(
+ @Nullable ChannelFinder finder,
+ @Nullable ChannelEndpoint endpoint,
+ @Nullable String databaseScope,
+ long operationUid,
+ boolean preferLeader) {
this.finder = finder;
this.endpoint = endpoint;
+ this.databaseScope = databaseScope;
+ this.operationUid = operationUid;
+ this.preferLeader = preferLeader;
+ }
+ }
+
+ @Nullable
+ private static String routingScope(@Nullable ChannelFinder finder) {
+ return finder == null ? null : finder.finderKey();
+ }
+
+ private static long operationUid(com.google.spanner.v1.RoutingHint routingHint) {
+ return routingHint == null ? 0L : routingHint.getOperationUid();
+ }
+
+ private static boolean preferLeader(TransactionSelector selector) {
+ switch (selector.getSelectorCase()) {
+ case BEGIN:
+ return !selector.getBegin().hasReadOnly() || selector.getBegin().getReadOnly().getStrong();
+ case SINGLE_USE:
+ if (!selector.getSingleUse().hasReadOnly()) {
+ return true;
+ }
+ return selector.getSingleUse().getReadOnly().getStrong();
+ case ID:
+ case SELECTOR_NOT_SET:
+ default:
+ return true;
}
}
@@ -858,10 +994,19 @@ public void onMessage(ResponseT message) {
@Override
public void onClose(io.grpc.Status status, Metadata trailers) {
- if (status.getCode() == io.grpc.Status.Code.RESOURCE_EXHAUSTED) {
- call.parentChannel.maybeExcludeEndpointOnNextCall(
- call.selectedEndpoint, call.logicalRequestKey);
+ if (shouldExcludeEndpointOnRetry(status.getCode())) {
+ call.parentChannel.maybeRecordErrorPenalty(
+ call.selectedDatabaseScope,
+ call.selectedEndpoint,
+ status.getCode(),
+ call.selectedOperationUid,
+ call.selectedPreferLeader);
+ call.parentChannel.recordEndpointCooldown(call.selectedEndpoint);
+ }
+ if (call.selectedEndpoint != null) {
+ call.selectedEndpoint.decrementActiveRequests();
}
+ RequestIdTargetTracker.remove(call.logicalRequestKey);
call.maybeClearAffinity();
super.onClose(status, trailers);
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java
index 438717c3c98f..a772af4c3567 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java
@@ -29,25 +29,22 @@
final class KeyAwareTransportChannelProvider implements TransportChannelProvider {
private final InstantiatingGrpcChannelProvider baseProvider;
@Nullable private final ChannelEndpointCacheFactory endpointCacheFactory;
-
- KeyAwareTransportChannelProvider(
- InstantiatingGrpcChannelProvider.Builder builder,
- @Nullable ChannelEndpointCacheFactory endpointCacheFactory) {
- this.baseProvider = builder.build();
- this.endpointCacheFactory = endpointCacheFactory;
- }
+ @Nullable private final GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator;
KeyAwareTransportChannelProvider(
InstantiatingGrpcChannelProvider baseProvider,
- @Nullable ChannelEndpointCacheFactory endpointCacheFactory) {
+ @Nullable ChannelEndpointCacheFactory endpointCacheFactory,
+ @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator) {
this.baseProvider = baseProvider;
this.endpointCacheFactory = endpointCacheFactory;
+ this.endpointChannelConfigurator = endpointChannelConfigurator;
}
@Override
public GrpcTransportChannel getTransportChannel() throws IOException {
return GrpcTransportChannel.newBuilder()
- .setManagedChannel(KeyAwareChannel.create(baseProvider, endpointCacheFactory))
+ .setManagedChannel(
+ KeyAwareChannel.create(baseProvider, endpointCacheFactory, endpointChannelConfigurator))
.build();
}
@@ -85,41 +82,48 @@ public boolean shouldAutoClose() {
public TransportChannelProvider withEndpoint(String endpoint) {
return new KeyAwareTransportChannelProvider(
(InstantiatingGrpcChannelProvider) baseProvider.withEndpoint(endpoint),
- endpointCacheFactory);
+ endpointCacheFactory,
+ endpointChannelConfigurator);
}
@Override
public TransportChannelProvider withCredentials(Credentials credentials) {
return new KeyAwareTransportChannelProvider(
(InstantiatingGrpcChannelProvider) baseProvider.withCredentials(credentials),
- endpointCacheFactory);
+ endpointCacheFactory,
+ endpointChannelConfigurator);
}
@Override
public TransportChannelProvider withHeaders(Map headers) {
return new KeyAwareTransportChannelProvider(
- (InstantiatingGrpcChannelProvider) baseProvider.withHeaders(headers), endpointCacheFactory);
+ (InstantiatingGrpcChannelProvider) baseProvider.withHeaders(headers),
+ endpointCacheFactory,
+ endpointChannelConfigurator);
}
@Override
public TransportChannelProvider withPoolSize(int poolSize) {
return new KeyAwareTransportChannelProvider(
(InstantiatingGrpcChannelProvider) baseProvider.withPoolSize(poolSize),
- endpointCacheFactory);
+ endpointCacheFactory,
+ endpointChannelConfigurator);
}
@Override
public TransportChannelProvider withExecutor(ScheduledExecutorService executor) {
return new KeyAwareTransportChannelProvider(
(InstantiatingGrpcChannelProvider) baseProvider.withExecutor(executor),
- endpointCacheFactory);
+ endpointCacheFactory,
+ endpointChannelConfigurator);
}
@Override
public TransportChannelProvider withExecutor(Executor executor) {
return new KeyAwareTransportChannelProvider(
(InstantiatingGrpcChannelProvider) baseProvider.withExecutor(executor),
- endpointCacheFactory);
+ endpointCacheFactory,
+ endpointChannelConfigurator);
}
@Override
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java
index 59955ccb4bd2..98c26381285d 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java
@@ -18,6 +18,7 @@
import com.google.api.core.InternalApi;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
import com.google.common.hash.Hashing;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.CacheUpdate;
@@ -26,8 +27,10 @@
import com.google.spanner.v1.Range;
import com.google.spanner.v1.RoutingHint;
import com.google.spanner.v1.Tablet;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
@@ -55,6 +58,7 @@ public final class KeyRangeCache {
private static final int MAX_LOCAL_REPLICA_DISTANCE = 5;
private static final int DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK = 1000;
+ private static final double LOCAL_LEADER_SELECTION_COST_MULTIPLIER = 0.5D;
/** Determines how to handle ranges that span multiple splits. */
public enum RangeMode {
@@ -64,8 +68,87 @@ public enum RangeMode {
PICK_RANDOM
}
+ enum RouteFailureReason {
+ NONE,
+ MISSING_ROUTING_KEY,
+ CACHE_MISS,
+ ALL_EXCLUDED_OR_COOLDOWN,
+ NO_READY_REPLICA,
+ NO_MATCHING_REPLICA,
+ NO_ROUTABLE_REPLICA
+ }
+
+ static final class RouteLookupResult {
+ @javax.annotation.Nullable final ChannelEndpoint endpoint;
+ @javax.annotation.Nullable final String targetEndpointLabel;
+ final List skippedTabletDetails;
+ final RouteFailureReason failureReason;
+
+ private RouteLookupResult(
+ @javax.annotation.Nullable ChannelEndpoint endpoint,
+ @javax.annotation.Nullable String targetEndpointLabel,
+ List skippedTabletDetails,
+ RouteFailureReason failureReason) {
+ this.endpoint = endpoint;
+ this.targetEndpointLabel = targetEndpointLabel;
+ this.skippedTabletDetails = skippedTabletDetails;
+ this.failureReason = failureReason;
+ }
+
+ static RouteLookupResult routed(
+ ChannelEndpoint endpoint,
+ String targetEndpointLabel,
+ List skippedTabletDetails) {
+ return new RouteLookupResult(
+ endpoint,
+ targetEndpointLabel,
+ Collections.unmodifiableList(new ArrayList<>(skippedTabletDetails)),
+ RouteFailureReason.NONE);
+ }
+
+ static RouteLookupResult failed(
+ RouteFailureReason failureReason, List skippedTabletDetails) {
+ return new RouteLookupResult(
+ null,
+ null,
+ Collections.unmodifiableList(new ArrayList<>(skippedTabletDetails)),
+ failureReason);
+ }
+ }
+
+ static final class SkippedTabletDetail {
+ @javax.annotation.Nullable final String targetEndpointLabel;
+ final String reason;
+
+ private SkippedTabletDetail(
+ @javax.annotation.Nullable String targetEndpointLabel, String reason) {
+ this.targetEndpointLabel = targetEndpointLabel;
+ this.reason = reason;
+ }
+ }
+
+ private static final class EligibleReplica {
+ final TabletSnapshot tablet;
+ final ChannelEndpoint endpoint;
+ final double selectionCost;
+
+ private EligibleReplica(TabletSnapshot tablet, ChannelEndpoint endpoint, double selectionCost) {
+ this.tablet = tablet;
+ this.endpoint = endpoint;
+ this.selectionCost = selectionCost;
+ }
+ }
+
+ static String formatTargetEndpointLabel(String address, boolean isLeader) {
+ if (address == null || address.isEmpty() || !isLeader) {
+ return address;
+ }
+ return address + "-LEADER";
+ }
+
private final ChannelEndpointCache endpointCache;
@javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager;
+ @javax.annotation.Nullable private final String databaseScope;
private final NavigableMap ranges =
new TreeMap<>(ByteString.unsignedLexicographicalComparator());
private final Map groups = new HashMap<>();
@@ -73,19 +156,28 @@ public enum RangeMode {
private final Lock readLock = cacheLock.readLock();
private final Lock writeLock = cacheLock.writeLock();
private final AtomicLong accessCounter = new AtomicLong();
+ private final ReplicaSelector replicaSelector = new PowerOfTwoReplicaSelector();
private volatile boolean deterministicRandom = false;
private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK;
public KeyRangeCache(ChannelEndpointCache endpointCache) {
- this(endpointCache, null);
+ this(endpointCache, null, null);
}
public KeyRangeCache(
ChannelEndpointCache endpointCache,
@javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) {
+ this(endpointCache, lifecycleManager, null);
+ }
+
+ KeyRangeCache(
+ ChannelEndpointCache endpointCache,
+ @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager,
+ @javax.annotation.Nullable String databaseScope) {
this.endpointCache = Objects.requireNonNull(endpointCache);
this.lifecycleManager = lifecycleManager;
+ this.databaseScope = databaseScope;
}
@VisibleForTesting
@@ -98,6 +190,16 @@ void setMinCacheEntriesForRandomPick(int value) {
minCacheEntriesForRandomPick = value;
}
+ @VisibleForTesting
+ void recordReplicaLatency(long operationUid, String address, Duration latency) {
+ EndpointLatencyRegistry.recordLatency(databaseScope, operationUid, false, address, latency);
+ }
+
+ @VisibleForTesting
+ void recordReplicaError(long operationUid, String address) {
+ EndpointLatencyRegistry.recordError(databaseScope, operationUid, false, address);
+ }
+
/** Applies cache updates. Tablets are processed inside group updates. */
public void addRanges(CacheUpdate cacheUpdate) {
List touchedGroups = new ArrayList<>();
@@ -153,9 +255,21 @@ public ChannelEndpoint fillRoutingHint(
DirectedReadOptions directedReadOptions,
RoutingHint.Builder hintBuilder,
Predicate excludedEndpoints) {
+ return lookupRoutingHint(
+ preferLeader, rangeMode, directedReadOptions, hintBuilder, excludedEndpoints)
+ .endpoint;
+ }
+
+ RouteLookupResult lookupRoutingHint(
+ boolean preferLeader,
+ RangeMode rangeMode,
+ DirectedReadOptions directedReadOptions,
+ RoutingHint.Builder hintBuilder,
+ Predicate excludedEndpoints) {
+ List skippedTabletDetails = new ArrayList<>();
ByteString key = hintBuilder.getKey();
if (key.isEmpty()) {
- return null;
+ return RouteLookupResult.failed(RouteFailureReason.MISSING_ROUTING_KEY, skippedTabletDetails);
}
CachedRange targetRange;
@@ -167,7 +281,7 @@ public ChannelEndpoint fillRoutingHint(
}
if (targetRange == null || targetRange.group == null) {
- return null;
+ return RouteLookupResult.failed(RouteFailureReason.CACHE_MISS, skippedTabletDetails);
}
hintBuilder.setGroupUid(targetRange.group.groupUid);
@@ -175,8 +289,8 @@ public ChannelEndpoint fillRoutingHint(
hintBuilder.setKey(targetRange.startKey);
hintBuilder.setLimitKey(targetRange.limitKey);
- return targetRange.group.fillRoutingHint(
- preferLeader, directedReadOptions, hintBuilder, excludedEndpoints);
+ return targetRange.group.lookupRoutingHint(
+ preferLeader, directedReadOptions, hintBuilder, excludedEndpoints, skippedTabletDetails);
}
/** Returns all server addresses currently referenced by cached tablets. */
@@ -185,11 +299,10 @@ Set getActiveAddresses() {
readLock.lock();
try {
for (CachedGroup group : groups.values()) {
- synchronized (group) {
- for (CachedTablet tablet : group.tablets) {
- if (!tablet.serverAddress.isEmpty()) {
- addresses.add(tablet.serverAddress);
- }
+ GroupSnapshot snapshot = group.snapshot;
+ for (TabletSnapshot tablet : snapshot.tablets) {
+ if (!tablet.serverAddress.isEmpty()) {
+ addresses.add(tablet.serverAddress);
}
}
}
@@ -487,34 +600,27 @@ private int compare(ByteString left, ByteString right) {
return ByteString.unsignedLexicographicalComparator().compare(left, right);
}
- /** Represents a single tablet within a group. */
- private class CachedTablet {
- long tabletUid = 0;
- ByteString incarnation = ByteString.EMPTY;
- String serverAddress = "";
- int distance = 0;
- boolean skip = false;
- Tablet.Role role = Tablet.Role.ROLE_UNSPECIFIED;
- String location = "";
-
- ChannelEndpoint endpoint = null;
-
- void update(Tablet tabletIn) {
- if (tabletUid > 0 && compare(incarnation, tabletIn.getIncarnation()) > 0) {
- return;
- }
-
- tabletUid = tabletIn.getTabletUid();
- incarnation = tabletIn.getIncarnation();
- distance = tabletIn.getDistance();
- skip = tabletIn.getSkip();
- role = tabletIn.getRole();
- location = tabletIn.getLocation();
-
- if (!serverAddress.equals(tabletIn.getServerAddress())) {
- serverAddress = tabletIn.getServerAddress();
- endpoint = null;
- }
+ private static final GroupSnapshot EMPTY_GROUP_SNAPSHOT =
+ new GroupSnapshot(ByteString.EMPTY, -1, Collections.emptyList());
+
+ /** Immutable tablet metadata used by the read path without per-group locking. */
+ private static final class TabletSnapshot {
+ final long tabletUid;
+ final ByteString incarnation;
+ final String serverAddress;
+ final int distance;
+ final boolean skip;
+ final Tablet.Role role;
+ final String location;
+
+ private TabletSnapshot(Tablet tabletIn) {
+ this.tabletUid = tabletIn.getTabletUid();
+ this.incarnation = tabletIn.getIncarnation();
+ this.serverAddress = tabletIn.getServerAddress();
+ this.distance = tabletIn.getDistance();
+ this.skip = tabletIn.getSkip();
+ this.role = tabletIn.getRole();
+ this.location = tabletIn.getLocation();
}
boolean matches(DirectedReadOptions directedReadOptions) {
@@ -555,132 +661,6 @@ private boolean matches(DirectedReadOptions.ReplicaSelection selection) {
}
}
- /**
- * Evaluates whether this tablet should be skipped for location-aware routing.
- *
- * State-aware skip logic:
- *
- *
- * - Server-marked skip, empty address, or excluded endpoint: skip and report in
- * skipped_tablets.
- *
- Endpoint exists and READY: usable, do not skip.
- *
- Endpoint exists and TRANSIENT_FAILURE: skip and report in skipped_tablets.
- *
- Endpoint absent, IDLE, CONNECTING, SHUTDOWN, or unsupported: skip silently unless the
- * lifecycle manager recently evicted the address for repeated TRANSIENT_FAILURE, in which
- * case report it in skipped_tablets.
- *
- */
- boolean shouldSkip(
- RoutingHint.Builder hintBuilder,
- Predicate excludedEndpoints,
- Set skippedTabletUids) {
- // Server-marked skip, no address, or excluded endpoint: always report.
- if (skip || serverAddress.isEmpty() || excludedEndpoints.test(serverAddress)) {
- addSkippedTablet(hintBuilder, skippedTabletUids);
- return true;
- }
-
- // If the cached endpoint's channel has been shut down (e.g. after idle eviction),
- // discard the stale reference so we re-lookup from the cache below.
- if (endpoint != null && endpoint.getChannel().isShutdown()) {
- logger.log(
- Level.FINE,
- "Tablet {0} at {1}: cached endpoint is shutdown, clearing stale reference",
- new Object[] {tabletUid, serverAddress});
- endpoint = null;
- }
-
- // Lookup without creating: location-aware routing should not trigger foreground endpoint
- // creation.
- if (endpoint == null) {
- endpoint = endpointCache.getIfPresent(serverAddress);
- }
-
- // No endpoint exists yet - skip silently, request background recreation so the
- // endpoint becomes available for future requests.
- if (endpoint == null) {
- logger.log(
- Level.FINE,
- "Tablet {0} at {1}: no endpoint present, skipping silently",
- new Object[] {tabletUid, serverAddress});
- maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids);
- if (lifecycleManager != null) {
- lifecycleManager.requestEndpointRecreation(serverAddress);
- }
- return true;
- }
-
- // READY - usable for location-aware routing.
- if (endpoint.isHealthy()) {
- return false;
- }
-
- // TRANSIENT_FAILURE - skip and report so server can refresh client cache.
- if (endpoint.isTransientFailure()) {
- logger.log(
- Level.FINE,
- "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets",
- new Object[] {tabletUid, serverAddress});
- addSkippedTablet(hintBuilder, skippedTabletUids);
- return true;
- }
-
- // IDLE, CONNECTING, SHUTDOWN, or unsupported - skip silently.
- logger.log(
- Level.FINE,
- "Tablet {0} at {1}: endpoint not ready, skipping silently",
- new Object[] {tabletUid, serverAddress});
- maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids);
- return true;
- }
-
- private void addSkippedTablet(RoutingHint.Builder hintBuilder, Set skippedTabletUids) {
- if (!skippedTabletUids.add(tabletUid)) {
- return;
- }
- RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder();
- skipped.setTabletUid(tabletUid);
- skipped.setIncarnation(incarnation);
- }
-
- private void recordKnownTransientFailure(
- RoutingHint.Builder hintBuilder,
- Predicate excludedEndpoints,
- Set skippedTabletUids) {
- if (skip || serverAddress.isEmpty() || excludedEndpoints.test(serverAddress)) {
- return;
- }
-
- if (endpoint != null && endpoint.getChannel().isShutdown()) {
- endpoint = null;
- }
-
- if (endpoint == null) {
- endpoint = endpointCache.getIfPresent(serverAddress);
- }
-
- if (endpoint != null && endpoint.isTransientFailure()) {
- addSkippedTablet(hintBuilder, skippedTabletUids);
- return;
- }
-
- maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids);
- }
-
- private void maybeAddRecentTransientFailureSkip(
- RoutingHint.Builder hintBuilder, Set skippedTabletUids) {
- if (lifecycleManager != null
- && lifecycleManager.wasRecentlyEvictedTransientFailure(serverAddress)) {
- addSkippedTablet(hintBuilder, skippedTabletUids);
- }
- }
-
- ChannelEndpoint pick(RoutingHint.Builder hintBuilder) {
- hintBuilder.setTabletUid(tabletUid);
- // Endpoint must already exist and be READY if shouldSkip returned false.
- return endpoint;
- }
-
String debugString() {
return tabletUid
+ ":"
@@ -698,19 +678,40 @@ String debugString() {
}
}
+ private static final class GroupSnapshot {
+ final ByteString generation;
+ final int leaderIndex;
+ final List tablets;
+
+ private GroupSnapshot(ByteString generation, int leaderIndex, List tablets) {
+ this.generation = generation;
+ this.leaderIndex = leaderIndex;
+ this.tablets = Collections.unmodifiableList(tablets);
+ }
+
+ boolean hasLeader() {
+ return leaderIndex >= 0 && leaderIndex < tablets.size();
+ }
+
+ TabletSnapshot leader() {
+ return tablets.get(leaderIndex);
+ }
+ }
+
/** Represents a paxos group with its tablets. */
private class CachedGroup {
final long groupUid;
- ByteString generation = ByteString.EMPTY;
- List tablets = new ArrayList<>();
- int leaderIndex = -1;
+ volatile GroupSnapshot snapshot = EMPTY_GROUP_SNAPSHOT;
int refs = 1;
CachedGroup(long groupUid) {
this.groupUid = groupUid;
}
- synchronized void update(Group groupIn) {
+ void update(Group groupIn) {
+ GroupSnapshot current = snapshot;
+ ByteString generation = current.generation;
+ int leaderIndex = current.leaderIndex;
if (compare(groupIn.getGeneration(), generation) > 0) {
generation = groupIn.getGeneration();
if (groupIn.getLeaderIndex() >= 0 && groupIn.getLeaderIndex() < groupIn.getTabletsCount()) {
@@ -720,97 +721,151 @@ synchronized void update(Group groupIn) {
}
}
- if (tablets.size() == groupIn.getTabletsCount()) {
- boolean mismatch = false;
- for (int t = 0; t < groupIn.getTabletsCount(); t++) {
- if (tablets.get(t).tabletUid != groupIn.getTablets(t).getTabletUid()) {
- mismatch = true;
- break;
- }
- }
- if (!mismatch) {
- for (int t = 0; t < groupIn.getTabletsCount(); t++) {
- tablets.get(t).update(groupIn.getTablets(t));
- }
- return;
- }
- }
-
- Map tabletsByUid = new HashMap<>(tablets.size());
- for (CachedTablet tablet : tablets) {
- tabletsByUid.put(tablet.tabletUid, tablet);
- }
- List newTablets = new ArrayList<>(groupIn.getTabletsCount());
+ List tablets = new ArrayList<>(groupIn.getTabletsCount());
for (int t = 0; t < groupIn.getTabletsCount(); t++) {
- Tablet tabletIn = groupIn.getTablets(t);
- CachedTablet tablet = tabletsByUid.get(tabletIn.getTabletUid());
- if (tablet == null) {
- tablet = new CachedTablet();
- }
- tablet.update(tabletIn);
- newTablets.add(tablet);
+ tablets.add(new TabletSnapshot(groupIn.getTablets(t)));
}
- tablets = newTablets;
+ snapshot = new GroupSnapshot(generation, leaderIndex, tablets);
}
- ChannelEndpoint fillRoutingHint(
+ RouteLookupResult lookupRoutingHint(
boolean preferLeader,
DirectedReadOptions directedReadOptions,
RoutingHint.Builder hintBuilder,
- Predicate excludedEndpoints) {
+ Predicate excludedEndpoints,
+ List skippedTabletDetails) {
+ GroupSnapshot snapshot = this.snapshot;
Set skippedTabletUids = skippedTabletUids(hintBuilder);
boolean hasDirectedReadOptions =
directedReadOptions.getReplicasCase()
!= DirectedReadOptions.ReplicasCase.REPLICAS_NOT_SET;
-
- // Select a tablet while holding the lock. With state-aware routing, only READY
- // endpoints pass shouldSkip(), so the selected tablet always has a cached
- // endpoint. No foreground endpoint creation is needed — the lifecycle manager
- // creates endpoints in the background.
- synchronized (this) {
- CachedTablet selected =
- selectTabletLocked(
- preferLeader,
- hasDirectedReadOptions,
- hintBuilder,
+ Map resolvedEndpoints = new HashMap<>();
+ SelectionState selectionStats = new SelectionState();
+
+ TabletSnapshot selected =
+ selectTablet(
+ snapshot,
+ preferLeader,
+ hasDirectedReadOptions,
+ hintBuilder,
+ directedReadOptions,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints,
+ selectionStats);
+ if (selected == null) {
+ RouteFailureReason failureReason = selectionStats.toFailureReason();
+ if (failureReason == RouteFailureReason.ALL_EXCLUDED_OR_COOLDOWN) {
+ TabletSnapshot preferredLeader =
+ preferLeader ? localLeaderForScoreBias(snapshot, hasDirectedReadOptions) : null;
+ selected =
+ selectScoreAwareExcludedOrCoolingDownTablet(
+ snapshot,
+ preferLeader,
+ directedReadOptions,
+ hintBuilder,
+ resolvedEndpoints,
+ preferredLeader);
+ if (selected != null) {
+ recordKnownTransientFailures(
+ snapshot,
+ selected,
directedReadOptions,
+ hintBuilder,
excludedEndpoints,
- skippedTabletUids);
- if (selected == null) {
- return null;
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints);
+ hintBuilder.setTabletUid(selected.tabletUid);
+ return RouteLookupResult.routed(
+ resolveEndpoint(selected, resolvedEndpoints),
+ endpointLabel(snapshot, selected),
+ skippedTabletDetails);
+ }
}
- recordKnownTransientFailuresLocked(
- selected, directedReadOptions, hintBuilder, excludedEndpoints, skippedTabletUids);
- return selected.pick(hintBuilder);
+ return RouteLookupResult.failed(failureReason, skippedTabletDetails);
}
+ recordKnownTransientFailures(
+ snapshot,
+ selected,
+ directedReadOptions,
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints);
+ hintBuilder.setTabletUid(selected.tabletUid);
+ return RouteLookupResult.routed(
+ resolveEndpoint(selected, resolvedEndpoints),
+ endpointLabel(snapshot, selected),
+ skippedTabletDetails);
}
- private CachedTablet selectTabletLocked(
+ private TabletSnapshot selectTablet(
+ GroupSnapshot snapshot,
boolean preferLeader,
boolean hasDirectedReadOptions,
RoutingHint.Builder hintBuilder,
DirectedReadOptions directedReadOptions,
Predicate excludedEndpoints,
- Set skippedTabletUids) {
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ Map resolvedEndpoints,
+ SelectionState selectionStats) {
+ if (!preferLeader || hintBuilder.getOperationUid() > 0L) {
+ TabletSnapshot preferredLeader =
+ preferLeader ? localLeaderForScoreBias(snapshot, hasDirectedReadOptions) : null;
+ return selectScoreAwareTablet(
+ snapshot,
+ preferLeader,
+ directedReadOptions,
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints,
+ selectionStats,
+ preferredLeader);
+ }
+
boolean checkedLeader = false;
if (preferLeader
&& !hasDirectedReadOptions
- && hasLeader()
- && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
+ && snapshot.hasLeader()
+ && snapshot.leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
checkedLeader = true;
- if (!leader().shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) {
- return leader();
+ selectionStats.sawMatchingReplica = true;
+ if (!shouldSkip(
+ snapshot,
+ snapshot.leader(),
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints,
+ selectionStats)) {
+ return snapshot.leader();
}
}
- for (int index = 0; index < tablets.size(); index++) {
- if (checkedLeader && index == leaderIndex) {
+ for (int index = 0; index < snapshot.tablets.size(); index++) {
+ if (checkedLeader && index == snapshot.leaderIndex) {
continue;
}
- CachedTablet tablet = tablets.get(index);
+ TabletSnapshot tablet = snapshot.tablets.get(index);
if (!tablet.matches(directedReadOptions)) {
continue;
}
- if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) {
+ selectionStats.sawMatchingReplica = true;
+ if (shouldSkip(
+ snapshot,
+ tablet,
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints,
+ selectionStats)) {
continue;
}
return tablet;
@@ -818,17 +873,213 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
return null;
}
- private void recordKnownTransientFailuresLocked(
- CachedTablet selected,
+ @javax.annotation.Nullable
+ private TabletSnapshot localLeaderForScoreBias(
+ GroupSnapshot snapshot, boolean hasDirectedReadOptions) {
+ if (!hasDirectedReadOptions
+ && snapshot.hasLeader()
+ && snapshot.leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
+ return snapshot.leader();
+ }
+ return null;
+ }
+
+ private TabletSnapshot selectScoreAwareTablet(
+ GroupSnapshot snapshot,
+ boolean preferLeader,
+ DirectedReadOptions directedReadOptions,
+ RoutingHint.Builder hintBuilder,
+ Predicate excludedEndpoints,
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ Map resolvedEndpoints,
+ SelectionState selectionStats,
+ @javax.annotation.Nullable TabletSnapshot preferredLeader) {
+ long operationUid = hintBuilder.getOperationUid();
+ List eligibleReplicas =
+ collectEligibleReplicas(
+ snapshot,
+ directedReadOptions,
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints,
+ selectionStats,
+ operationUid,
+ preferLeader,
+ preferredLeader);
+ if (eligibleReplicas.isEmpty()) {
+ return null;
+ }
+ EligibleReplica selected = selectEligibleReplica(eligibleReplicas);
+ return selected.tablet;
+ }
+
+ private List collectEligibleReplicas(
+ GroupSnapshot snapshot,
+ DirectedReadOptions directedReadOptions,
+ RoutingHint.Builder hintBuilder,
+ Predicate excludedEndpoints,
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ Map resolvedEndpoints,
+ SelectionState selectionStats,
+ long operationUid,
+ boolean preferLeader,
+ @javax.annotation.Nullable TabletSnapshot preferredLeader) {
+ List eligibleReplicas = new ArrayList<>();
+ for (TabletSnapshot tablet : snapshot.tablets) {
+ if (!tablet.matches(directedReadOptions)) {
+ continue;
+ }
+ selectionStats.sawMatchingReplica = true;
+ if (shouldSkip(
+ snapshot,
+ tablet,
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints,
+ selectionStats)) {
+ continue;
+ }
+
+ ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints);
+ if (endpoint == null) {
+ continue;
+ }
+ eligibleReplicas.add(
+ new EligibleReplica(
+ tablet,
+ endpoint,
+ selectionCost(operationUid, preferLeader, endpoint, tablet, preferredLeader)));
+ }
+ return eligibleReplicas;
+ }
+
+ private EligibleReplica selectEligibleReplica(List eligibleReplicas) {
+ if (eligibleReplicas.size() == 1) {
+ return eligibleReplicas.get(0);
+ }
+ if (deterministicRandom) {
+ return lowestCostReplica(eligibleReplicas);
+ }
+
+ ChannelEndpoint selectedEndpoint =
+ replicaSelector.select(
+ endpointView(eligibleReplicas),
+ endpoint -> selectionCostForEndpoint(eligibleReplicas, endpoint));
+ if (selectedEndpoint == null) {
+ return eligibleReplicas.get(0);
+ }
+
+ EligibleReplica selected = candidateForEndpoint(eligibleReplicas, selectedEndpoint);
+ return selected == null ? eligibleReplicas.get(0) : selected;
+ }
+
+ private EligibleReplica lowestCostReplica(List eligibleReplicas) {
+ EligibleReplica lowestCost = eligibleReplicas.get(0);
+ for (int i = 1; i < eligibleReplicas.size(); i++) {
+ EligibleReplica candidate = eligibleReplicas.get(i);
+ if (candidate.selectionCost < lowestCost.selectionCost) {
+ lowestCost = candidate;
+ }
+ }
+ return lowestCost;
+ }
+
+ private List endpointView(List eligibleReplicas) {
+ return Lists.transform(eligibleReplicas, candidate -> candidate.endpoint);
+ }
+
+ private double selectionCostForEndpoint(
+ List eligibleReplicas, ChannelEndpoint endpoint) {
+ EligibleReplica candidate = candidateForEndpoint(eligibleReplicas, endpoint);
+ return candidate == null ? Double.MAX_VALUE : candidate.selectionCost;
+ }
+
+ @javax.annotation.Nullable
+ private EligibleReplica candidateForEndpoint(
+ List eligibleReplicas, ChannelEndpoint endpoint) {
+ for (EligibleReplica candidate : eligibleReplicas) {
+ if (candidate.endpoint == endpoint) {
+ return candidate;
+ }
+ }
+ return null;
+ }
+
+ private double selectionCost(
+ long operationUid,
+ boolean preferLeader,
+ @javax.annotation.Nullable ChannelEndpoint endpoint,
+ @javax.annotation.Nullable TabletSnapshot tablet,
+ @javax.annotation.Nullable TabletSnapshot preferredLeader) {
+ if (tablet == null) {
+ return Double.MAX_VALUE;
+ }
+ double cost =
+ EndpointLatencyRegistry.getSelectionCost(
+ databaseScope, operationUid, preferLeader, endpoint, tablet.serverAddress);
+ if (preferredLeader != null && tablet == preferredLeader) {
+ return cost * LOCAL_LEADER_SELECTION_COST_MULTIPLIER;
+ }
+ return cost;
+ }
+
+ @javax.annotation.Nullable
+ private TabletSnapshot selectScoreAwareExcludedOrCoolingDownTablet(
+ GroupSnapshot snapshot,
+ boolean preferLeader,
+ DirectedReadOptions directedReadOptions,
+ RoutingHint.Builder hintBuilder,
+ Map resolvedEndpoints,
+ @javax.annotation.Nullable TabletSnapshot preferredLeader) {
+ long operationUid = hintBuilder.getOperationUid();
+ List candidates = new ArrayList<>();
+ for (TabletSnapshot tablet : snapshot.tablets) {
+ if (!tablet.matches(directedReadOptions) || tablet.skip || tablet.serverAddress.isEmpty()) {
+ continue;
+ }
+ ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints);
+ if (endpoint == null || !endpoint.isHealthy()) {
+ continue;
+ }
+ candidates.add(
+ new EligibleReplica(
+ tablet,
+ endpoint,
+ selectionCost(operationUid, preferLeader, endpoint, tablet, preferredLeader)));
+ }
+ if (candidates.isEmpty()) {
+ return null;
+ }
+ return selectEligibleReplica(candidates).tablet;
+ }
+
+ private void recordKnownTransientFailures(
+ GroupSnapshot snapshot,
+ TabletSnapshot selected,
DirectedReadOptions directedReadOptions,
RoutingHint.Builder hintBuilder,
Predicate excludedEndpoints,
- Set skippedTabletUids) {
- for (CachedTablet tablet : tablets) {
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ Map resolvedEndpoints) {
+ for (TabletSnapshot tablet : snapshot.tablets) {
if (tablet == selected || !tablet.matches(directedReadOptions)) {
continue;
}
- tablet.recordKnownTransientFailure(hintBuilder, excludedEndpoints, skippedTabletUids);
+ recordKnownTransientFailure(
+ snapshot,
+ tablet,
+ hintBuilder,
+ excludedEndpoints,
+ skippedTabletUids,
+ skippedTabletDetails,
+ resolvedEndpoints);
}
}
@@ -840,27 +1091,215 @@ private Set skippedTabletUids(RoutingHint.Builder hintBuilder) {
return skippedTabletUids;
}
- boolean hasLeader() {
- return leaderIndex >= 0 && leaderIndex < tablets.size();
+ private boolean shouldSkip(
+ GroupSnapshot snapshot,
+ TabletSnapshot tablet,
+ RoutingHint.Builder hintBuilder,
+ Predicate excludedEndpoints,
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ Map resolvedEndpoints,
+ SelectionState selectionStats) {
+ String targetEndpointLabel = endpointLabel(snapshot, tablet);
+ if (tablet.skip) {
+ selectionStats.sawNonExcludedReplica = true;
+ selectionStats.hasUnroutableReplica = true;
+ addSkippedTablet(
+ tablet,
+ hintBuilder,
+ skippedTabletUids,
+ skippedTabletDetails,
+ targetEndpointLabel,
+ "tablet_marked_skip");
+ return true;
+ }
+ if (tablet.serverAddress.isEmpty()) {
+ selectionStats.sawNonExcludedReplica = true;
+ selectionStats.hasUnroutableReplica = true;
+ addSkippedTablet(
+ tablet, hintBuilder, skippedTabletUids, skippedTabletDetails, null, "missing_address");
+ return true;
+ }
+ if (excludedEndpoints.test(tablet.serverAddress)) {
+ selectionStats.sawExcludedReplica = true;
+ return true;
+ }
+
+ selectionStats.sawNonExcludedReplica = true;
+ ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints);
+ if (endpoint == null) {
+ selectionStats.hasUnavailableReplica = true;
+ logger.log(
+ Level.FINE,
+ "Tablet {0} at {1}: no endpoint present, skipping silently",
+ new Object[] {tablet.tabletUid, tablet.serverAddress});
+ maybeAddRecentTransientFailureSkip(
+ tablet, targetEndpointLabel, hintBuilder, skippedTabletUids, skippedTabletDetails);
+ if (lifecycleManager != null) {
+ lifecycleManager.requestEndpointRecreation(tablet.serverAddress);
+ }
+ return true;
+ }
+ if (endpoint.isHealthy()) {
+ return false;
+ }
+ if (endpoint.isTransientFailure()) {
+ selectionStats.hasUnavailableReplica = true;
+ logger.log(
+ Level.FINE,
+ "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets",
+ new Object[] {tablet.tabletUid, tablet.serverAddress});
+ addSkippedTablet(
+ tablet,
+ hintBuilder,
+ skippedTabletUids,
+ skippedTabletDetails,
+ targetEndpointLabel,
+ "transient_failure");
+ return true;
+ }
+
+ selectionStats.hasUnavailableReplica = true;
+ logger.log(
+ Level.FINE,
+ "Tablet {0} at {1}: endpoint not ready, skipping silently",
+ new Object[] {tablet.tabletUid, tablet.serverAddress});
+ maybeAddRecentTransientFailureSkip(
+ tablet, targetEndpointLabel, hintBuilder, skippedTabletUids, skippedTabletDetails);
+ return true;
}
- CachedTablet leader() {
- return tablets.get(leaderIndex);
+ private final class SelectionState {
+ private boolean sawMatchingReplica;
+ private boolean sawExcludedReplica;
+ private boolean sawNonExcludedReplica;
+ private boolean hasUnavailableReplica;
+ private boolean hasUnroutableReplica;
+
+ private RouteFailureReason toFailureReason() {
+ if (!sawMatchingReplica) {
+ return RouteFailureReason.NO_MATCHING_REPLICA;
+ }
+ if (sawExcludedReplica && !sawNonExcludedReplica) {
+ return RouteFailureReason.ALL_EXCLUDED_OR_COOLDOWN;
+ }
+ if (hasUnavailableReplica) {
+ return RouteFailureReason.NO_READY_REPLICA;
+ }
+ if (hasUnroutableReplica) {
+ return RouteFailureReason.NO_ROUTABLE_REPLICA;
+ }
+ return RouteFailureReason.NO_ROUTABLE_REPLICA;
+ }
+ }
+
+ private void recordKnownTransientFailure(
+ GroupSnapshot snapshot,
+ TabletSnapshot tablet,
+ RoutingHint.Builder hintBuilder,
+ Predicate excludedEndpoints,
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ Map resolvedEndpoints) {
+ if (tablet.skip
+ || tablet.serverAddress.isEmpty()
+ || excludedEndpoints.test(tablet.serverAddress)) {
+ return;
+ }
+
+ ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints);
+ if (endpoint != null && endpoint.isTransientFailure()) {
+ addSkippedTablet(
+ tablet,
+ hintBuilder,
+ skippedTabletUids,
+ skippedTabletDetails,
+ endpointLabel(snapshot, tablet),
+ "known_transient_failure");
+ return;
+ }
+
+ maybeAddRecentTransientFailureSkip(
+ tablet,
+ endpointLabel(snapshot, tablet),
+ hintBuilder,
+ skippedTabletUids,
+ skippedTabletDetails);
+ }
+
+ private String endpointLabel(GroupSnapshot snapshot, TabletSnapshot tablet) {
+ return formatTargetEndpointLabel(
+ tablet.serverAddress, snapshot.hasLeader() && snapshot.leader() == tablet);
+ }
+
+ private ChannelEndpoint resolveEndpoint(
+ TabletSnapshot tablet, Map resolvedEndpoints) {
+ if (tablet.serverAddress.isEmpty()) {
+ return null;
+ }
+ if (resolvedEndpoints.containsKey(tablet.serverAddress)) {
+ return resolvedEndpoints.get(tablet.serverAddress);
+ }
+ ChannelEndpoint endpoint = endpointCache.getIfPresent(tablet.serverAddress);
+ if (endpoint != null && endpoint.getChannel().isShutdown()) {
+ logger.log(
+ Level.FINE,
+ "Tablet {0} at {1}: cached endpoint is shutdown, clearing stale reference",
+ new Object[] {tablet.tabletUid, tablet.serverAddress});
+ endpoint = null;
+ }
+ resolvedEndpoints.put(tablet.serverAddress, endpoint);
+ return endpoint;
+ }
+
+ private void maybeAddRecentTransientFailureSkip(
+ TabletSnapshot tablet,
+ @javax.annotation.Nullable String targetEndpointLabel,
+ RoutingHint.Builder hintBuilder,
+ Set skippedTabletUids,
+ List skippedTabletDetails) {
+ if (lifecycleManager != null
+ && lifecycleManager.wasRecentlyEvictedTransientFailure(tablet.serverAddress)) {
+ addSkippedTablet(
+ tablet,
+ hintBuilder,
+ skippedTabletUids,
+ skippedTabletDetails,
+ targetEndpointLabel,
+ "recent_transient_failure_eviction");
+ }
+ }
+
+ private void addSkippedTablet(
+ TabletSnapshot tablet,
+ RoutingHint.Builder hintBuilder,
+ Set skippedTabletUids,
+ List skippedTabletDetails,
+ @javax.annotation.Nullable String targetEndpointLabel,
+ String reason) {
+ if (!skippedTabletUids.add(tablet.tabletUid)) {
+ return;
+ }
+ RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder();
+ skipped.setTabletUid(tablet.tabletUid);
+ skipped.setIncarnation(tablet.incarnation);
+ skippedTabletDetails.add(new SkippedTabletDetail(targetEndpointLabel, reason));
}
String debugString() {
+ GroupSnapshot snapshot = this.snapshot;
StringBuilder sb = new StringBuilder();
sb.append(groupUid).append(":[");
- for (int i = 0; i < tablets.size(); i++) {
- sb.append(tablets.get(i).debugString());
- if (hasLeader() && i == leaderIndex) {
+ for (int i = 0; i < snapshot.tablets.size(); i++) {
+ sb.append(snapshot.tablets.get(i).debugString());
+ if (snapshot.hasLeader() && i == snapshot.leaderIndex) {
sb.append(" (leader)");
}
- if (i < tablets.size() - 1) {
+ if (i < snapshot.tablets.size() - 1) {
sb.append(", ");
}
}
- sb.append("]@").append(generation.toStringUtf8());
+ sb.append("]@").append(snapshot.generation.toStringUtf8());
sb.append("#").append(refs);
return sb.toString();
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java
new file mode 100644
index 000000000000..610bfd2e9310
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.spi.v1;
+
+import com.google.cloud.spanner.XGoogSpannerRequestId;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+
+final class RequestIdTargetTracker {
+ @VisibleForTesting static final long MAX_TRACKED_TARGETS = 1_000_000L;
+
+ @VisibleForTesting
+ static final int TARGET_CACHE_CONCURRENCY =
+ Math.max(4, Runtime.getRuntime().availableProcessors());
+
+ private static final Cache TARGETS =
+ CacheBuilder.newBuilder()
+ .concurrencyLevel(TARGET_CACHE_CONCURRENCY)
+ .maximumSize(MAX_TRACKED_TARGETS)
+ .expireAfterWrite(10, TimeUnit.MINUTES)
+ .build();
+
+ private RequestIdTargetTracker() {}
+
+ static void record(
+ String requestId,
+ @Nullable String databaseScope,
+ String targetEndpoint,
+ long operationUid,
+ boolean preferLeader) {
+ String trackingKey = normalizeRequestKey(requestId);
+ if (trackingKey == null || targetEndpoint == null || targetEndpoint.isEmpty()) {
+ return;
+ }
+ TARGETS.put(
+ trackingKey, new RoutingTarget(databaseScope, targetEndpoint, operationUid, preferLeader));
+ }
+
+ @Nullable
+ static RoutingTarget get(String requestId) {
+ String trackingKey = normalizeRequestKey(requestId);
+ if (trackingKey == null) {
+ return null;
+ }
+ return TARGETS.getIfPresent(trackingKey);
+ }
+
+ static void remove(String requestId) {
+ String trackingKey = normalizeRequestKey(requestId);
+ if (trackingKey == null) {
+ return;
+ }
+ TARGETS.invalidate(trackingKey);
+ }
+
+ @VisibleForTesting
+ static void clear() {
+ TARGETS.invalidateAll();
+ }
+
+ @VisibleForTesting
+ static String normalizeRequestKey(String requestId) {
+ if (requestId == null || requestId.isEmpty()) {
+ return null;
+ }
+ try {
+ return XGoogSpannerRequestId.of(requestId).getLogicalRequestKey();
+ } catch (IllegalStateException e) {
+ return requestId;
+ }
+ }
+
+ static final class RoutingTarget {
+ @Nullable final String databaseScope;
+ final String targetEndpoint;
+ final long operationUid;
+ final boolean preferLeader;
+
+ private RoutingTarget(
+ @Nullable String databaseScope,
+ String targetEndpoint,
+ long operationUid,
+ boolean preferLeader) {
+ this.databaseScope = databaseScope;
+ this.targetEndpoint = targetEndpoint;
+ this.operationUid = operationUid;
+ this.preferLeader = preferLeader;
+ }
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java
new file mode 100644
index 000000000000..287862c461bc
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java
@@ -0,0 +1,843 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.cloud.NoCredentials;
+import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime;
+import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
+import com.google.cloud.spanner.spi.v1.KeyRecipeCache;
+import com.google.protobuf.AbstractMessage;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.ListValue;
+import com.google.protobuf.TextFormat;
+import com.google.protobuf.Value;
+import com.google.rpc.RetryInfo;
+import com.google.spanner.v1.CacheUpdate;
+import com.google.spanner.v1.DirectedReadOptions;
+import com.google.spanner.v1.DirectedReadOptions.IncludeReplicas;
+import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection;
+import com.google.spanner.v1.Group;
+import com.google.spanner.v1.Range;
+import com.google.spanner.v1.ReadRequest;
+import com.google.spanner.v1.RecipeList;
+import com.google.spanner.v1.ResultSetMetadata;
+import com.google.spanner.v1.RoutingHint;
+import com.google.spanner.v1.StructType;
+import com.google.spanner.v1.Tablet;
+import com.google.spanner.v1.Type;
+import com.google.spanner.v1.TypeCode;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.protobuf.ProtoUtils;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class LocationAwareSharedBackendReplicaHarnessTest {
+
+ private static final String PROJECT = "fake-project";
+ private static final String INSTANCE = "fake-instance";
+ private static final String DATABASE = "fake-database";
+ private static final String TABLE = "T";
+ private static final String REPLICA_LOCATION = "us-east1";
+ private static final Statement SEED_QUERY = Statement.of("SELECT 1");
+ private static final ByteString RESUME_TOKEN_AFTER_FIRST_ROW =
+ ByteString.copyFromUtf8("000000001");
+ private static final DirectedReadOptions DIRECTED_READ_OPTIONS =
+ DirectedReadOptions.newBuilder()
+ .setIncludeReplicas(
+ IncludeReplicas.newBuilder()
+ .addReplicaSelections(
+ ReplicaSelection.newBuilder()
+ .setLocation(REPLICA_LOCATION)
+ .setType(ReplicaSelection.Type.READ_ONLY)
+ .build())
+ .build())
+ .build();
+ private static SharedBackendReplicaHarness harness;
+
+ @BeforeClass
+ public static void enableLocationAwareRouting() throws Exception {
+ SpannerOptions.useEnvironment(
+ new SpannerOptions.SpannerEnvironment() {
+ @Override
+ public boolean isEnableLocationApi() {
+ return true;
+ }
+ });
+ harness = SharedBackendReplicaHarness.create(2);
+ }
+
+ @Before
+ public void resetHarness() {
+ harness.reset();
+ }
+
+ @AfterClass
+ public static void restoreEnvironment() throws Exception {
+ try {
+ if (harness != null) {
+ harness.close();
+ }
+ } finally {
+ harness = null;
+ SpannerOptions.useDefaultEnvironment();
+ }
+ }
+
+ @Test
+ public void singleUseReadReroutesOnResourceExhaustedForBypassTraffic() throws Exception {
+ try (Spanner spanner = createSpanner(harness)) {
+ configureBackend(harness, singleRowReadResultSet("b"));
+ DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE));
+
+ seedLocationMetadata(client);
+ int firstReplicaIndex = waitForReplicaRoutedRead(client, harness);
+ int secondReplicaIndex = 1 - firstReplicaIndex;
+ harness.clearRequests();
+
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .putMethodErrors(
+ SharedBackendReplicaHarness.METHOD_STREAMING_READ,
+ resourceExhausted("busy-routed-replica"));
+
+ try (ResultSet resultSet =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ assertTrue(resultSet.next());
+ }
+
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .defaultReplica
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ ReadRequest replicaARequest =
+ (ReadRequest)
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0);
+ assertTrue(replicaARequest.getResumeToken().isEmpty());
+ assertRetriedOnSameLogicalRequest(
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0),
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0));
+ }
+ }
+
+ @Test
+ public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() throws Exception {
+ try (Spanner spanner = createSpanner(harness)) {
+ configureBackend(harness, singleRowReadResultSet("b"));
+ DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE));
+
+ seedLocationMetadata(client);
+ int firstReplicaIndex = waitForReplicaRoutedRead(client, harness);
+ int secondReplicaIndex = 1 - firstReplicaIndex;
+ harness.clearRequests();
+
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .putMethodErrors(
+ SharedBackendReplicaHarness.METHOD_STREAMING_READ,
+ resourceExhaustedWithRetryInfo("busy-routed-replica"));
+
+ try (ResultSet firstRead =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ assertTrue(firstRead.next());
+ }
+
+ try (ResultSet secondRead =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ assertTrue(secondRead.next());
+ }
+
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 2,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .defaultReplica
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ List replicaBRequests =
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ);
+ for (AbstractMessage request : replicaBRequests) {
+ assertTrue(((ReadRequest) request).getResumeToken().isEmpty());
+ }
+ List replicaBRequestIds =
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ);
+ assertRetriedOnSameLogicalRequest(
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0),
+ replicaBRequestIds.get(0));
+ assertNotEquals(
+ XGoogSpannerRequestId.of(replicaBRequestIds.get(0)).getLogicalRequestKey(),
+ XGoogSpannerRequestId.of(replicaBRequestIds.get(1)).getLogicalRequestKey());
+ }
+ }
+
+ @Test
+ public void singleUseReadReroutesOnUnavailableForBypassTraffic() throws Exception {
+ try (Spanner spanner = createSpanner(harness)) {
+ configureBackend(harness, singleRowReadResultSet("b"));
+ DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE));
+
+ seedLocationMetadata(client);
+ int firstReplicaIndex = waitForReplicaRoutedRead(client, harness);
+ int secondReplicaIndex = 1 - firstReplicaIndex;
+ harness.clearRequests();
+
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .putMethodErrors(
+ SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica"));
+
+ try (ResultSet resultSet =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ assertTrue(resultSet.next());
+ }
+
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .defaultReplica
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ ReadRequest replicaARequest =
+ (ReadRequest)
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0);
+ assertTrue(replicaARequest.getResumeToken().isEmpty());
+ assertRetriedOnSameLogicalRequest(
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0),
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0));
+ }
+ }
+
+ @Test
+ public void singleUseReadCooldownSkipsUnavailableReplicaOnNextRequestForBypassTraffic()
+ throws Exception {
+ try (Spanner spanner = createSpanner(harness)) {
+ configureBackend(harness, singleRowReadResultSet("b"));
+ DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE));
+
+ seedLocationMetadata(client);
+ int firstReplicaIndex = waitForReplicaRoutedRead(client, harness);
+ int secondReplicaIndex = 1 - firstReplicaIndex;
+ harness.clearRequests();
+
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .putMethodErrors(
+ SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica"));
+
+ try (ResultSet firstRead =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ assertTrue(firstRead.next());
+ }
+
+ try (ResultSet secondRead =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ assertTrue(secondRead.next());
+ }
+
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 2,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .defaultReplica
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ List replicaBRequests =
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ);
+ for (AbstractMessage request : replicaBRequests) {
+ assertTrue(((ReadRequest) request).getResumeToken().isEmpty());
+ }
+ List replicaBRequestIds =
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ);
+ assertRetriedOnSameLogicalRequest(
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0),
+ replicaBRequestIds.get(0));
+ assertNotEquals(
+ XGoogSpannerRequestId.of(replicaBRequestIds.get(0)).getLogicalRequestKey(),
+ XGoogSpannerRequestId.of(replicaBRequestIds.get(1)).getLogicalRequestKey());
+ }
+ }
+
+ @Test
+ public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTraffic()
+ throws Exception {
+ try (Spanner spanner = createSpanner(harness)) {
+ configureBackend(harness, multiRowReadResultSet("b", "c", "d"));
+ DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE));
+
+ seedLocationMetadata(client);
+ int firstReplicaIndex = waitForReplicaRoutedRead(client, harness);
+ int secondReplicaIndex = 1 - firstReplicaIndex;
+ harness.clearRequests();
+
+ harness.backend.setStreamingReadExecutionTime(
+ SimulatedExecutionTime.ofStreamException(resourceExhausted("busy-routed-replica"), 1L));
+
+ List rows = new ArrayList<>();
+ try (ResultSet resultSet =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ while (resultSet.next()) {
+ rows.add(resultSet.getString(0));
+ }
+ }
+
+ assertEquals(Arrays.asList("b", "c", "d"), rows);
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 1,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .defaultReplica
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+
+ ReadRequest replicaARequest =
+ (ReadRequest)
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0);
+ ReadRequest replicaBRequest =
+ (ReadRequest)
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0);
+ assertTrue(replicaARequest.getResumeToken().isEmpty());
+ assertEquals(RESUME_TOKEN_AFTER_FIRST_ROW, replicaBRequest.getResumeToken());
+ assertRetriedOnSameLogicalRequest(
+ harness
+ .replicas
+ .get(firstReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0),
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .get(0));
+ }
+ }
+
+ @Test
+ public void readWriteTransactionAbortedCommitUsesReadAffinityReplicaForBypassTraffic()
+ throws Exception {
+ try (Spanner spanner = createSpanner(harness)) {
+ configureBackend(harness, singleRowReadResultSet("b"), /* leaderReplicaIndex= */ 1);
+ DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE));
+
+ seedLocationMetadata(client);
+ waitForReplicaRoutedStrongRead(client, harness, /* expectedReplicaIndex= */ 1);
+ harness.clearRequests();
+ AtomicInteger attempts = new AtomicInteger();
+ AtomicInteger firstReplicaIndex = new AtomicInteger(-1);
+
+ client
+ .readWriteTransaction()
+ .run(
+ transaction -> {
+ int attempt = attempts.incrementAndGet();
+ try (ResultSet resultSet =
+ transaction.read(TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k"))) {
+ assertTrue(resultSet.next());
+ }
+
+ if (attempt == 1) {
+ int routedReplicaIndex =
+ findReplicaWithRequest(
+ harness, SharedBackendReplicaHarness.METHOD_STREAMING_READ);
+ if (routedReplicaIndex < 0) {
+ fail("Expected read-write transaction read to route to a bypass replica");
+ }
+ firstReplicaIndex.set(routedReplicaIndex);
+ harness
+ .replicas
+ .get(routedReplicaIndex)
+ .putMethodErrors(
+ SharedBackendReplicaHarness.METHOD_COMMIT,
+ Status.ABORTED
+ .withDescription("commit aborted on routed replica")
+ .asRuntimeException());
+ }
+ return null;
+ });
+
+ assertEquals(2, attempts.get());
+ assertEquals(1, firstReplicaIndex.get());
+ int secondReplicaIndex = 1 - firstReplicaIndex.get();
+ assertEquals(
+ 2,
+ harness
+ .replicas
+ .get(firstReplicaIndex.get())
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .size());
+ assertEquals(
+ 2,
+ harness
+ .replicas
+ .get(firstReplicaIndex.get())
+ .getRequests(SharedBackendReplicaHarness.METHOD_COMMIT)
+ .size());
+ assertEquals(
+ 0,
+ harness
+ .replicas
+ .get(secondReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_COMMIT)
+ .size());
+ assertEquals(
+ 0, harness.defaultReplica.getRequests(SharedBackendReplicaHarness.METHOD_COMMIT).size());
+ }
+ }
+
+ private static Spanner createSpanner(SharedBackendReplicaHarness harness) {
+ return SpannerOptions.newBuilder()
+ .usePlainText()
+ .setExperimentalHost(harness.defaultAddress)
+ .setSessionPoolOption(
+ SessionPoolOptions.newBuilder()
+ .setExperimentalHost()
+ .setUseMultiplexedSession(true)
+ .setUseMultiplexedSessionForRW(true)
+ .build())
+ .setProjectId(PROJECT)
+ .setCredentials(NoCredentials.getInstance())
+ .setChannelEndpointCacheFactory(null)
+ .build()
+ .getService();
+ }
+
+ private static void configureBackend(
+ SharedBackendReplicaHarness harness, com.google.spanner.v1.ResultSet readResultSet)
+ throws TextFormat.ParseException {
+ configureBackend(harness, readResultSet, /* leaderReplicaIndex= */ 0);
+ }
+
+ private static void configureBackend(
+ SharedBackendReplicaHarness harness,
+ com.google.spanner.v1.ResultSet readResultSet,
+ int leaderReplicaIndex)
+ throws TextFormat.ParseException {
+ Statement readStatement =
+ StatementResult.createReadStatement(
+ TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k"));
+ harness.backend.putStatementResult(StatementResult.query(readStatement, readResultSet));
+ harness.backend.putStatementResult(
+ StatementResult.query(
+ SEED_QUERY,
+ singleRowReadResultSet("seed").toBuilder()
+ .setCacheUpdate(cacheUpdate(harness, leaderReplicaIndex))
+ .build()));
+ }
+
+ private static void seedLocationMetadata(DatabaseClient client) {
+ try (com.google.cloud.spanner.ResultSet resultSet =
+ client.singleUse().executeQuery(SEED_QUERY)) {
+ while (resultSet.next()) {
+ // Consume the cache update on the first query result.
+ }
+ }
+ }
+
+ private static int waitForReplicaRoutedRead(
+ DatabaseClient client, SharedBackendReplicaHarness harness) throws InterruptedException {
+ long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10);
+ while (System.nanoTime() < deadlineNanos) {
+ try (ResultSet resultSet =
+ client
+ .singleUse()
+ .read(
+ TABLE,
+ KeySet.singleKey(Key.of("b")),
+ Arrays.asList("k"),
+ Options.directedRead(DIRECTED_READ_OPTIONS))) {
+ if (resultSet.next()) {
+ for (int replicaIndex = 0; replicaIndex < harness.replicas.size(); replicaIndex++) {
+ if (!harness
+ .replicas
+ .get(replicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .isEmpty()) {
+ return replicaIndex;
+ }
+ }
+ }
+ }
+ Thread.sleep(50L);
+ }
+ throw new AssertionError("Timed out waiting for location-aware read to route to replica");
+ }
+
+ private static void waitForReplicaRoutedStrongRead(
+ DatabaseClient client, SharedBackendReplicaHarness harness, int expectedReplicaIndex)
+ throws InterruptedException {
+ long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10);
+ while (System.nanoTime() < deadlineNanos) {
+ harness.clearRequests();
+ try (ResultSet resultSet =
+ client.singleUse().read(TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k"))) {
+ if (resultSet.next()) {
+ if (!harness
+ .replicas
+ .get(expectedReplicaIndex)
+ .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ)
+ .isEmpty()) {
+ return;
+ }
+ }
+ }
+ Thread.sleep(50L);
+ }
+ throw new AssertionError(
+ "Timed out waiting for strong read to route to replica " + expectedReplicaIndex);
+ }
+
+ private static int findReplicaWithRequest(SharedBackendReplicaHarness harness, String method) {
+ for (int replicaIndex = 0; replicaIndex < harness.replicas.size(); replicaIndex++) {
+ if (!harness.replicas.get(replicaIndex).getRequests(method).isEmpty()) {
+ return replicaIndex;
+ }
+ }
+ return -1;
+ }
+
+ private static CacheUpdate cacheUpdate(SharedBackendReplicaHarness harness)
+ throws TextFormat.ParseException {
+ return cacheUpdate(harness, /* leaderReplicaIndex= */ 0);
+ }
+
+ private static CacheUpdate cacheUpdate(
+ SharedBackendReplicaHarness harness, int leaderReplicaIndex)
+ throws TextFormat.ParseException {
+ RecipeList recipes = readRecipeList();
+ RoutingHint routingHint = exactReadRoutingHint(recipes);
+ ByteString limitKey = routingHint.getLimitKey();
+ if (limitKey.isEmpty()) {
+ limitKey = routingHint.getKey().concat(ByteString.copyFrom(new byte[] {0}));
+ }
+
+ return CacheUpdate.newBuilder()
+ .setDatabaseId(12345L)
+ .setKeyRecipes(recipes)
+ .addRange(
+ Range.newBuilder()
+ .setStartKey(routingHint.getKey())
+ .setLimitKey(limitKey)
+ .setGroupUid(1L)
+ .setSplitId(1L)
+ .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1")))
+ .addGroup(
+ Group.newBuilder()
+ .setGroupUid(1L)
+ .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1"))
+ .setLeaderIndex(leaderReplicaIndex)
+ .addTablets(
+ Tablet.newBuilder()
+ .setTabletUid(11L)
+ .setServerAddress(harness.replicaAddresses.get(0))
+ .setLocation(REPLICA_LOCATION)
+ .setRole(Tablet.Role.READ_ONLY)
+ .setDistance(0))
+ .addTablets(
+ Tablet.newBuilder()
+ .setTabletUid(12L)
+ .setServerAddress(harness.replicaAddresses.get(1))
+ .setLocation(REPLICA_LOCATION)
+ .setRole(Tablet.Role.READ_ONLY)
+ .setDistance(0)))
+ .build();
+ }
+
+ private static RecipeList readRecipeList() throws TextFormat.ParseException {
+ RecipeList.Builder recipes = RecipeList.newBuilder();
+ TextFormat.merge(
+ "schema_generation: \"1\"\n"
+ + "recipe {\n"
+ + " table_name: \""
+ + TABLE
+ + "\"\n"
+ + " part { tag: 1 }\n"
+ + " part {\n"
+ + " order: ASCENDING\n"
+ + " null_order: NULLS_FIRST\n"
+ + " type { code: STRING }\n"
+ + " identifier: \"k\"\n"
+ + " }\n"
+ + "}\n",
+ recipes);
+ return recipes.build();
+ }
+
+ private static RoutingHint exactReadRoutingHint(RecipeList recipes) {
+ KeyRecipeCache recipeCache = new KeyRecipeCache();
+ recipeCache.addRecipes(recipes);
+ ReadRequest.Builder request =
+ ReadRequest.newBuilder()
+ .setSession(
+ String.format(
+ "projects/%s/instances/%s/databases/%s/sessions/test-session",
+ PROJECT, INSTANCE, DATABASE))
+ .setTable(TABLE)
+ .addAllColumns(Arrays.asList("k"))
+ .setDirectedReadOptions(DIRECTED_READ_OPTIONS);
+ KeySet.singleKey(Key.of("b")).appendToProto(request.getKeySetBuilder());
+ recipeCache.computeKeys(request);
+ return request.getRoutingHint();
+ }
+
+ private static io.grpc.StatusRuntimeException resourceExhaustedWithRetryInfo(String description) {
+ Metadata trailers = new Metadata();
+ trailers.put(
+ ProtoUtils.keyForProto(RetryInfo.getDefaultInstance()),
+ RetryInfo.newBuilder()
+ .setRetryDelay(
+ com.google.protobuf.Duration.newBuilder()
+ .setNanos((int) TimeUnit.MILLISECONDS.toNanos(1L))
+ .build())
+ .build());
+ return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException(trailers);
+ }
+
+ private static StatusRuntimeException resourceExhausted(String description) {
+ return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException();
+ }
+
+ private static StatusRuntimeException unavailable(String description) {
+ return Status.UNAVAILABLE.withDescription(description).asRuntimeException();
+ }
+
+ private static void assertRetriedOnSameLogicalRequest(
+ String firstRequestId, String secondRequestId) {
+ XGoogSpannerRequestId first = XGoogSpannerRequestId.of(firstRequestId);
+ XGoogSpannerRequestId second = XGoogSpannerRequestId.of(secondRequestId);
+ assertEquals(first.getLogicalRequestKey(), second.getLogicalRequestKey());
+ assertEquals(first.getAttempt() + 1, second.getAttempt());
+ }
+
+ private static com.google.spanner.v1.ResultSet singleRowReadResultSet(String value) {
+ return readResultSet(Arrays.asList(value));
+ }
+
+ private static com.google.spanner.v1.ResultSet multiRowReadResultSet(String... values) {
+ return readResultSet(Arrays.asList(values));
+ }
+
+ private static com.google.spanner.v1.ResultSet readResultSet(List values) {
+ com.google.spanner.v1.ResultSet.Builder builder =
+ com.google.spanner.v1.ResultSet.newBuilder()
+ .setMetadata(
+ ResultSetMetadata.newBuilder()
+ .setRowType(
+ StructType.newBuilder()
+ .addFields(
+ StructType.Field.newBuilder()
+ .setName("k")
+ .setType(Type.newBuilder().setCode(TypeCode.STRING).build())
+ .build())
+ .build()));
+ for (String value : values) {
+ builder.addRows(
+ ListValue.newBuilder()
+ .addValues(Value.newBuilder().setStringValue(value).build())
+ .build());
+ }
+ return builder.build();
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
index 6f40052d0aed..3ea19ad2422a 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
@@ -202,9 +202,15 @@ private static class PartialResultSetsIterator implements Iterator responseObserver,
SimulatedExecutionTime executionTime,
boolean isMultiplexedSession)
@@ -1783,7 +1803,8 @@ private void returnPartialResultSet(
new PartialResultSetsIterator(
resultSet,
isMultiplexedSession && isReadWriteTransaction(transactionId),
- transactionId);
+ transactionId,
+ resumeToken);
long index = 0L;
while (iterator.hasNext()) {
SimulatedExecutionTime.checkStreamException(
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java
new file mode 100644
index 000000000000..7aa5eb88c3e0
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java
@@ -0,0 +1,326 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import com.google.protobuf.AbstractMessage;
+import com.google.protobuf.Empty;
+import com.google.spanner.v1.BatchCreateSessionsRequest;
+import com.google.spanner.v1.BatchCreateSessionsResponse;
+import com.google.spanner.v1.BeginTransactionRequest;
+import com.google.spanner.v1.CommitRequest;
+import com.google.spanner.v1.CommitResponse;
+import com.google.spanner.v1.CreateSessionRequest;
+import com.google.spanner.v1.DeleteSessionRequest;
+import com.google.spanner.v1.ExecuteSqlRequest;
+import com.google.spanner.v1.GetSessionRequest;
+import com.google.spanner.v1.PartialResultSet;
+import com.google.spanner.v1.ReadRequest;
+import com.google.spanner.v1.ResultSet;
+import com.google.spanner.v1.RollbackRequest;
+import com.google.spanner.v1.Session;
+import com.google.spanner.v1.SpannerGrpc;
+import com.google.spanner.v1.Transaction;
+import io.grpc.Metadata;
+import io.grpc.Server;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.ServerInterceptors;
+import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
+import io.grpc.stub.StreamObserver;
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Shared-backend replica harness for end-to-end location-aware routing tests. */
+final class SharedBackendReplicaHarness implements Closeable {
+
+ static final String METHOD_BATCH_CREATE_SESSIONS = "BatchCreateSessions";
+ static final String METHOD_BEGIN_TRANSACTION = "BeginTransaction";
+ static final String METHOD_COMMIT = "Commit";
+ static final String METHOD_CREATE_SESSION = "CreateSession";
+ static final String METHOD_DELETE_SESSION = "DeleteSession";
+ static final String METHOD_EXECUTE_SQL = "ExecuteSql";
+ static final String METHOD_EXECUTE_STREAMING_SQL = "ExecuteStreamingSql";
+ static final String METHOD_GET_SESSION = "GetSession";
+ static final String METHOD_READ = "Read";
+ static final String METHOD_ROLLBACK = "Rollback";
+ static final String METHOD_STREAMING_READ = "StreamingRead";
+
+ static final class HookedReplicaSpannerService extends SpannerGrpc.SpannerImplBase {
+ private final MockSpannerServiceImpl backend;
+ private final Map> methodErrors = new HashMap<>();
+ private final Map> requests = new HashMap<>();
+ private final Map> requestIds = new HashMap<>();
+
+ private HookedReplicaSpannerService(MockSpannerServiceImpl backend) {
+ this.backend = backend;
+ }
+
+ synchronized void putMethodErrors(String method, Throwable... errors) {
+ ArrayDeque queue = new ArrayDeque<>();
+ for (Throwable error : errors) {
+ queue.addLast(error);
+ }
+ methodErrors.put(method, queue);
+ }
+
+ synchronized List getRequests(String method) {
+ return new ArrayList<>(requests.getOrDefault(method, new ArrayList<>()));
+ }
+
+ synchronized List getRequestIds(String method) {
+ return new ArrayList<>(requestIds.getOrDefault(method, new ArrayList<>()));
+ }
+
+ synchronized void clearRequests() {
+ requests.clear();
+ requestIds.clear();
+ }
+
+ synchronized void clearMethodErrors() {
+ methodErrors.clear();
+ }
+
+ private synchronized void recordRequest(String method, AbstractMessage request) {
+ requests.computeIfAbsent(method, ignored -> new ArrayList<>()).add(request);
+ }
+
+ private synchronized void recordRequestId(String method, String requestId) {
+ requestIds.computeIfAbsent(method, ignored -> new ArrayList<>()).add(requestId);
+ }
+
+ private synchronized Throwable nextError(String method) {
+ ArrayDeque queue = methodErrors.get(method);
+ if (queue == null || queue.isEmpty()) {
+ return null;
+ }
+ return queue.removeFirst();
+ }
+
+ private boolean maybeFail(String method, StreamObserver> responseObserver) {
+ Throwable error = nextError(method);
+ if (error == null) {
+ return false;
+ }
+ responseObserver.onError(error);
+ return true;
+ }
+
+ @Override
+ public void batchCreateSessions(
+ BatchCreateSessionsRequest request,
+ StreamObserver responseObserver) {
+ recordRequest(METHOD_BATCH_CREATE_SESSIONS, request);
+ if (!maybeFail(METHOD_BATCH_CREATE_SESSIONS, responseObserver)) {
+ backend.batchCreateSessions(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void beginTransaction(
+ BeginTransactionRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_BEGIN_TRANSACTION, request);
+ if (!maybeFail(METHOD_BEGIN_TRANSACTION, responseObserver)) {
+ backend.beginTransaction(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void commit(CommitRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_COMMIT, request);
+ if (!maybeFail(METHOD_COMMIT, responseObserver)) {
+ backend.commit(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void createSession(
+ CreateSessionRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_CREATE_SESSION, request);
+ if (!maybeFail(METHOD_CREATE_SESSION, responseObserver)) {
+ backend.createSession(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void deleteSession(
+ DeleteSessionRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_DELETE_SESSION, request);
+ if (!maybeFail(METHOD_DELETE_SESSION, responseObserver)) {
+ backend.deleteSession(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void executeSql(ExecuteSqlRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_EXECUTE_SQL, request);
+ if (!maybeFail(METHOD_EXECUTE_SQL, responseObserver)) {
+ backend.executeSql(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void executeStreamingSql(
+ ExecuteSqlRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_EXECUTE_STREAMING_SQL, request);
+ if (!maybeFail(METHOD_EXECUTE_STREAMING_SQL, responseObserver)) {
+ backend.executeStreamingSql(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void getSession(GetSessionRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_GET_SESSION, request);
+ if (!maybeFail(METHOD_GET_SESSION, responseObserver)) {
+ backend.getSession(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void read(ReadRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_READ, request);
+ if (!maybeFail(METHOD_READ, responseObserver)) {
+ backend.read(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void rollback(RollbackRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_ROLLBACK, request);
+ if (!maybeFail(METHOD_ROLLBACK, responseObserver)) {
+ backend.rollback(request, responseObserver);
+ }
+ }
+
+ @Override
+ public void streamingRead(
+ ReadRequest request, StreamObserver responseObserver) {
+ recordRequest(METHOD_STREAMING_READ, request);
+ if (!maybeFail(METHOD_STREAMING_READ, responseObserver)) {
+ backend.streamingRead(request, responseObserver);
+ }
+ }
+ }
+
+ private final List servers;
+ final MockSpannerServiceImpl backend;
+ final HookedReplicaSpannerService defaultReplica;
+ final String defaultAddress;
+ final List replicas;
+ final List replicaAddresses;
+
+ private SharedBackendReplicaHarness(
+ MockSpannerServiceImpl backend,
+ HookedReplicaSpannerService defaultReplica,
+ String defaultAddress,
+ List replicas,
+ List replicaAddresses,
+ List servers) {
+ this.backend = backend;
+ this.defaultReplica = defaultReplica;
+ this.defaultAddress = defaultAddress;
+ this.replicas = replicas;
+ this.replicaAddresses = replicaAddresses;
+ this.servers = servers;
+ }
+
+ static SharedBackendReplicaHarness create(int replicaCount) throws IOException {
+ MockSpannerServiceImpl backend = new MockSpannerServiceImpl();
+ backend.setAbortProbability(0.0D);
+ List servers = new ArrayList<>();
+ HookedReplicaSpannerService defaultReplica = new HookedReplicaSpannerService(backend);
+ List replicas = new ArrayList<>();
+ List replicaAddresses = new ArrayList<>();
+ String defaultAddress = startServer(servers, defaultReplica);
+ for (int i = 0; i < replicaCount; i++) {
+ HookedReplicaSpannerService replica = new HookedReplicaSpannerService(backend);
+ replicas.add(replica);
+ replicaAddresses.add(startServer(servers, replica));
+ }
+ return new SharedBackendReplicaHarness(
+ backend, defaultReplica, defaultAddress, replicas, replicaAddresses, servers);
+ }
+
+ private static String startServer(List servers, HookedReplicaSpannerService service)
+ throws IOException {
+ InetSocketAddress address = new InetSocketAddress("localhost", 0);
+ ServerInterceptor interceptor =
+ new ServerInterceptor() {
+ @Override
+ public ServerCall.Listener interceptCall(
+ ServerCall call, Metadata headers, ServerCallHandler next) {
+ service.recordRequestId(
+ call.getMethodDescriptor().getBareMethodName(),
+ headers.get(XGoogSpannerRequestId.REQUEST_ID_HEADER_KEY));
+ return next.startCall(call, headers);
+ }
+ };
+ Server server =
+ NettyServerBuilder.forAddress(address)
+ .addService(ServerInterceptors.intercept(service, interceptor))
+ .build()
+ .start();
+ servers.add(server);
+ return "localhost:" + server.getPort();
+ }
+
+ void clearRequests() {
+ defaultReplica.clearRequests();
+ for (HookedReplicaSpannerService replica : replicas) {
+ replica.clearRequests();
+ }
+ }
+
+ void reset() {
+ backend.reset();
+ backend.removeAllExecutionTimes();
+ backend.setAbortProbability(0.0D);
+ defaultReplica.clearRequests();
+ defaultReplica.clearMethodErrors();
+ for (HookedReplicaSpannerService replica : replicas) {
+ replica.clearRequests();
+ replica.clearMethodErrors();
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOException failure = null;
+ for (Server server : servers) {
+ server.shutdown();
+ }
+ for (Server server : servers) {
+ try {
+ server.awaitTermination(5L, java.util.concurrent.TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ if (failure == null) {
+ failure = new IOException("Interrupted while stopping replica harness", e);
+ }
+ }
+ }
+ if (failure != null) {
+ throw failure;
+ }
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderGoldenTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderGoldenTest.java
index 7e9d2476346b..a5b5338b0e06 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderGoldenTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderGoldenTest.java
@@ -38,6 +38,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -149,6 +150,7 @@ public void shutdown() {
private final class FakeEndpoint implements ChannelEndpoint {
private final String address;
+ private final AtomicInteger activeRequests = new AtomicInteger();
private FakeEndpoint(String address) {
this.address = address;
@@ -209,6 +211,21 @@ public String authority() {
}
};
}
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
}
}
}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java
index d73c946cc2ac..ee388677f416 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java
@@ -35,6 +35,7 @@
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -223,6 +224,7 @@ public void shutdown() {
private static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final ManagedChannel channel = new FakeManagedChannel();
+ private final AtomicInteger activeRequests = new AtomicInteger();
private FakeEndpoint(String address) {
this.address = address;
@@ -247,6 +249,21 @@ public boolean isTransientFailure() {
public ManagedChannel getChannel() {
return channel;
}
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
}
private static final class FakeManagedChannel extends ManagedChannel {
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java
new file mode 100644
index 000000000000..9f85777183d7
--- /dev/null
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java
@@ -0,0 +1,165 @@
+/*
+ * Copyright 2026 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.spi.v1;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.common.base.Ticker;
+import com.google.common.testing.FakeTicker;
+import io.grpc.ManagedChannel;
+import java.time.Duration;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class EndpointLatencyRegistryTest {
+ private static final String DATABASE_SCOPE = "projects/p/instances/i/databases/d";
+
+ private static final class TestEndpoint implements ChannelEndpoint {
+ private final AtomicInteger activeRequests = new AtomicInteger();
+
+ @Override
+ public String getAddress() {
+ return "test";
+ }
+
+ @Override
+ public boolean isHealthy() {
+ return true;
+ }
+
+ @Override
+ public boolean isTransientFailure() {
+ return false;
+ }
+
+ @Override
+ public ManagedChannel getChannel() {
+ return null;
+ }
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
+ }
+
+ @After
+ public void tearDown() {
+ EndpointLatencyRegistry.useTrackerTicker(Ticker.systemTicker());
+ EndpointLatencyRegistry.clear();
+ }
+
+ @Test
+ public void trackersExpireAfterAccessWindow() {
+ FakeTicker ticker = new FakeTicker();
+ EndpointLatencyRegistry.useTrackerTicker(ticker);
+
+ EndpointLatencyRegistry.recordLatency(
+ DATABASE_SCOPE, 101L, false, "server-a:1234", Duration.ofMillis(5));
+
+ assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 101L, false, "server-a:1234"))
+ .isTrue();
+
+ ticker.advance(
+ EndpointLatencyRegistry.TRACKER_EXPIRE_AFTER_ACCESS.toNanos() + 1L, TimeUnit.NANOSECONDS);
+
+ assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 101L, false, "server-a:1234"))
+ .isFalse();
+ }
+
+ @Test
+ public void accessKeepsTrackerAliveWithinExpiryWindow() {
+ FakeTicker ticker = new FakeTicker();
+ EndpointLatencyRegistry.useTrackerTicker(ticker);
+
+ EndpointLatencyRegistry.recordLatency(
+ DATABASE_SCOPE, 202L, false, "server-b:1234", Duration.ofMillis(7));
+
+ ticker.advance(
+ EndpointLatencyRegistry.TRACKER_EXPIRE_AFTER_ACCESS.toNanos() / 2L, TimeUnit.NANOSECONDS);
+ assertThat(
+ EndpointLatencyRegistry.getSelectionCost(DATABASE_SCOPE, 202L, false, "server-b:1234"))
+ .isGreaterThan(0.0);
+
+ ticker.advance(
+ EndpointLatencyRegistry.TRACKER_EXPIRE_AFTER_ACCESS.toNanos() / 2L, TimeUnit.NANOSECONDS);
+
+ assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 202L, false, "server-b:1234"))
+ .isTrue();
+ }
+
+ @Test
+ public void trackersAreIsolatedByDatabaseScope() {
+ EndpointLatencyRegistry.recordLatency(
+ "projects/p1/instances/i1/databases/d1",
+ 303L,
+ false,
+ "server-a:1234",
+ Duration.ofMillis(9));
+
+ assertThat(
+ EndpointLatencyRegistry.hasScore(
+ "projects/p1/instances/i1/databases/d1", 303L, false, "server-a:1234"))
+ .isTrue();
+ assertThat(
+ EndpointLatencyRegistry.hasScore(
+ "projects/p2/instances/i2/databases/d2", 303L, false, "server-a:1234"))
+ .isFalse();
+ }
+
+ @Test
+ public void trackersAreIsolatedByPreferLeader() {
+ EndpointLatencyRegistry.recordLatency(
+ DATABASE_SCOPE, 404L, true, "server-a:1234", Duration.ofMillis(9));
+
+ assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 404L, true, "server-a:1234"))
+ .isTrue();
+ assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 404L, false, "server-a:1234"))
+ .isFalse();
+ }
+
+ @Test
+ public void inflightCountDoesNotGoNegativeAndCanBeReusedAfterZero() {
+ TestEndpoint endpoint = new TestEndpoint();
+ endpoint.incrementActiveRequests();
+ assertThat(endpoint.getActiveRequestCount()).isEqualTo(1);
+
+ endpoint.decrementActiveRequests();
+ assertThat(endpoint.getActiveRequestCount()).isEqualTo(0);
+
+ endpoint.decrementActiveRequests();
+ assertThat(endpoint.getActiveRequestCount()).isEqualTo(0);
+
+ endpoint.incrementActiveRequests();
+ assertThat(endpoint.getActiveRequestCount()).isEqualTo(1);
+ }
+}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java
index 552cfd9bd2c8..341974f17fd6 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java
@@ -267,6 +267,73 @@ public void transientFailureEvictionTrackedUntilEndpointReadyAgain() throws Exce
() -> !manager.wasRecentlyEvictedTransientFailure("server1"));
}
+ @Test
+ public void transientFailureOscillationWithConnectingStillEvictsEndpoint() throws Exception {
+ KeyRangeCacheTest.FakeEndpointCache cache = new KeyRangeCacheTest.FakeEndpointCache();
+ manager =
+ new EndpointLifecycleManager(
+ cache, /* probeIntervalSeconds= */ 60, Duration.ofMinutes(30), Clock.systemUTC());
+
+ registerAddresses(manager, "server1");
+ awaitCondition(
+ "endpoint should be created in background", () -> cache.getIfPresent("server1") != null);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE);
+ manager.probe("server1");
+ assertEquals(1, manager.getEndpointState("server1").consecutiveTransientFailures);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.CONNECTING);
+ manager.probe("server1");
+ assertEquals(1, manager.getEndpointState("server1").consecutiveTransientFailures);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE);
+ manager.probe("server1");
+ assertEquals(2, manager.getEndpointState("server1").consecutiveTransientFailures);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.CONNECTING);
+ manager.probe("server1");
+ assertEquals(2, manager.getEndpointState("server1").consecutiveTransientFailures);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE);
+ manager.probe("server1");
+
+ assertFalse(manager.isManaged("server1"));
+ assertTrue(manager.wasRecentlyEvictedTransientFailure("server1"));
+ assertNull(cache.getIfPresent("server1"));
+ }
+
+ @Test
+ public void readyResetsTransientFailureCounterAfterRecovery() throws Exception {
+ KeyRangeCacheTest.FakeEndpointCache cache = new KeyRangeCacheTest.FakeEndpointCache();
+ manager =
+ new EndpointLifecycleManager(
+ cache, /* probeIntervalSeconds= */ 60, Duration.ofMinutes(30), Clock.systemUTC());
+
+ registerAddresses(manager, "server1");
+ awaitCondition(
+ "endpoint should be created in background", () -> cache.getIfPresent("server1") != null);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE);
+ manager.probe("server1");
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.CONNECTING);
+ manager.probe("server1");
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE);
+ manager.probe("server1");
+ assertEquals(2, manager.getEndpointState("server1").consecutiveTransientFailures);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.READY);
+ manager.probe("server1");
+ EndpointLifecycleManager.EndpointState state = manager.getEndpointState("server1");
+ assertNotNull(state);
+ assertEquals(0, state.consecutiveTransientFailures);
+ assertNotNull(state.lastReadyAt);
+
+ cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE);
+ manager.probe("server1");
+ assertEquals(1, manager.getEndpointState("server1").consecutiveTransientFailures);
+ assertTrue(manager.isManaged("server1"));
+ }
+
@Test
public void transientFailureEvictionMarkerRemovedWhenAddressNoLongerActive() throws Exception {
KeyRangeCacheTest.FakeEndpointCache cache = new KeyRangeCacheTest.FakeEndpointCache();
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java
index 306628b9bdab..84fa32b4fb7e 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java
@@ -20,6 +20,7 @@
import static org.junit.Assert.assertThrows;
import java.time.Duration;
+import java.util.function.LongSupplier;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -29,7 +30,7 @@ public class EwmaLatencyTrackerTest {
@Test
public void testInitialization() {
- EwmaLatencyTracker tracker = new EwmaLatencyTracker();
+ EwmaLatencyTracker tracker = new EwmaLatencyTracker(Duration.ofSeconds(10), new FakeClock());
tracker.update(Duration.ofNanos(100 * 1000));
assertEquals(100.0, tracker.getScore(), 0.001);
}
@@ -42,7 +43,7 @@ public void testUninitializedScore() {
@Test
public void testOverflowScore() {
- EwmaLatencyTracker tracker = new EwmaLatencyTracker();
+ EwmaLatencyTracker tracker = new EwmaLatencyTracker(Duration.ofSeconds(10), new FakeClock());
tracker.update(Duration.ofSeconds(Long.MAX_VALUE));
assertEquals((double) Long.MAX_VALUE, tracker.getScore(), 0.001);
}
@@ -50,7 +51,7 @@ public void testOverflowScore() {
@Test
public void testEwmaCalculation() {
double alpha = 0.5;
- EwmaLatencyTracker tracker = new EwmaLatencyTracker(alpha);
+ EwmaLatencyTracker tracker = new EwmaLatencyTracker(alpha, new FakeClock());
tracker.update(Duration.ofNanos(100 * 1000)); // Initial score = 100
assertEquals(100.0, tracker.getScore(), 0.001);
@@ -63,19 +64,21 @@ public void testEwmaCalculation() {
}
@Test
- public void testDefaultAlpha() {
- EwmaLatencyTracker tracker = new EwmaLatencyTracker();
+ public void testDefaultDecayUsesTimeBasedAlpha() {
+ FakeClock clock = new FakeClock();
+ EwmaLatencyTracker tracker = new EwmaLatencyTracker(Duration.ofSeconds(10), clock);
tracker.update(Duration.ofNanos(100 * 1000));
+ clock.advance(Duration.ofSeconds(10));
tracker.update(Duration.ofNanos(200 * 1000));
- double expected =
- EwmaLatencyTracker.DEFAULT_ALPHA * 200 + (1 - EwmaLatencyTracker.DEFAULT_ALPHA) * 100;
+ double alpha = 1.0 - Math.exp(-1.0);
+ double expected = alpha * 200 + (1.0 - alpha) * 100;
assertEquals(expected, tracker.getScore(), 0.001);
}
@Test
public void testRecordError() {
- EwmaLatencyTracker tracker = new EwmaLatencyTracker(0.5);
+ EwmaLatencyTracker tracker = new EwmaLatencyTracker(0.5, new FakeClock());
tracker.update(Duration.ofNanos(100 * 1000));
tracker.recordError(Duration.ofNanos(10000 * 1000)); // Score = 0.5 * 10000 + 0.5 * 100 = 5050
@@ -91,11 +94,24 @@ public void testInvalidAlpha() {
@Test
public void testAlphaOne() {
- EwmaLatencyTracker tracker = new EwmaLatencyTracker(1.0);
+ EwmaLatencyTracker tracker = new EwmaLatencyTracker(1.0, new FakeClock());
tracker.update(Duration.ofNanos(100 * 1000));
assertEquals(100.0, tracker.getScore(), 0.001);
tracker.update(Duration.ofNanos(200 * 1000));
assertEquals(200.0, tracker.getScore(), 0.001);
}
+
+ private static final class FakeClock implements LongSupplier {
+ private long currentNanos;
+
+ @Override
+ public long getAsLong() {
+ return currentNanos;
+ }
+
+ void advance(Duration duration) {
+ currentNanos += duration.toNanos();
+ }
+ }
}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java
index 165557608ac3..5cd1162cbaf4 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java
@@ -34,6 +34,7 @@
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.ApiClientHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
+import com.google.api.gax.rpc.StatusCode.Code;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.AccessToken;
@@ -1059,6 +1060,100 @@ public boolean isEnableLocationApi() {
}
}
+ @Test
+ public void testReadRetryableCodesIncludeResourceExhaustedWhenLocationApiEnabled() {
+ try {
+ SpannerOptions.useEnvironment(
+ new SpannerOptions.SpannerEnvironment() {
+ @Override
+ public boolean isEnableLocationApi() {
+ return true;
+ }
+ });
+ GapicSpannerRpc rpc = new GapicSpannerRpc(createSpannerOptions(), true);
+ try {
+ assertThat(rpc.getReadRetryableCodes()).contains(Code.RESOURCE_EXHAUSTED);
+ } finally {
+ rpc.shutdown();
+ }
+ } finally {
+ SpannerOptions.useDefaultEnvironment();
+ }
+ }
+
+ @Test
+ public void testExecuteQueryRetryableCodesIncludeResourceExhaustedWhenLocationApiEnabled() {
+ try {
+ SpannerOptions.useEnvironment(
+ new SpannerOptions.SpannerEnvironment() {
+ @Override
+ public boolean isEnableLocationApi() {
+ return true;
+ }
+ });
+ GapicSpannerRpc rpc = new GapicSpannerRpc(createSpannerOptions(), true);
+ try {
+ assertThat(rpc.getExecuteQueryRetryableCodes()).contains(Code.RESOURCE_EXHAUSTED);
+ } finally {
+ rpc.shutdown();
+ }
+ } finally {
+ SpannerOptions.useDefaultEnvironment();
+ }
+ }
+
+ @Test
+ public void testReadRetryableCodesDoNotAddResourceExhaustedWhenLocationApiDisabled() {
+ try {
+ SpannerOptions.useEnvironment(
+ new SpannerOptions.SpannerEnvironment() {
+ @Override
+ public boolean isEnableLocationApi() {
+ return false;
+ }
+ });
+ GapicSpannerRpc rpc = new GapicSpannerRpc(createSpannerOptions(), true);
+ try {
+ assertThat(rpc.getReadRetryableCodes())
+ .isEqualTo(
+ createSpannerOptions()
+ .getSpannerStubSettings()
+ .streamingReadSettings()
+ .getRetryableCodes());
+ } finally {
+ rpc.shutdown();
+ }
+ } finally {
+ SpannerOptions.useDefaultEnvironment();
+ }
+ }
+
+ @Test
+ public void testExecuteQueryRetryableCodesDoNotAddResourceExhaustedWhenLocationApiDisabled() {
+ try {
+ SpannerOptions.useEnvironment(
+ new SpannerOptions.SpannerEnvironment() {
+ @Override
+ public boolean isEnableLocationApi() {
+ return false;
+ }
+ });
+ GapicSpannerRpc rpc = new GapicSpannerRpc(createSpannerOptions(), true);
+ try {
+ assertThat(rpc.getExecuteQueryRetryableCodes())
+ .isEqualTo(
+ createSpannerOptions()
+ .getSpannerStubSettings()
+ .executeStreamingSqlSettings()
+ .getRetryableCodes());
+ } finally {
+ rpc.shutdown();
+ }
+ } finally {
+ SpannerOptions.useDefaultEnvironment();
+ }
+ }
+
@Test
public void testGrpcGcpExtensionPreservesChannelConfigurator() throws Exception {
InstantiatingGrpcChannelProvider.Builder channelProviderBuilder =
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java
index 74afec18bfc3..cca418a53f99 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java
@@ -25,6 +25,9 @@
import org.junit.Test;
public class GrpcChannelEndpointCacheTest {
+ private static final String DEFAULT_ENDPOINT = "default.invalid:1234";
+ private static final String ROUTED_ENDPOINT_A = "replica-a.invalid:1111";
+ private static final String ROUTED_ENDPOINT_B = "replica-b.invalid:2222";
private static InstantiatingGrpcChannelProvider createProvider(String endpoint) {
return InstantiatingGrpcChannelProvider.newBuilder()
@@ -35,7 +38,7 @@ private static InstantiatingGrpcChannelProvider createProvider(String endpoint)
@Test
public void defaultChannelIsCached() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
try {
ChannelEndpoint defaultChannel = cache.defaultChannel();
ChannelEndpoint server = cache.get(defaultChannel.getAddress());
@@ -47,11 +50,11 @@ public void defaultChannelIsCached() throws Exception {
@Test
public void getCachesPerAddress() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
try {
- ChannelEndpoint first = cache.get("localhost:1111");
- ChannelEndpoint second = cache.get("localhost:1111");
- ChannelEndpoint third = cache.get("localhost:2222");
+ ChannelEndpoint first = cache.get(ROUTED_ENDPOINT_A);
+ ChannelEndpoint second = cache.get(ROUTED_ENDPOINT_A);
+ ChannelEndpoint third = cache.get(ROUTED_ENDPOINT_B);
assertThat(second).isSameInstanceAs(first);
assertThat(third).isNotSameInstanceAs(first);
@@ -62,11 +65,63 @@ public void getCachesPerAddress() throws Exception {
@Test
public void routedChannelsReuseDefaultAuthority() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
try {
- ChannelEndpoint routed = cache.get("localhost:1111");
+ ChannelEndpoint routed = cache.get(ROUTED_ENDPOINT_A);
- assertThat(routed.getChannel().authority()).isEqualTo("localhost:1234");
+ assertThat(routed.getChannel().authority()).isEqualTo(DEFAULT_ENDPOINT);
+ } finally {
+ cache.shutdown();
+ }
+ }
+
+ @Test
+ public void routedChannelsUseSingleUnderlyingChannel() throws Exception {
+ InstantiatingGrpcChannelProvider provider =
+ InstantiatingGrpcChannelProvider.newBuilder()
+ .setEndpoint(DEFAULT_ENDPOINT)
+ .setPoolSize(4)
+ .setChannelConfigurator(ManagedChannelBuilder::usePlaintext)
+ .build();
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(provider);
+ try {
+ InstantiatingGrpcChannelProvider routedProvider =
+ cache.createProviderWithAuthorityOverride(ROUTED_ENDPOINT_A);
+
+ assertThat(provider.toBuilder().getPoolSize()).isEqualTo(4);
+ assertThat(routedProvider.getChannelPoolSettings().getInitialChannelCount()).isEqualTo(1);
+ assertThat(routedProvider.getChannelPoolSettings().getMinChannelCount()).isEqualTo(1);
+ assertThat(routedProvider.getChannelPoolSettings().getMaxChannelCount()).isEqualTo(1);
+ } finally {
+ cache.shutdown();
+ }
+ }
+
+ @Test
+ public void routedChannelsOverrideKeepAliveSettingsOnlyForEndpointProvider() throws Exception {
+ InstantiatingGrpcChannelProvider provider =
+ InstantiatingGrpcChannelProvider.newBuilder()
+ .setEndpoint(DEFAULT_ENDPOINT)
+ .setPoolSize(4)
+ .setKeepAliveTimeDuration(java.time.Duration.ofSeconds(120))
+ .setKeepAliveTimeoutDuration(java.time.Duration.ofSeconds(60))
+ .setKeepAliveWithoutCalls(Boolean.FALSE)
+ .setChannelConfigurator(ManagedChannelBuilder::usePlaintext)
+ .build();
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(provider);
+ try {
+ InstantiatingGrpcChannelProvider routedProvider =
+ cache.createProviderWithAuthorityOverride(ROUTED_ENDPOINT_A);
+
+ assertThat(provider.getKeepAliveWithoutCalls()).isFalse();
+ assertThat(provider.getKeepAliveTimeDuration()).isEqualTo(java.time.Duration.ofSeconds(120));
+ assertThat(provider.getKeepAliveTimeoutDuration())
+ .isEqualTo(java.time.Duration.ofSeconds(60));
+ assertThat(routedProvider.getKeepAliveWithoutCalls()).isTrue();
+ assertThat(routedProvider.getKeepAliveTimeDuration())
+ .isEqualTo(GrpcChannelEndpointCache.ROUTED_KEEPALIVE_TIME);
+ assertThat(routedProvider.getKeepAliveTimeoutDuration())
+ .isEqualTo(GrpcChannelEndpointCache.ROUTED_KEEPALIVE_TIMEOUT);
} finally {
cache.shutdown();
}
@@ -74,11 +129,11 @@ public void routedChannelsReuseDefaultAuthority() throws Exception {
@Test
public void evictRemovesNonDefaultServer() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
try {
- ChannelEndpoint first = cache.get("localhost:1111");
- cache.evict("localhost:1111");
- ChannelEndpoint second = cache.get("localhost:1111");
+ ChannelEndpoint first = cache.get(ROUTED_ENDPOINT_A);
+ cache.evict(ROUTED_ENDPOINT_A);
+ ChannelEndpoint second = cache.get(ROUTED_ENDPOINT_A);
assertThat(second).isNotSameInstanceAs(first);
} finally {
@@ -88,7 +143,7 @@ public void evictRemovesNonDefaultServer() throws Exception {
@Test
public void evictIgnoresDefaultChannel() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
try {
ChannelEndpoint defaultChannel = cache.defaultChannel();
cache.evict(defaultChannel.getAddress());
@@ -102,18 +157,18 @@ public void evictIgnoresDefaultChannel() throws Exception {
@Test
public void shutdownPreventsNewServers() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
cache.shutdown();
- assertThrows(SpannerException.class, () -> cache.get("localhost:1111"));
+ assertThrows(SpannerException.class, () -> cache.get(ROUTED_ENDPOINT_A));
assertThat(cache.defaultChannel().getChannel().isShutdown()).isTrue();
}
@Test
public void healthReflectsChannelShutdown() throws Exception {
- GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234"));
+ GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT));
try {
- ChannelEndpoint server = cache.get("localhost:1111");
+ ChannelEndpoint server = cache.get(ROUTED_ENDPOINT_A);
// Newly created channel is not READY (likely IDLE), so isHealthy is false for location aware.
// isHealthy now requires READY state for location aware routing.
assertThat(server.isHealthy()).isFalse();
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java
index 1ad3888b4f9d..7a9201322c74 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java
@@ -24,6 +24,8 @@
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
import com.google.cloud.spanner.XGoogSpannerRequestId;
+import com.google.common.base.Ticker;
+import com.google.common.testing.FakeTicker;
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import com.google.protobuf.ListValue;
@@ -56,12 +58,18 @@
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.io.IOException;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
+import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -72,6 +80,12 @@ public class KeyAwareChannelTest {
private static final String SESSION =
"projects/p/instances/i/databases/d/sessions/test-session-id";
+ @After
+ public void clearSharedRoutingState() {
+ EndpointLatencyRegistry.clear();
+ RequestIdTargetTracker.clear();
+ }
+
@Test
public void cancelBeforeStartPreservesTrailersAndSkipsDelegateCreation() throws Exception {
TestHarness harness = createHarness();
@@ -458,9 +472,10 @@ public void singleUseCommitUsesSameMutationSelectionHeuristicAsBeginTransaction(
@Test
public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception {
- TestHarness harness = createHarness();
+ TestHarness harness = createHarness(createDeterministicCooldownTracker());
seedCache(harness, createLeaderAndReplicaCacheUpdate());
- CallOptions retryCallOptions = retryCallOptions(1L);
+ XGoogSpannerRequestId requestId = retryRequestId(1L);
+ CallOptions retryCallOptions = retryCallOptions(requestId);
ExecuteSqlRequest request =
ExecuteSqlRequest.newBuilder()
@@ -481,6 +496,8 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception {
harness.endpointCache.latestCallForAddress("server-a:1234");
firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata());
+ assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue();
+
ClientCall secondCall =
harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions);
secondCall.start(new CapturingListener(), new Metadata());
@@ -488,6 +505,14 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception {
assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1);
+ assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue();
+ }
+
+ @Test
+ public void resourceExhaustedOrUnavailableRoutedEndpointRecordsErrorPenalty() throws Exception {
+ assertRoutedEndpointErrorPenaltyRecorded(Status.RESOURCE_EXHAUSTED, 101L);
+ EndpointLatencyRegistry.clear();
+ assertRoutedEndpointErrorPenaltyRecorded(Status.UNAVAILABLE, 102L);
}
@Test
@@ -555,7 +580,7 @@ public void resourceExhaustedAffinityEndpointIsAvoidedForSubsequentTransactionRe
}
@Test
- public void resourceExhaustedRoutedEndpointFallsBackToDefaultWhenNoReplicaExists()
+ public void resourceExhaustedRoutedEndpointRetriesSameReplicaWhenSingleReplicaIsExcluded()
throws Exception {
TestHarness harness = createHarness();
CallOptions retryCallOptions = retryCallOptions(3L);
@@ -585,16 +610,68 @@ public void resourceExhaustedRoutedEndpointFallsBackToDefaultWhenNoReplicaExists
secondCall.start(new CapturingListener(), new Metadata());
secondCall.sendMessage(request);
- assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
- assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(2);
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2);
+ assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(1);
}
@Test
- public void resourceExhaustedSkipDoesNotAffectDifferentLogicalRequest() throws Exception {
+ public void
+ resourceExhaustedRoutedEndpointRetriesLowestCostExcludedReplicaWhenAllReplicasExcluded()
+ throws Exception {
TestHarness harness = createHarness();
seedCache(harness, createLeaderAndReplicaCacheUpdate());
- CallOptions firstLogicalRequest = retryCallOptions(4L);
- CallOptions secondLogicalRequest = retryCallOptions(5L);
+ CallOptions retryCallOptions = retryCallOptions(100L);
+ ExecuteSqlRequest request =
+ ExecuteSqlRequest.newBuilder()
+ .setSession(SESSION)
+ .setRoutingHint(RoutingHint.newBuilder().setKey(bytes("a")).build())
+ .build();
+
+ ClientCall firstCall =
+ harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions);
+ firstCall.start(new CapturingListener(), new Metadata());
+ firstCall.sendMessage(request);
+
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+ assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(0);
+
+ @SuppressWarnings("unchecked")
+ RecordingClientCall firstDelegate =
+ (RecordingClientCall)
+ harness.endpointCache.latestCallForAddress("server-a:1234");
+ firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata());
+
+ ClientCall secondCall =
+ harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions);
+ secondCall.start(new CapturingListener(), new Metadata());
+ secondCall.sendMessage(request);
+
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+ assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1);
+
+ @SuppressWarnings("unchecked")
+ RecordingClientCall secondDelegate =
+ (RecordingClientCall)
+ harness.endpointCache.latestCallForAddress("server-b:1234");
+ secondDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata());
+
+ ClientCall thirdCall =
+ harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions);
+ thirdCall.start(new CapturingListener(), new Metadata());
+ thirdCall.sendMessage(request);
+
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2);
+ assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1);
+ }
+
+ @Test
+ public void resourceExhaustedCooldownAffectsDifferentLogicalRequests() throws Exception {
+ TestHarness harness = createHarness(createDeterministicCooldownTracker());
+ seedCache(harness, createLeaderAndReplicaCacheUpdate());
+ XGoogSpannerRequestId firstRequestId = retryRequestId(4L);
+ XGoogSpannerRequestId secondRequestId = retryRequestId(5L);
+ CallOptions firstLogicalRequest = retryCallOptions(firstRequestId);
+ CallOptions secondLogicalRequest = retryCallOptions(secondRequestId);
ExecuteSqlRequest request =
ExecuteSqlRequest.newBuilder()
@@ -613,21 +690,23 @@ public void resourceExhaustedSkipDoesNotAffectDifferentLogicalRequest() throws E
harness.endpointCache.latestCallForAddress("server-a:1234");
firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata());
+ assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue();
+
ClientCall unrelatedCall =
harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), secondLogicalRequest);
unrelatedCall.start(new CapturingListener(), new Metadata());
unrelatedCall.sendMessage(request);
- assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2);
- assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(0);
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+ assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1);
ClientCall retriedFirstCall =
harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), firstLogicalRequest);
retriedFirstCall.start(new CapturingListener(), new Metadata());
retriedFirstCall.sendMessage(request);
- assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2);
- assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1);
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+ assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(2);
}
@Test
@@ -1055,6 +1134,58 @@ public void readOnlyTransactionCleanupOnClose() throws Exception {
harness.channel.clearTransactionAffinity(transactionId);
}
+ @Test
+ public void abandonedReadWriteTransactionAffinityExpiresAfterInactivity() throws Exception {
+ FakeTicker ticker = new FakeTicker();
+ TestHarness harness = createHarness(ticker);
+ ByteString transactionId = ByteString.copyFromUtf8("rw-tx-expired-affinity");
+ seedCache(harness, createTwoRangeCacheUpdate());
+
+ ClientCall beginCall =
+ harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT);
+ beginCall.start(new CapturingListener(), new Metadata());
+ beginCall.sendMessage(
+ ExecuteSqlRequest.newBuilder()
+ .setSession(SESSION)
+ .setTransaction(
+ TransactionSelector.newBuilder()
+ .setBegin(
+ TransactionOptions.newBuilder()
+ .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance())
+ .build()))
+ .setRoutingHint(RoutingHint.newBuilder().setKey(bytes("b")).build())
+ .build());
+
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+
+ @SuppressWarnings("unchecked")
+ RecordingClientCall beginDelegate =
+ (RecordingClientCall)
+ harness.endpointCache.latestCallForAddress("server-a:1234");
+ beginDelegate.emitOnMessage(
+ ResultSet.newBuilder()
+ .setMetadata(
+ ResultSetMetadata.newBuilder()
+ .setTransaction(Transaction.newBuilder().setId(transactionId)))
+ .build());
+ beginDelegate.emitOnClose(Status.OK, new Metadata());
+
+ ticker.advance(11, TimeUnit.MINUTES);
+
+ ClientCall nextCall =
+ harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT);
+ nextCall.start(new CapturingListener(), new Metadata());
+ nextCall.sendMessage(
+ ExecuteSqlRequest.newBuilder()
+ .setSession(SESSION)
+ .setTransaction(TransactionSelector.newBuilder().setId(transactionId))
+ .setRoutingHint(RoutingHint.newBuilder().setKey(bytes("n")).build())
+ .build());
+
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+ assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1);
+ }
+
private static CacheUpdate createTwoRangeCacheUpdate() {
return CacheUpdate.newBuilder()
.setDatabaseId(7L)
@@ -1235,13 +1366,37 @@ private static RecipeList parseRecipeList(String text) throws TextFormat.ParseEx
}
private static TestHarness createHarness() throws IOException {
+ return createHarness(new EndpointOverloadCooldownTracker(), Ticker.systemTicker());
+ }
+
+ private static TestHarness createHarness(Ticker ticker) throws IOException {
+ return createHarness(new EndpointOverloadCooldownTracker(), ticker);
+ }
+
+ private static TestHarness createHarness(EndpointOverloadCooldownTracker tracker)
+ throws IOException {
+ return createHarness(tracker, Ticker.systemTicker());
+ }
+
+ private static TestHarness createHarness(EndpointOverloadCooldownTracker tracker, Ticker ticker)
+ throws IOException {
FakeEndpointCache endpointCache = new FakeEndpointCache(DEFAULT_ADDRESS);
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost:9999").build();
- KeyAwareChannel channel = KeyAwareChannel.create(provider, baseProvider -> endpointCache);
+ KeyAwareChannel channel =
+ KeyAwareChannel.create(provider, baseProvider -> endpointCache, tracker, ticker);
return new TestHarness(channel, endpointCache, endpointCache.defaultManagedChannel());
}
+ private static EndpointOverloadCooldownTracker createDeterministicCooldownTracker() {
+ return new EndpointOverloadCooldownTracker(
+ Duration.ofMinutes(1),
+ Duration.ofMinutes(1),
+ Duration.ofMinutes(10),
+ Clock.fixed(Instant.ofEpochSecond(100), ZoneOffset.UTC),
+ bound -> bound - 1L);
+ }
+
private static final class TestHarness {
private final KeyAwareChannel channel;
private final FakeEndpointCache endpointCache;
@@ -1350,6 +1505,7 @@ int callCountForAddress(String address) {
private static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final FakeManagedChannel channel;
+ private final AtomicInteger activeRequests = new AtomicInteger();
private FakeEndpoint(String address) {
this.address = address;
@@ -1375,6 +1531,21 @@ public boolean isTransientFailure() {
public ManagedChannel getChannel() {
return channel;
}
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
}
private static final class FakeManagedChannel extends ManagedChannel {
@@ -1483,9 +1654,59 @@ private static ByteString bytes(String value) {
return ByteString.copyFromUtf8(value);
}
+ private static XGoogSpannerRequestId retryRequestId(long nthRequest) {
+ return XGoogSpannerRequestId.of(1L, 0L, nthRequest, 0L);
+ }
+
private static CallOptions retryCallOptions(long nthRequest) {
+ return retryCallOptions(retryRequestId(nthRequest));
+ }
+
+ private static CallOptions retryCallOptions(XGoogSpannerRequestId requestId) {
return CallOptions.DEFAULT.withOption(
- XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY,
- XGoogSpannerRequestId.of(1L, 0L, nthRequest, 0L));
+ XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY, requestId);
+ }
+
+ private static void assertRoutedEndpointErrorPenaltyRecorded(Status status, long operationUid)
+ throws Exception {
+ EndpointLatencyRegistry.clear();
+ TestHarness harness = createHarness();
+ seedCache(harness, createLeaderAndReplicaCacheUpdate());
+
+ ExecuteSqlRequest request =
+ ExecuteSqlRequest.newBuilder()
+ .setSession(SESSION)
+ .setRoutingHint(
+ RoutingHint.newBuilder().setKey(bytes("b")).setOperationUid(operationUid).build())
+ .build();
+
+ ClientCall call =
+ harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions(operationUid));
+ call.start(new CapturingListener(), new Metadata());
+ call.sendMessage(request);
+
+ assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1);
+
+ @SuppressWarnings("unchecked")
+ RecordingClientCall delegate =
+ (RecordingClientCall)
+ harness.endpointCache.latestCallForAddress("server-a:1234");
+ long routedOperationUid = delegate.lastMessage.getRoutingHint().getOperationUid();
+ assertThat(routedOperationUid).isGreaterThan(0L);
+ delegate.emitOnClose(status, new Metadata());
+
+ String databaseScope = "projects/p/instances/i/databases/d";
+ assertThat(
+ EndpointLatencyRegistry.hasScore(
+ databaseScope, routedOperationUid, true, "server-a:1234"))
+ .isTrue();
+ assertThat(
+ EndpointLatencyRegistry.getSelectionCost(
+ databaseScope, routedOperationUid, true, "server-a:1234"))
+ .isEqualTo((double) EndpointLatencyRegistry.DEFAULT_ERROR_PENALTY.toNanos() / 1_000D);
+ assertThat(
+ EndpointLatencyRegistry.hasScore(
+ databaseScope, routedOperationUid, true, "server-b:1234"))
+ .isFalse();
}
}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheGoldenTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheGoldenTest.java
index 7fa2874ada5b..b4b86cf8f258 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheGoldenTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheGoldenTest.java
@@ -34,6 +34,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -145,6 +146,7 @@ public void shutdown() {
private static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final ManagedChannel channel = new FakeManagedChannel();
+ private final AtomicInteger activeRequests = new AtomicInteger();
FakeEndpoint(String address) {
this.address = address;
@@ -169,6 +171,21 @@ public boolean isTransientFailure() {
public ManagedChannel getChannel() {
return channel;
}
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
}
private static final class FakeManagedChannel extends ManagedChannel {
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java
index b19123daa704..b7645f044a13 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java
@@ -33,16 +33,25 @@
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
+import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
+import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class KeyRangeCacheTest {
+ private static final long TEST_OPERATION_UID = 101L;
+
+ @After
+ public void tearDown() {
+ EndpointLatencyRegistry.clear();
+ }
@Test
public void skipsTransientFailureTabletWithSkippedTablet() {
@@ -130,8 +139,100 @@ public void skipsExplicitlyExcludedTablet() {
assertNotNull(server);
assertEquals("server2", server.getAddress());
- assertEquals(1, hint.getSkippedTabletUidCount());
- assertEquals(1L, hint.getSkippedTabletUid(0).getTabletUid());
+ assertEquals(0, hint.getSkippedTabletUidCount());
+ }
+
+ @Test
+ public void lookupRoutingHintReportsCacheMiss() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ KeyRangeCache.RouteLookupResult result =
+ cache.lookupRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint,
+ address -> false);
+
+ assertNull(result.endpoint);
+ assertEquals(KeyRangeCache.RouteFailureReason.CACHE_MISS, result.failureReason);
+ }
+
+ @Test
+ public void lookupRoutingHintReusesReplicaWhenAllCandidatesAreExcludedOrCoolingDown() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.addRanges(singleReplicaUpdate("server1"));
+ endpointCache.get("server1");
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ KeyRangeCache.RouteLookupResult result =
+ cache.lookupRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint,
+ "server1"::equals);
+
+ assertNotNull(result.endpoint);
+ assertEquals("server1", result.endpoint.getAddress());
+ assertEquals(KeyRangeCache.RouteFailureReason.NONE, result.failureReason);
+ }
+
+ @Test
+ public void lookupRoutingHintUsesLowestScoreWhenAllCandidatesAreExcludedOrCoolingDown() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, false, "server1", Duration.ofNanos(300_000L));
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, false, "server2", Duration.ofNanos(100_000L));
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, false, "server3", Duration.ofNanos(200_000L));
+
+ RoutingHint.Builder hint =
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID);
+ KeyRangeCache.RouteLookupResult result =
+ cache.lookupRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint,
+ address -> true);
+
+ assertNotNull(result.endpoint);
+ assertEquals("server2", result.endpoint.getAddress());
+ assertEquals(KeyRangeCache.RouteFailureReason.NONE, result.failureReason);
+ }
+
+ @Test
+ public void lookupRoutingHintReportsNoReadyReplica() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.addRanges(singleReplicaUpdate("server1"));
+ endpointCache.get("server1");
+ endpointCache.setState("server1", EndpointHealthState.IDLE);
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ KeyRangeCache.RouteLookupResult result =
+ cache.lookupRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint,
+ address -> false);
+
+ assertNull(result.endpoint);
+ assertEquals(KeyRangeCache.RouteFailureReason.NO_READY_REPLICA, result.failureReason);
}
@Test
@@ -351,6 +452,29 @@ public void connectingEndpointCausesDefaultHostFallbackWithoutSkippedTablet() {
assertEquals(0, hint.getSkippedTabletUidCount());
}
+ @Test
+ public void excludedEndpointDoesNotAddSkippedTablet() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.addRanges(singleReplicaUpdate("server1"));
+
+ endpointCache.get("server1");
+ endpointCache.setState("server1", EndpointHealthState.READY);
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint,
+ "server1"::equals);
+
+ assertNotNull(server);
+ assertEquals("server1", server.getAddress());
+ assertEquals(0, hint.getSkippedTabletUidCount());
+ }
+
@Test
public void transientFailureEndpointCausesSkippedTabletPlusDefaultHostFallback() {
FakeEndpointCache endpointCache = new FakeEndpointCache();
@@ -481,6 +605,7 @@ public void transientFailureReplicaSkippedAndReadyReplicaSelected() {
public void laterTransientFailureReplicaReportedWhenEarlierReplicaSelected() {
FakeEndpointCache endpointCache = new FakeEndpointCache();
KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
cache.addRanges(threeReplicaUpdate());
endpointCache.get("server1");
@@ -512,6 +637,7 @@ public void laterRecentlyEvictedTransientFailureReplicaReportedWhenEarlierReplic
new RecentTransientFailureLifecycleManager(endpointCache);
try {
KeyRangeCache cache = new KeyRangeCache(endpointCache, lifecycleManager);
+ cache.useDeterministicRandom();
cache.addRanges(threeReplicaUpdate());
endpointCache.get("server1");
@@ -535,6 +661,302 @@ public void laterRecentlyEvictedTransientFailureReplicaReportedWhenEarlierReplic
}
}
+ @Test
+ public void preferLeaderFalseUsesLowestLatencyReplicaWhenScoresAvailable() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, false, "server1", Duration.ofNanos(300_000L));
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, false, "server2", Duration.ofNanos(100_000L));
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, false, "server3", Duration.ofNanos(200_000L));
+
+ RoutingHint.Builder hint =
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID);
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint);
+
+ assertNotNull(server);
+ assertEquals("server2", server.getAddress());
+ }
+
+ @Test
+ public void preferLeaderTrueUsesLatencyScoresWhenOperationUidAvailable() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, true, "server1", Duration.ofNanos(300_000L));
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, true, "server2", Duration.ofNanos(100_000L));
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, true, "server3", Duration.ofNanos(200_000L));
+
+ RoutingHint.Builder hint =
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID);
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ true,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint);
+
+ assertNotNull(server);
+ assertEquals("server2", server.getAddress());
+ }
+
+ @Test
+ public void preferLeaderTrueWithoutOperationUidKeepsLeaderSelection() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(300_000L));
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(100_000L));
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server3", Duration.ofNanos(200_000L));
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ true,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint);
+
+ assertNotNull(server);
+ assertEquals("server1", server.getAddress());
+ }
+
+ @Test
+ public void preferLeaderFalseSkipsBestScoredReplicaWhenItIsNotReady() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+ endpointCache.setState("server2", EndpointHealthState.IDLE);
+
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(300_000L));
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(100_000L));
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server3", Duration.ofNanos(200_000L));
+
+ RoutingHint.Builder hint =
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID);
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint);
+
+ assertNotNull(server);
+ assertEquals("server3", server.getAddress());
+ }
+
+ @Test
+ public void preferLeaderFalseUsesOperationUidScopedScores() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ cache.recordReplicaLatency(201L, "server1", Duration.ofNanos(100_000L));
+ cache.recordReplicaLatency(201L, "server2", Duration.ofNanos(300_000L));
+ cache.recordReplicaLatency(201L, "server3", Duration.ofNanos(200_000L));
+ cache.recordReplicaLatency(202L, "server1", Duration.ofNanos(300_000L));
+ cache.recordReplicaLatency(202L, "server2", Duration.ofNanos(100_000L));
+ cache.recordReplicaLatency(202L, "server3", Duration.ofNanos(200_000L));
+
+ ChannelEndpoint firstOperationServer =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(201L));
+ ChannelEndpoint secondOperationServer =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(202L));
+
+ assertNotNull(firstOperationServer);
+ assertNotNull(secondOperationServer);
+ assertEquals("server1", firstOperationServer.getAddress());
+ assertEquals("server2", secondOperationServer.getAddress());
+ }
+
+ @Test
+ public void preferLeaderFalseUsesInflightCostForColdReplicaSelection() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ endpointCache.get("server1").incrementActiveRequests();
+
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID));
+
+ assertNotNull(server);
+ assertEquals("server2", server.getAddress());
+ }
+
+ @Test
+ public void coldReplicaSelectionEmitsFiniteDefaultCost() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ KeyRangeCache.RouteLookupResult result =
+ cache.lookupRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID),
+ address -> false);
+
+ assertNotNull(result.endpoint);
+ }
+
+ @Test
+ public void preferLeaderFalseInflightCostCanOutweighLowerLatency() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(threeReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+ endpointCache.get("server3");
+
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(100_000L));
+ cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(300_000L));
+ endpointCache.get("server1").incrementActiveRequests();
+ endpointCache.get("server1").incrementActiveRequests();
+ endpointCache.get("server1").incrementActiveRequests();
+
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID));
+
+ assertNotNull(server);
+ assertEquals("server2", server.getAddress());
+ }
+
+ @Test
+ public void preferLeaderFalseErrorPenaltySteersSelectionAwayFromPenalizedReplica() {
+ FakeEndpointCache baselineEndpointCache = new FakeEndpointCache();
+ KeyRangeCache baselineCache = new KeyRangeCache(baselineEndpointCache);
+ baselineCache.useDeterministicRandom();
+ baselineCache.addRanges(threeReplicaUpdate());
+
+ baselineEndpointCache.get("server1");
+ baselineEndpointCache.get("server2");
+ baselineEndpointCache.get("server3");
+
+ ChannelEndpoint baselineServer =
+ baselineCache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID));
+
+ assertNotNull(baselineServer);
+
+ EndpointLatencyRegistry.clear();
+
+ FakeEndpointCache penalizedEndpointCache = new FakeEndpointCache();
+ KeyRangeCache penalizedCache = new KeyRangeCache(penalizedEndpointCache);
+ penalizedCache.useDeterministicRandom();
+ penalizedCache.addRanges(threeReplicaUpdate());
+
+ penalizedEndpointCache.get("server1");
+ penalizedEndpointCache.get("server2");
+ penalizedEndpointCache.get("server3");
+ penalizedCache.recordReplicaError(TEST_OPERATION_UID, baselineServer.getAddress());
+
+ ChannelEndpoint penalizedSelection =
+ penalizedCache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID));
+
+ assertNotNull(penalizedSelection);
+ assertTrue(!baselineServer.getAddress().equals(penalizedSelection.getAddress()));
+ }
+
+ @Test
+ public void preferLeaderFalseIgnoresPreferLeaderTrueScoreBucket() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+ KeyRangeCache cache = new KeyRangeCache(endpointCache);
+ cache.useDeterministicRandom();
+ cache.addRanges(twoReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+
+ EndpointLatencyRegistry.recordLatency(
+ null, TEST_OPERATION_UID, true, "server2", Duration.ofMillis(1));
+
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID));
+
+ assertNotNull(server);
+ assertEquals("server1", server.getAddress());
+ }
+
// --- Eviction and recreation tests ---
@Test
@@ -826,6 +1248,7 @@ void setHealthy(String address, boolean healthy) {
static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final FakeManagedChannel channel = new FakeManagedChannel();
+ private final AtomicInteger activeRequests = new AtomicInteger();
private EndpointHealthState state = EndpointHealthState.READY;
FakeEndpoint(String address) {
@@ -852,6 +1275,21 @@ public ManagedChannel getChannel() {
return channel;
}
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
+
void setState(EndpointHealthState state) {
this.state = state;
channel.setConnectivityState(toConnectivityState(state));
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java
index 424efb363df6..9e7c7159bd77 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java
@@ -24,6 +24,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -33,6 +34,7 @@ public class PowerOfTwoReplicaSelectorTest {
private static class TestEndpoint implements ChannelEndpoint {
private final String address;
+ private final AtomicInteger activeRequests = new AtomicInteger();
TestEndpoint(String address) {
this.address = address;
@@ -57,6 +59,21 @@ public boolean isTransientFailure() {
public io.grpc.ManagedChannel getChannel() {
return null;
}
+
+ @Override
+ public void incrementActiveRequests() {
+ activeRequests.incrementAndGet();
+ }
+
+ @Override
+ public void decrementActiveRequests() {
+ activeRequests.updateAndGet(current -> current > 0 ? current - 1 : 0);
+ }
+
+ @Override
+ public int getActiveRequestCount() {
+ return Math.max(0, activeRequests.get());
+ }
}
@Test
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java
index 7ac5faf2e16e..05b773ad4f60 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java
@@ -16,6 +16,7 @@
package com.google.cloud.spanner.spi.v1;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
@@ -23,11 +24,13 @@
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.MockSpannerServiceImpl;
+import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.cloud.spanner.Options;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.TimestampBound;
import com.google.common.base.Stopwatch;
import com.google.protobuf.ByteString;
import com.google.protobuf.ListValue;
@@ -37,6 +40,7 @@
import com.google.spanner.v1.DirectedReadOptions;
import com.google.spanner.v1.DirectedReadOptions.IncludeReplicas;
import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection;
+import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.Group;
import com.google.spanner.v1.Range;
import com.google.spanner.v1.ReadRequest;
@@ -85,6 +89,8 @@ public class ReplicaSelectionMockServerTest {
.build())
.setMetadata(SELECT1_METADATA)
.build();
+ private static final String QUERY_SQL = "SELECT * FROM Table WHERE Column = @p1";
+ private static final String QUERY_PARAM = "p1";
private List servers;
private final int numServers = 2;
@@ -149,6 +155,8 @@ public void onCompleted() {
@After
public void tearDown() throws InterruptedException {
+ EndpointLatencyRegistry.clear();
+ RequestIdTargetTracker.clear();
for (ServerInstance si : servers) {
si.server.shutdown();
}
@@ -330,4 +338,508 @@ public void testEndToEndWithSpannerOptions() throws Exception {
server1ReceivedSuccessfulRead);
}
}
+
+ @Test
+ public void testStaleSingleUseReadBootstrapsScoresAndConvergesToLowerLatencyReplica()
+ throws Exception {
+ SpannerOptions options =
+ SpannerOptions.newBuilder()
+ .usePlainText()
+ .setExperimentalHost("localhost:" + servers.get(0).port)
+ .setProjectId("fake-project")
+ .setChannelEndpointCacheFactory(null)
+ .build();
+
+ RecipeList.Builder recipeListBuilder = RecipeList.newBuilder();
+ try {
+ TextFormat.merge(
+ "recipe {\n"
+ + " table_name: \"Table\"\n"
+ + " part { tag: 1 }\n"
+ + " part {\n"
+ + " order: ASCENDING\n"
+ + " null_order: NULLS_FIRST\n"
+ + " type { code: STRING }\n"
+ + " }\n"
+ + "}\n",
+ recipeListBuilder);
+ } catch (TextFormat.ParseException e) {
+ throw new RuntimeException(e);
+ }
+
+ CacheUpdate cacheUpdate =
+ CacheUpdate.newBuilder()
+ .setDatabaseId(12345L)
+ .setKeyRecipes(recipeListBuilder.build())
+ .addGroup(
+ Group.newBuilder()
+ .setGroupUid(1L)
+ .addTablets(
+ Tablet.newBuilder()
+ .setTabletUid(101L)
+ .setServerAddress("localhost:" + servers.get(0).port)
+ .setRole(Tablet.Role.READ_ONLY)
+ .setDistance(0)
+ .build())
+ .addTablets(
+ Tablet.newBuilder()
+ .setTabletUid(102L)
+ .setServerAddress("localhost:" + servers.get(1).port)
+ .setRole(Tablet.Role.READ_ONLY)
+ .setDistance(0)
+ .build())
+ .build())
+ .addRange(
+ Range.newBuilder()
+ .setStartKey(ByteString.EMPTY)
+ .setLimitKey(ByteString.copyFromUtf8("\u00FF"))
+ .setGroupUid(1L)
+ .setSplitId(1L)
+ .setGeneration(ByteString.copyFromUtf8("gen1"))
+ .build())
+ .build();
+
+ ResultSet resultSetWithUpdate =
+ SELECT1_RESULTSET.toBuilder().setCacheUpdate(cacheUpdate).build();
+
+ servers
+ .get(0)
+ .mockSpanner
+ .putStatementResult(StatementResult.query(Statement.of("SELECT 1"), resultSetWithUpdate));
+
+ com.google.cloud.spanner.Statement readStatement =
+ StatementResult.createReadStatement(
+ "Table",
+ com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of()),
+ Arrays.asList("Column"));
+
+ servers
+ .get(0)
+ .mockSpanner
+ .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET));
+ servers
+ .get(1)
+ .mockSpanner
+ .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET));
+ servers
+ .get(0)
+ .mockSpanner
+ .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(40, 0));
+ servers
+ .get(1)
+ .mockSpanner
+ .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(0, 0));
+
+ try (Spanner spanner = options.getService()) {
+ DatabaseClient client =
+ spanner.getDatabaseClient(
+ DatabaseId.of("fake-project", "fake-instance", "fake-database"));
+
+ try (com.google.cloud.spanner.ResultSet rs =
+ client.singleUse().executeQuery(Statement.of("SELECT 1"))) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+
+ clearServerRequests();
+ long operationUid = 0L;
+
+ for (int attempt = 1; attempt <= 3; attempt++) {
+ String key = "bootstrap-key-" + attempt;
+ try (com.google.cloud.spanner.ResultSet rs =
+ client
+ .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS))
+ .read(
+ "Table",
+ com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)),
+ Arrays.asList("Column"))) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+
+ long currentOperationUid = findReadOperationUid(key);
+ assertTrue("Expected stale read to carry operation_uid", currentOperationUid > 0L);
+ if (operationUid == 0L) {
+ operationUid = currentOperationUid;
+ } else {
+ assertEquals(
+ "Expected stale reads to reuse the same operation_uid",
+ operationUid,
+ currentOperationUid);
+ }
+ }
+
+ assertTrue("Expected stale reads to reuse the same operation_uid", operationUid > 0L);
+
+ clearServerRequests();
+ Stopwatch watch = Stopwatch.createStarted();
+ boolean routedToLowerLatencyReplica = false;
+ int convergenceAttempt = 0;
+ while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) {
+ convergenceAttempt++;
+ String key = "convergence-key-" + convergenceAttempt;
+ try (com.google.cloud.spanner.ResultSet rs =
+ client
+ .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS))
+ .read(
+ "Table",
+ com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)),
+ Arrays.asList("Column"))) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+
+ routedToLowerLatencyReplica =
+ hasReadRequestForKey(servers.get(1).mockSpanner, key)
+ && !hasReadRequestForKey(servers.get(0).mockSpanner, key);
+ }
+
+ assertTrue(
+ "Expected latency-aware routing to converge to the faster replica",
+ routedToLowerLatencyReplica);
+ }
+ }
+
+ @Test
+ public void testStrongSingleUseReadConvergesToLowerLatencyReplica() throws Exception {
+ SpannerOptions options =
+ SpannerOptions.newBuilder()
+ .usePlainText()
+ .setExperimentalHost("localhost:" + servers.get(0).port)
+ .setProjectId("fake-project")
+ .setChannelEndpointCacheFactory(null)
+ .build();
+
+ ResultSet resultSetWithUpdate =
+ SELECT1_RESULTSET.toBuilder()
+ .setCacheUpdate(createReplicaCacheUpdate(readRecipeList()))
+ .build();
+
+ servers
+ .get(0)
+ .mockSpanner
+ .putStatementResult(StatementResult.query(Statement.of("SELECT 1"), resultSetWithUpdate));
+
+ com.google.cloud.spanner.Statement readStatement =
+ StatementResult.createReadStatement(
+ "Table",
+ com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of()),
+ Arrays.asList("Column"));
+
+ servers
+ .get(0)
+ .mockSpanner
+ .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET));
+ servers
+ .get(1)
+ .mockSpanner
+ .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET));
+ servers
+ .get(0)
+ .mockSpanner
+ .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(40, 0));
+ servers
+ .get(1)
+ .mockSpanner
+ .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(0, 0));
+
+ try (Spanner spanner = options.getService()) {
+ DatabaseClient client =
+ spanner.getDatabaseClient(
+ DatabaseId.of("fake-project", "fake-instance", "fake-database"));
+
+ try (com.google.cloud.spanner.ResultSet rs =
+ client.singleUse().executeQuery(Statement.of("SELECT 1"))) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+
+ clearServerRequests();
+ long operationUid = 0L;
+
+ for (int attempt = 1; attempt <= 3; attempt++) {
+ String key = "strong-read-bootstrap-" + attempt;
+ try (com.google.cloud.spanner.ResultSet rs =
+ client
+ .singleUse()
+ .read(
+ "Table",
+ com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)),
+ Arrays.asList("Column"))) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+
+ long currentOperationUid = findReadOperationUid(key);
+ assertTrue("Expected strong read to carry operation_uid", currentOperationUid > 0L);
+ if (operationUid == 0L) {
+ operationUid = currentOperationUid;
+ } else {
+ assertEquals(
+ "Expected strong reads to reuse the same operation_uid",
+ operationUid,
+ currentOperationUid);
+ }
+ }
+
+ assertTrue("Expected strong reads to reuse the same operation_uid", operationUid > 0L);
+
+ clearServerRequests();
+ Stopwatch watch = Stopwatch.createStarted();
+ boolean routedToLowerLatencyReplica = false;
+ int convergenceAttempt = 0;
+ while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) {
+ convergenceAttempt++;
+ String key = "strong-read-convergence-" + convergenceAttempt;
+ try (com.google.cloud.spanner.ResultSet rs =
+ client
+ .singleUse()
+ .read(
+ "Table",
+ com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)),
+ Arrays.asList("Column"))) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+
+ routedToLowerLatencyReplica =
+ hasReadRequestForKey(servers.get(1).mockSpanner, key)
+ && !hasReadRequestForKey(servers.get(0).mockSpanner, key);
+ }
+
+ assertTrue(
+ "Expected strong read routing to converge to the faster replica",
+ routedToLowerLatencyReplica);
+ }
+ }
+
+ @Test
+ public void testStrongSingleUseQueryConvergesToLowerLatencyReplica() throws Exception {
+ SpannerOptions options =
+ SpannerOptions.newBuilder()
+ .usePlainText()
+ .setExperimentalHost("localhost:" + servers.get(0).port)
+ .setProjectId("fake-project")
+ .setChannelEndpointCacheFactory(null)
+ .build();
+
+ servers
+ .get(0)
+ .mockSpanner
+ .setExecuteStreamingSqlExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(40, 0));
+ servers
+ .get(1)
+ .mockSpanner
+ .setExecuteStreamingSqlExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(0, 0));
+
+ try (Spanner spanner = options.getService()) {
+ DatabaseClient client =
+ spanner.getDatabaseClient(
+ DatabaseId.of("fake-project", "fake-instance", "fake-database"));
+ assertStrongQueryConvergesToLowerLatencyReplica(
+ statement -> {
+ try (com.google.cloud.spanner.ResultSet rs =
+ client.singleUse().executeQuery(statement)) {
+ while (rs.next()) {
+ /* consume */
+ }
+ }
+ });
+ }
+ }
+
+ @FunctionalInterface
+ private interface QueryExecutor {
+ void execute(Statement statement) throws Exception;
+ }
+
+ private void assertStrongQueryConvergesToLowerLatencyReplica(QueryExecutor queryExecutor)
+ throws Exception {
+ String seedKey = "query-seed";
+ installQueryResultOnAllServers(seedKey, SELECT1_RESULTSET);
+
+ queryExecutor.execute(queryStatement(seedKey));
+ long operationUid = findQueryOperationUid(seedKey);
+ assertTrue("Expected strong query to carry operation_uid", operationUid > 0L);
+
+ installQueryResultOnAllServers(
+ seedKey,
+ SELECT1_RESULTSET.toBuilder()
+ .setCacheUpdate(createReplicaCacheUpdate(queryRecipeList(operationUid)))
+ .build());
+ queryExecutor.execute(queryStatement(seedKey));
+ clearServerRequests();
+
+ for (int attempt = 1; attempt <= 3; attempt++) {
+ String key = "strong-query-bootstrap-" + attempt;
+ installQueryResultOnAllServers(key, SELECT1_RESULTSET);
+ queryExecutor.execute(queryStatement(key));
+
+ long currentOperationUid = findQueryOperationUid(key);
+ assertEquals(
+ "Expected strong queries to reuse the same operation_uid",
+ operationUid,
+ currentOperationUid);
+ }
+
+ clearServerRequests();
+ Stopwatch watch = Stopwatch.createStarted();
+ boolean routedToLowerLatencyReplica = false;
+ int convergenceAttempt = 0;
+ while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) {
+ convergenceAttempt++;
+ String key = "strong-query-convergence-" + convergenceAttempt;
+ installQueryResultOnAllServers(key, SELECT1_RESULTSET);
+ queryExecutor.execute(queryStatement(key));
+
+ routedToLowerLatencyReplica =
+ hasQueryRequestForKey(servers.get(1).mockSpanner, key)
+ && !hasQueryRequestForKey(servers.get(0).mockSpanner, key);
+ }
+
+ assertTrue(
+ "Expected strong query routing to converge to the faster replica",
+ routedToLowerLatencyReplica);
+ }
+
+ private void clearServerRequests() {
+ for (ServerInstance server : servers) {
+ server.mockSpanner.clearRequests();
+ }
+ }
+
+ private CacheUpdate createReplicaCacheUpdate(RecipeList keyRecipes) {
+ return CacheUpdate.newBuilder()
+ .setDatabaseId(12345L)
+ .setKeyRecipes(keyRecipes)
+ .addGroup(
+ Group.newBuilder()
+ .setGroupUid(1L)
+ .setLeaderIndex(0)
+ .addTablets(
+ Tablet.newBuilder()
+ .setTabletUid(101L)
+ .setServerAddress("localhost:" + servers.get(0).port)
+ .setRole(Tablet.Role.READ_ONLY)
+ .setDistance(0)
+ .build())
+ .addTablets(
+ Tablet.newBuilder()
+ .setTabletUid(102L)
+ .setServerAddress("localhost:" + servers.get(1).port)
+ .setRole(Tablet.Role.READ_ONLY)
+ .setDistance(0)
+ .build())
+ .build())
+ .addRange(
+ Range.newBuilder()
+ .setStartKey(ByteString.EMPTY)
+ .setLimitKey(ByteString.copyFromUtf8("\u00FF"))
+ .setGroupUid(1L)
+ .setSplitId(1L)
+ .setGeneration(ByteString.copyFromUtf8("gen1"))
+ .build())
+ .build();
+ }
+
+ private RecipeList readRecipeList() throws TextFormat.ParseException {
+ RecipeList.Builder recipeListBuilder = RecipeList.newBuilder();
+ TextFormat.merge(
+ "recipe {\n"
+ + " table_name: \"Table\"\n"
+ + " part { tag: 1 }\n"
+ + " part {\n"
+ + " order: ASCENDING\n"
+ + " null_order: NULLS_FIRST\n"
+ + " type { code: STRING }\n"
+ + " identifier: \"k\"\n"
+ + " }\n"
+ + "}\n",
+ recipeListBuilder);
+ return recipeListBuilder.build();
+ }
+
+ private RecipeList queryRecipeList(long operationUid) throws TextFormat.ParseException {
+ RecipeList.Builder recipeListBuilder = RecipeList.newBuilder();
+ TextFormat.merge(
+ "recipe {\n"
+ + " operation_uid: "
+ + operationUid
+ + "\n"
+ + " part { tag: 1 }\n"
+ + " part {\n"
+ + " order: ASCENDING\n"
+ + " null_order: NULLS_FIRST\n"
+ + " type { code: STRING }\n"
+ + " identifier: \""
+ + QUERY_PARAM
+ + "\"\n"
+ + " }\n"
+ + "}\n",
+ recipeListBuilder);
+ return recipeListBuilder.build();
+ }
+
+ private Statement queryStatement(String key) {
+ return Statement.newBuilder(QUERY_SQL).bind(QUERY_PARAM).to(key).build();
+ }
+
+ private void installQueryResultOnAllServers(String key, ResultSet resultSet) {
+ Statement statement = queryStatement(key);
+ for (ServerInstance server : servers) {
+ server.mockSpanner.putStatementResult(StatementResult.query(statement, resultSet));
+ }
+ }
+
+ private long findReadOperationUid(String key) {
+ for (ServerInstance server : servers) {
+ for (ReadRequest request : server.mockSpanner.getRequestsOfType(ReadRequest.class)) {
+ if (request.getKeySet().getKeysCount() == 0
+ || request.getKeySet().getKeys(0).getValuesCount() == 0) {
+ continue;
+ }
+ if (key.equals(request.getKeySet().getKeys(0).getValues(0).getStringValue())) {
+ return request.getRoutingHint().getOperationUid();
+ }
+ }
+ }
+ return 0L;
+ }
+
+ private long findQueryOperationUid(String key) {
+ for (ServerInstance server : servers) {
+ for (ExecuteSqlRequest request :
+ server.mockSpanner.getRequestsOfType(ExecuteSqlRequest.class)) {
+ if (request.getParams().getFieldsMap().containsKey(QUERY_PARAM)
+ && key.equals(request.getParams().getFieldsOrThrow(QUERY_PARAM).getStringValue())) {
+ return request.getRoutingHint().getOperationUid();
+ }
+ }
+ }
+ return 0L;
+ }
+
+ private boolean hasReadRequestForKey(MockSpannerServiceImpl mockSpanner, String key) {
+ return mockSpanner.getRequestsOfType(ReadRequest.class).stream()
+ .anyMatch(
+ request ->
+ request.getKeySet().getKeysCount() > 0
+ && request.getKeySet().getKeys(0).getValuesCount() > 0
+ && key.equals(request.getKeySet().getKeys(0).getValues(0).getStringValue()));
+ }
+
+ private boolean hasQueryRequestForKey(MockSpannerServiceImpl mockSpanner, String key) {
+ return mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream()
+ .anyMatch(
+ request ->
+ request.getParams().getFieldsMap().containsKey(QUERY_PARAM)
+ && key.equals(
+ request.getParams().getFieldsOrThrow(QUERY_PARAM).getStringValue()));
+ }
}