Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Add `QeueueBundle.Remove` to remove an already added queue/producer. [PR #1235](https://github.com/riverqueue/river/pull/1235) and [PR #1240](https://github.com/riverqueue/river/pull/1240).

### Fixed

- Fix unsafe concurrent producer map access in client. [PR #1236](https://github.com/riverqueue/river/pull/1236).
Expand Down
44 changes: 26 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
}

client.queues = &QueueBundle{
addProducer: client.addProducer,
removeProducer: client.removeProducer,
clientFetchCooldown: config.FetchCooldown,
clientFetchPollInterval: config.FetchPollInterval,
clientWillExecuteJobs: config.willExecuteJobs(),
producerAdd: client.producerAdd,
producerRemove: client.producerRemove,
}

baseservice.Init(archetype, &client.baseService)
Expand Down Expand Up @@ -879,7 +879,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
client.services = append(client.services, client.elector)

for queue, queueConfig := range config.Queues {
if _, err := client.addProducer(queue, queueConfig); err != nil {
if _, err := client.producerAdd(queue, queueConfig); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -2177,7 +2177,7 @@ func (c *Client[TTx]) validateJobArgs(args JobArgs) error {
return nil
}

func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) (*producer, error) {
func (c *Client[TTx]) producerAdd(queueName string, queueConfig QueueConfig) (*producer, error) {
c.producersMu.Lock()
defer c.producersMu.Unlock()

Expand Down Expand Up @@ -2210,7 +2210,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) (*p
return producer, nil
}

func (c *Client[TTx]) removeProducer(queueName string) error {
func (c *Client[TTx]) producerRemove(ctx context.Context, queueName string) error {
c.producersMu.Lock()
defer c.producersMu.Unlock()

Expand All @@ -2219,7 +2219,17 @@ func (c *Client[TTx]) removeProducer(queueName string) error {
return &QueueNotFoundError{Name: queueName}
}

producer.Stop()
shouldStop, stopped, finalizeStop := producer.StopInit()
if shouldStop {
select {
case <-ctx.Done():
finalizeStop(false)
return ctx.Err()
case <-stopped:
finalizeStop(true)
}
}

delete(c.producersByQueueName, queueName)

return nil
Expand Down Expand Up @@ -2812,17 +2822,14 @@ func (c *Client[TTx]) Schema() string { return c.config.Schema }
// QueueBundle is a bundle for adding additional queues. It's made accessible
// through Client.Queues.
type QueueBundle struct {
// Function that adds a producer to the associated client.
addProducer func(queueName string, queueConfig QueueConfig) (*producer, error)

removeProducer func(queueName string) error

clientFetchCooldown time.Duration
clientFetchPollInterval time.Duration

clientWillExecuteJobs bool

fetchCtx context.Context //nolint:containedctx
fetchCtx context.Context //nolint:containedctx
producerAdd func(queueName string, queueConfig QueueConfig) (*producer, error) // add producer to associated client
producerRemove func(ctx context.Context, queueName string) error // remove producer from associated client

// Mutex that's acquired when client is starting and stopping and when a
// queue is being added so that we can be sure that a client is fully
Expand All @@ -2847,7 +2854,7 @@ func (b *QueueBundle) Add(queueName string, queueConfig QueueConfig) error {
b.startStopMu.Lock()
defer b.startStopMu.Unlock()

producer, err := b.addProducer(queueName, queueConfig)
producer, err := b.producerAdd(queueName, queueConfig)
if err != nil {
return err
}
Expand All @@ -2863,21 +2870,22 @@ func (b *QueueBundle) Add(queueName string, queueConfig QueueConfig) error {
}

// Remove removes a queue from the client, stopping the producer if the client
// is running. The function will block until all jobs currently being worked in
// the queue have completed. This blocking behavior may affect other operations,
// including shutdown timing.
// is running. It waits for any jobs currently being worked in the queue to
// complete before returning. If the provided context is done before the
// producer has stopped, Remove returns the context's error and does not remove
// the queue.
//
// Returns an error if the client is not configured to execute jobs or if the
// specified queue does not exist.
func (b *QueueBundle) Remove(queueName string) error {
func (b *QueueBundle) Remove(ctx context.Context, queueName string) error {
if !b.clientWillExecuteJobs {
return errors.New("client is not configured to execute jobs, cannot remove queue")
}

b.startStopMu.Lock()
defer b.startStopMu.Unlock()

return b.removeProducer(queueName)
return b.producerRemove(ctx, queueName)
}

// Generates a default client ID using the current hostname and time.
Expand Down
86 changes: 79 additions & 7 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,34 @@ func Test_Client_Common(t *testing.T) {
wg.Wait()
})

t.Run("Queues_Remove_Stress", func(t *testing.T) {
t.Parallel()

client, _ := setup(t)

startClient(ctx, t, client)
riversharedtest.WaitOrTimeout(t, client.baseStartStop.Started())

var wg sync.WaitGroup

for i := range 5 {
wg.Add(1)
workerNum := i
go func() {
defer wg.Done()

for j := range 5 {
queueName := fmt.Sprintf("stress_queue_%d_%d", workerNum, j)

require.NoError(t, client.Queues().Add(queueName, QueueConfig{MaxWorkers: 1}))
require.NoError(t, client.Queues().Remove(ctx, queueName))
}
}()
}

wg.Wait()
})

t.Run("Queues_Remove_BeforeStart", func(t *testing.T) {
t.Parallel()

Expand All @@ -427,7 +455,7 @@ func Test_Client_Common(t *testing.T) {
})
require.NoError(t, err)

err = client.Queues().Remove(queueName)
err = client.Queues().Remove(ctx, queueName)
require.NoError(t, err)

startClient(ctx, t, client)
Expand Down Expand Up @@ -481,7 +509,7 @@ func Test_Client_Common(t *testing.T) {
event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)

err = client.Queues().Remove(queueName)
err = client.Queues().Remove(ctx, queueName)
require.NoError(t, err)

insertRes, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{
Expand All @@ -502,12 +530,56 @@ func Test_Client_Common(t *testing.T) {
require.Equal(t, rivertype.JobStateAvailable, job.State)
})

t.Run("Queues_Remove_ContextDone", func(t *testing.T) {
t.Parallel()

client, _ := setup(t)

type JobArgs struct {
testutil.JobArgsReflectKind[JobArgs]
}

jobStartedChan := make(chan struct{})
AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error {
close(jobStartedChan)
<-ctx.Done()
return nil
}))

queueName := "remove_context_done_queue"
require.NoError(t, client.Queues().Add(queueName, QueueConfig{MaxWorkers: 2}))

startClient(ctx, t, client)
riversharedtest.WaitOrTimeout(t, client.baseStartStop.Started())

_, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{Queue: queueName})
require.NoError(t, err)

riversharedtest.WaitOrTimeout(t, jobStartedChan)

// Remove with an already-cancelled context should return immediately
// without removing the queue.
cancelledCtx, cancel := context.WithCancel(ctx)
cancel()

err = client.Queues().Remove(cancelledCtx, queueName)
require.ErrorIs(t, err, context.Canceled)

// Queue should still exist and be functional since Remove bailed out.
// Verify by successfully removing it with a valid context after
// cancelling the job via StopAndCancel.
require.NoError(t, client.StopAndCancel(ctx))

// Re-start so startClient's cleanup Stop doesn't fail.
require.NoError(t, client.Start(ctx))
})

t.Run("Queues_Remove_NonExistentQueue", func(t *testing.T) {
t.Parallel()

client, _ := setup(t)

err := client.Queues().Remove("non_existent_queue")
err := client.Queues().Remove(ctx, "non_existent_queue")
require.Error(t, err)
var queueNotFoundErr *QueueNotFoundError
require.ErrorAs(t, err, &queueNotFoundErr)
Expand All @@ -522,7 +594,7 @@ func Test_Client_Common(t *testing.T) {
config.Workers = nil
client := newTestClient(t, bundle.dbPool, config)

err := client.Queues().Remove("any_queue")
err := client.Queues().Remove(ctx, "any_queue")
require.Error(t, err)
require.Contains(t, err.Error(), "client is not configured to execute jobs, cannot remove queue")
})
Expand Down Expand Up @@ -551,7 +623,7 @@ func Test_Client_Common(t *testing.T) {
event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)

err = client.Queues().Remove(QueueDefault)
err = client.Queues().Remove(ctx, QueueDefault)
require.NoError(t, err)

insertRes, err := client.Insert(ctx, &JobArgs{}, nil)
Expand Down Expand Up @@ -601,7 +673,7 @@ func Test_Client_Common(t *testing.T) {
event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)

err = client.Queues().Remove(queueName)
err = client.Queues().Remove(ctx, queueName)
require.NoError(t, err)

err = client.Queues().Add(queueName, QueueConfig{
Expand Down Expand Up @@ -634,7 +706,7 @@ func Test_Client_Common(t *testing.T) {
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertRes1.Job.ID, event.Job.ID)

require.NoError(t, client.Queues().Remove("test_queue"))
require.NoError(t, client.Queues().Remove(ctx, "test_queue"))

insertRes2, err := client.Insert(ctx, &noOpArgs{}, &InsertOpts{Queue: "test_queue"})
require.NoError(t, err)
Expand Down
Loading