Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Prediction(Resource):
version: str
"""An identifier for the version of the model used to create the prediction."""

status: Literal["starting", "processing", "succeeded", "failed", "canceled"]
status: Literal["starting", "processing", "succeeded", "failed", "canceled", "aborted"]
"""The status of the prediction."""

input: Optional[Dict[str, Any]]
Expand Down Expand Up @@ -141,7 +141,7 @@ def wait(self) -> None:
Wait for prediction to finish.
"""

while self.status not in ["succeeded", "failed", "canceled"]:
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
time.sleep(self._client.poll_interval)
self.reload()

Expand All @@ -150,7 +150,7 @@ async def async_wait(self) -> None:
Wait for prediction to finish asynchronously.
"""

while self.status not in ["succeeded", "failed", "canceled"]:
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
await asyncio.sleep(self._client.poll_interval)
await self.async_reload()

Expand Down Expand Up @@ -251,15 +251,15 @@ def output_iterator(self) -> Iterator[Any]:

# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed", "canceled"]:
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
output = self.output or []
new_output = output[len(previous_output) :]
yield from new_output
previous_output = output
time.sleep(self._client.poll_interval) # pylint: disable=no-member
self.reload()

if self.status == "failed":
if self.status in ("failed", "aborted"):
raise ModelError(self)

output = self.output or []
Expand All @@ -273,7 +273,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:

# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed", "canceled"]:
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
output = self.output or []
new_output = output[len(previous_output) :]
for item in new_output:
Expand All @@ -282,7 +282,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member
await self.async_reload()

if self.status == "failed":
if self.status in ("failed", "aborted"):
raise ModelError(self)

output = self.output or []
Expand Down
4 changes: 2 additions & 2 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run(

prediction.wait()

if prediction.status == "failed":
if prediction.status in ("failed", "aborted"):
raise ModelError(prediction)

# Return an iterator for the completed prediction when needed.
Expand Down Expand Up @@ -147,7 +147,7 @@ async def async_run(

await prediction.async_wait()

if prediction.status == "failed":
if prediction.status in ("failed", "aborted"):
raise ModelError(prediction)

# Return an iterator for completed output if the model has an output iterator array type.
Expand Down
85 changes: 85 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,3 +1076,88 @@ def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None)
},
},
}


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_raises_on_aborted_prediction(async_flag, mock_replicate_api_token):
"""
Regression test: an 'aborted' prediction (server-side termination) must surface as
ModelError and must NOT cause wait() / async_wait() to poll forever.

Before the fix, 'aborted' was not in the terminal-state list, so wait() looped
until the test timed out (issue #431) and run() silently returned None output.
"""
router = respx.Router(base_url="https://api.replicate.com/v1")
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("starting"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
return_value=httpx.Response(
200,
json={**_prediction_with_status("aborted"), "error": "Prediction was aborted"},
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
201,
json=_version_with_schema(),
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
client.poll_interval = 0.001

with pytest.raises(ModelError) as excinfo:
if async_flag:
await client.async_run("test/example:v1", input={"text": "Hello, world!"})
else:
client.run("test/example:v1", input={"text": "Hello, world!"})

assert excinfo.value.prediction.status == "aborted"


@pytest.mark.asyncio
async def test_prediction_wait_terminates_on_aborted(mock_replicate_api_token):
"""
Regression test: Prediction.wait() and async_wait() must exit immediately when
a prediction transitions to 'aborted', not loop forever.
"""
import replicate
from replicate.prediction import Prediction

router = respx.Router(base_url="https://api.replicate.com/v1")
router.route(method="GET", path="/predictions/p1").mock(
return_value=httpx.Response(
200,
json={**_prediction_with_status("aborted"), "error": "aborted by server"},
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
client.poll_interval = 0.001

prediction = Prediction(**_prediction_with_status("processing"))
prediction._client = client

# wait() must return (not loop forever) when status flips to "aborted"
prediction.wait()
assert prediction.status == "aborted"

# Reset and verify async_wait() also exits
prediction2 = Prediction(**_prediction_with_status("processing"))
prediction2._client = client
await prediction2.async_wait()
assert prediction2.status == "aborted"