Skip to content

fix: add ollama and nvidia embedding#8104

Open
counhopig wants to merge 2 commits into
AstrBotDevs:masterfrom
counhopig:feature/add-provider
Open

fix: add ollama and nvidia embedding#8104
counhopig wants to merge 2 commits into
AstrBotDevs:masterfrom
counhopig:feature/add-provider

Conversation

@counhopig
Copy link
Copy Markdown

@counhopig counhopig commented May 9, 2026

This PR adds two new Embedding providers to extend AstrBot's knowledge base vectorization capabilities:

  1. NVIDIA NIM Embedding — supports NVIDIA's high-performance Embedding API (e.g., NV-Embed-QA, llama-nemotron-embed-1b-v2). Users can leverage NVIDIA's free developer tier for faster vectorization, or connect to privately deployed NIM microservices.
  2. Ollama Embedding — supports locally self-hosted embedding models via Ollama, giving users a fully offline, privacy-preserving vectorization option.
    Closes [Feature]在 Embedding provider 中添加对 Nvidia Embedding 的支持 #7829.

Modifications / 改动点

Core files modified:

  • astrbot/core/provider/sources/nvidia_embedding_source.py (new, 140 lines)

    • NvidiaEmbeddingProvider class implementing EmbeddingProvider interface
    • Supports configurable api_key, base_url, model, input_type, timeout, proxy
    • Handles both single and batch embedding via get_embedding() / get_embeddings()
    • Parses NVIDIA API response format (OpenAI-compatible embedding structure)
  • astrbot/core/provider/sources/ollama_embedding_source.py (new, 120 lines)

    • OllamaEmbeddingProvider class implementing EmbeddingProvider interface
    • Supports configurable base_url, model, embedding_dimensions, timeout, proxy
    • Communicates with Ollama's /api/embed endpoint
  • astrbot/core/provider/manager.py (+8 lines)

    • Registers both providers in the lazy import dispatch (match/case block)
  • astrbot/core/config/default.py (+28 lines)

    • Default provider configurations for both NVIDIA Embedding and Ollama Embedding
  • This is NOT a breaking change. / 这不是一个破坏性变更。

Screenshots or Test Results / 运行截图或测试结果

image image

Checklist / 检查清单

  • 😊 If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
    / 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。

  • 👀 My changes have been well-tested, and "Verification Steps" and "Screenshots" have been provided above.
    / 我的更改经过了良好的测试,并已在上方提供了“验证步骤”和“运行截图”

  • 🤓 I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in requirements.txt and pyproject.toml.
    / 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到 requirements.txtpyproject.toml 文件相应位置。

  • 😮 My changes do not introduce malicious code.
    / 我的更改没有引入恶意代码。

Summary by Sourcery

Add new NVIDIA NIM and Ollama embedding providers and wire them into the provider manager and default configuration.

New Features:

  • Introduce a NVIDIA NIM embedding provider for vectorizing text via NVIDIA's embedding API.
  • Introduce an Ollama embedding provider for generating embeddings from locally hosted models.

Enhancements:

  • Register the new NVIDIA and Ollama embedding providers in the dynamic provider manager for lazy loading.
  • Extend default configuration with presets for NVIDIA and Ollama embedding providers, including model, endpoint, and timeout settings.

Copilot AI review requested due to automatic review settings May 9, 2026 10:45
@auto-assign auto-assign Bot requested review from LIghtJUNction and anka-afk May 9, 2026 10:45
@dosubot dosubot Bot added the size:L This PR changes 100-499 lines, ignoring generated files. label May 9, 2026
@dosubot dosubot Bot added area:provider The bug / feature is about AI Provider, Models, LLM Agent, LLM Agent Runner. feature:knowledge-base The bug / feature is about knowledge base labels May 9, 2026
Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 6 issues, and left some high level feedback:

  • In NvidiaEmbeddingProvider._get_client the if self.proxy branch currently creates the same ClientSession as the else branch and never configures the proxy at session level, so either wire the proxy into the session (e.g., via a connector) or simplify the condition to avoid dead code.
  • Both providers consistently raise bare Exception for network and API errors; consider using more specific exception types or a shared provider error class so callers can distinguish between network issues, API errors, and logical errors more easily.
  • There is repeated logic between the two embedding providers (client/session setup, get_dim, error handling); factoring this into shared helpers or a small base mixin would reduce duplication and make future provider additions easier to maintain.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `NvidiaEmbeddingProvider._get_client` the `if self.proxy` branch currently creates the same `ClientSession` as the `else` branch and never configures the proxy at session level, so either wire the proxy into the session (e.g., via a connector) or simplify the condition to avoid dead code.
- Both providers consistently raise bare `Exception` for network and API errors; consider using more specific exception types or a shared provider error class so callers can distinguish between network issues, API errors, and logical errors more easily.
- There is repeated logic between the two embedding providers (client/session setup, `get_dim`, error handling); factoring this into shared helpers or a small base mixin would reduce duplication and make future provider additions easier to maintain.

## Individual Comments

### Comment 1
<location path="astrbot/core/provider/sources/nvidia_embedding_source.py" line_range="44-61" />
<code_context>
+                "Content-Type": "application/json",
+                "Accept": "application/json",
+            }
+            timeout = aiohttp.ClientTimeout(total=self.timeout)
+            if self.proxy:
+                self.client = aiohttp.ClientSession(
+                    headers=headers,
</code_context>
<issue_to_address>
**suggestion:** The proxy-specific branch in `_get_client` is redundant and can be simplified.

In `_get_client`, both the `if self.proxy:` and `else:` branches build `aiohttp.ClientSession` identically, and the proxy is only applied when calling `client.post(...)`. You can collapse this into a single `ClientSession` creation without the conditional to keep the logic simpler and prevent the two paths from drifting apart over time.

```suggestion
        if self.client is None or self.client.closed:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
                "Accept": "application/json",
            }
            timeout = aiohttp.ClientTimeout(total=self.timeout)
            self.client = aiohttp.ClientSession(
                headers=headers,
                timeout=timeout,
            )
        return self.client
```
</issue_to_address>

### Comment 2
<location path="astrbot/core/provider/sources/nvidia_embedding_source.py" line_range="119-124" />
<code_context>
+
+                return embeddings
+
+        except aiohttp.ClientError as e:
+            logger.error(f"[NVIDIA Embedding] Network error: {e}")
+            raise Exception(f"Network error: {e}") from e
+        except Exception as e:
+            logger.error(f"[NVIDIA Embedding] Error: {e}")
+            raise
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Consider preserving the original traceback or using more specific exception types for error propagation.

In the `aiohttp.ClientError` branch you wrap the error in a generic `Exception`, which hides the original type from callers. Instead, either re-raise the original error after logging (`raise` with no args) or raise a dedicated custom exception that callers can use to distinguish network failures from other provider/API errors.

```suggestion
        except aiohttp.ClientError as e:
            logger.error(f"[NVIDIA Embedding] Network error: {e}")
            # Preserve the original exception type and traceback
            raise
        except Exception as e:
            # Log unexpected errors with full traceback for easier debugging
            logger.error(f"[NVIDIA Embedding] Error: {e}", exc_info=True)
            raise
```
</issue_to_address>

### Comment 3
<location path="astrbot/core/provider/sources/ollama_embedding_source.py" line_range="55-62" />
<code_context>
+            raise
+
+    def get_dim(self) -> int:
+        if "embedding_dimensions" in self.provider_config:
+            try:
+                return int(self.provider_config["embedding_dimensions"])
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Invalid `embedding_dimensions` values are silently ignored; consider logging or surfacing the issue.

In `_build_payload`, invalid `embedding_dimensions` values are swallowed by `except (ValueError, TypeError): pass`, which hides config errors. Instead, surface this with at least a warning (as in the NVIDIA provider) so operators know their setting is being ignored.
</issue_to_address>

### Comment 4
<location path="astrbot/core/provider/sources/ollama_embedding_source.py" line_range="99-101" />
<code_context>
+
+                return embeddings
+
+        except aiohttp.ClientError as e:
+            logger.error(f"[NVIDIA Embedding] Network error: {e}")
+            raise Exception(f"Network error: {e}") from e
</code_context>
<issue_to_address>
**suggestion:** Align error handling with NVIDIA provider and avoid wrapping `ClientError` in a generic `Exception`.

In this branch we lose the original `aiohttp.ClientError` type by wrapping it in `Exception`. After logging, either re-raise `e` directly or wrap it in a dedicated provider-specific exception so callers can distinguish network failures from other error types.

```suggestion
        except aiohttp.ClientError as e:
            logger.error(f"[Ollama Embedding] Network error: {e}")
            raise
```
</issue_to_address>

### Comment 5
<location path="astrbot/core/provider/sources/nvidia_embedding_source.py" line_range="15" />
<code_context>
+    "NVIDIA NIM Embedding 提供商适配器",
+    provider_type=ProviderType.EMBEDDING,
+)
+class NvidiaEmbeddingProvider(EmbeddingProvider):
+    def __init__(self, provider_config: dict, provider_settings: dict) -> None:
+        super().__init__(provider_config, provider_settings)
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying NvidiaEmbeddingProvider by removing redundant state/branches, centralizing shared logic, and using helper methods as the single source of truth to keep the class focused on NVIDIA-specific behavior.

A few small changes would reduce complexity without changing behavior:

1. **Remove redundant branching in `_get_client` and duplicate state**

You’re already calling `super().__init__(provider_config, provider_settings)`, so if the base class stores these, you don’t need local copies. Also the `if self.proxy` branch is identical.

```python
class NvidiaEmbeddingProvider(EmbeddingProvider):
    def __init__(self, provider_config: dict, provider_settings: dict) -> None:
        super().__init__(provider_config, provider_settings)

        # If EmbeddingProvider already exposes these, don't shadow them:
        # self.provider_config = provider_config
        # self.provider_settings = provider_settings

        self.api_key = provider_config.get("embedding_api_key", "")
        self.base_url = (
            provider_config.get(
                "embedding_api_base", "https://integrate.api.nvidia.com/v1"
            )
            .rstrip("/")
            .removesuffix("/embeddings")
        )
        self.timeout = int(provider_config.get("timeout", 20))
        self.model = provider_config.get(
            "embedding_model", "nvidia/llama-nemotron-embed-1b-v2"
        )
        self.input_type = provider_config.get("input_type", "passage")

        self.proxy = provider_config.get("proxy", "")
        if self.proxy:
            logger.info(f"[NVIDIA Embedding] Using proxy: {self.proxy}")

        self.client = None
        self.set_model(self.model)

    async def _get_client(self) -> aiohttp.ClientSession:
        if self.client is None or self.client.closed:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
                "Accept": "application/json",
            }
            timeout = aiohttp.ClientTimeout(total=self.timeout)
            self.client = aiohttp.ClientSession(
                headers=headers,
                timeout=timeout,
            )
        return self.client
```

2. **Rely on `_get_client` as the single source of truth**

Since `_get_client` guarantees a live session (or raises), you don’t need to re‑check `client`/`client.closed` in `get_embeddings`:

```python
    async def get_embeddings(self, text: list[str]) -> list[list[float]]:
        client = await self._get_client()

        payload = self._build_payload(text)
        request_url = f"{self.base_url}/embeddings"

        try:
            async with client.post(
                request_url, json=payload, proxy=self.proxy or None
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    logger.error(
                        f"[NVIDIA Embedding] API Error: {response.status} - {error_text}"
                    )
                    raise Exception(
                        f"NVIDIA Embedding API request failed: HTTP {response.status} - {error_text}"
                    )

                response_data = await response.json()
                embeddings = self._parse_response(response_data)

                usage = response_data.get("usage", {})
                total_tokens = usage.get("total_tokens", 0)
                if total_tokens > 0:
                    logger.debug(
                        f"[NVIDIA Embedding] Token usage: {total_tokens}"
                    )

                return embeddings

        except aiohttp.ClientError as e:
            logger.error(f"[NVIDIA Embedding] Network error: {e}")
            raise Exception(f"Network error: {e}") from e
        except Exception as e:
            logger.error(f"[NVIDIA Embedding] Error: {e}")
            raise
```

3. **Simplify response parsing**

The loop in `_parse_response` can be made more concise:

```python
    def _parse_response(self, response_data: dict) -> list[list[float]]:
        return [
            item.get("embedding", [])
            for item in response_data.get("data", [])
        ]
```

4. **Consider centralizing `get_dim` and client lifecycle**

If other embedding providers implement the same `get_dim` and `terminate` patterns, it may be worth moving them into `EmbeddingProvider` (or a mixin), so this provider only overrides behavior that’s truly specific to NVIDIA:

```python
# In EmbeddingProvider (conceptual sketch):
class EmbeddingProvider:
    def get_dim(self) -> int:
        if "embedding_dimensions" in self.provider_config:
            try:
                return int(self.provider_config["embedding_dimensions"])
            except (ValueError, TypeError):
                logger.warning(
                    "embedding_dimensions in embedding configs is not a valid "
                    f"integer: '{self.provider_config['embedding_dimensions']}', ignored."
                )
        return 0

    async def terminate(self):
        if getattr(self, "client", None) and not self.client.closed:
            await self.client.close()
            self.client = None
```

Then `NvidiaEmbeddingProvider` can drop its local `get_dim` and `terminate` if they’re identical.
</issue_to_address>

### Comment 6
<location path="astrbot/core/provider/sources/ollama_embedding_source.py" line_range="50" />
<code_context>
+                )
+        return self.client
+
+    def _build_payload(self, text: str | list[str]) -> dict:
+        if isinstance(text, str):
+            input_text = [text]
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying the Ollama embedding provider by tightening `_build_payload`'s type, removing redundant client checks, and factoring shared client/dimension logic into the base `EmbeddingProvider` to reduce duplication.

- You can narrow `_build_payload` to the type you actually use and drop the unused union, which simplifies the method and its signature:

  ```python
-    def _build_payload(self, text: str | list[str]) -> dict:
+    def _build_payload(self, text: list[str]) -> dict:
         payload = {
             "model": self.model,
             "input": text,
         }
  ```

- `_get_client()` already guarantees a valid session, so the extra defensive check in `get_embeddings` is redundant and can be removed without changing behavior:

  ```python
-        client = await self._get_client()
-        if not client or client.closed:
-            raise Exception("[Ollama Embedding] Client session not initialized")
+        client = await self._get_client()
  ```

- To avoid duplicating lifecycle and dimension logic across providers (e.g., NVIDIA and Ollama), consider promoting the common pieces into `EmbeddingProvider` and letting concrete providers just supply config and headers. For example:

  ```python
  # in EmbeddingProvider
  class EmbeddingProvider:
      def __init__(self, provider_config: dict, provider_settings: dict) -> None:
          self.provider_config = provider_config
          self.provider_settings = provider_settings
          self.client = None

      async def _ensure_client(self, *, headers: dict, timeout_s: int):
          if self.client is None or self.client.closed:
              timeout = aiohttp.ClientTimeout(total=timeout_s)
              self.client = aiohttp.ClientSession(headers=headers, timeout=timeout)
          return self.client

      def get_dim(self) -> int:
          raw = self.provider_config.get("embedding_dimensions")
          if raw is None:
              return 0
          try:
              return int(raw)
          except (ValueError, TypeError):
              logger.warning(
                  f"embedding_dimensions in embedding configs is not a valid integer: "
                  f"'{raw}', ignored."
              )
              return 0

      async def terminate(self):
          if self.client and not self.client.closed:
              await self.client.close()
              self.client = None
  ```

  Then Ollama-specific code becomes thinner and avoids duplication:

  ```python
  class OllamaEmbeddingProvider(EmbeddingProvider):
      ...

      async def _get_client(self):
          headers = {
              "Content-Type": "application/json",
              "Accept": "application/json",
          }
          return await self._ensure_client(headers=headers, timeout_s=self.timeout)

-     def get_dim(self) -> int:
-         ...
-
-     async def terminate(self):
-         ...
+     # Inherit get_dim() and terminate() from EmbeddingProvider
  ```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread astrbot/core/provider/sources/nvidia_embedding_source.py
Comment thread astrbot/core/provider/sources/nvidia_embedding_source.py
Comment on lines +55 to +62
if "embedding_dimensions" in self.provider_config:
try:
dimensions = int(self.provider_config["embedding_dimensions"])
if dimensions > 0:
payload["dimensions"] = dimensions
except (ValueError, TypeError):
pass
return payload
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Invalid embedding_dimensions values are silently ignored; consider logging or surfacing the issue.

In _build_payload, invalid embedding_dimensions values are swallowed by except (ValueError, TypeError): pass, which hides config errors. Instead, surface this with at least a warning (as in the NVIDIA provider) so operators know their setting is being ignored.

Comment thread astrbot/core/provider/sources/ollama_embedding_source.py Outdated
"NVIDIA NIM Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class NvidiaEmbeddingProvider(EmbeddingProvider):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider simplifying NvidiaEmbeddingProvider by removing redundant state/branches, centralizing shared logic, and using helper methods as the single source of truth to keep the class focused on NVIDIA-specific behavior.

A few small changes would reduce complexity without changing behavior:

  1. Remove redundant branching in _get_client and duplicate state

You’re already calling super().__init__(provider_config, provider_settings), so if the base class stores these, you don’t need local copies. Also the if self.proxy branch is identical.

class NvidiaEmbeddingProvider(EmbeddingProvider):
    def __init__(self, provider_config: dict, provider_settings: dict) -> None:
        super().__init__(provider_config, provider_settings)

        # If EmbeddingProvider already exposes these, don't shadow them:
        # self.provider_config = provider_config
        # self.provider_settings = provider_settings

        self.api_key = provider_config.get("embedding_api_key", "")
        self.base_url = (
            provider_config.get(
                "embedding_api_base", "https://integrate.api.nvidia.com/v1"
            )
            .rstrip("/")
            .removesuffix("/embeddings")
        )
        self.timeout = int(provider_config.get("timeout", 20))
        self.model = provider_config.get(
            "embedding_model", "nvidia/llama-nemotron-embed-1b-v2"
        )
        self.input_type = provider_config.get("input_type", "passage")

        self.proxy = provider_config.get("proxy", "")
        if self.proxy:
            logger.info(f"[NVIDIA Embedding] Using proxy: {self.proxy}")

        self.client = None
        self.set_model(self.model)

    async def _get_client(self) -> aiohttp.ClientSession:
        if self.client is None or self.client.closed:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
                "Accept": "application/json",
            }
            timeout = aiohttp.ClientTimeout(total=self.timeout)
            self.client = aiohttp.ClientSession(
                headers=headers,
                timeout=timeout,
            )
        return self.client
  1. Rely on _get_client as the single source of truth

Since _get_client guarantees a live session (or raises), you don’t need to re‑check client/client.closed in get_embeddings:

    async def get_embeddings(self, text: list[str]) -> list[list[float]]:
        client = await self._get_client()

        payload = self._build_payload(text)
        request_url = f"{self.base_url}/embeddings"

        try:
            async with client.post(
                request_url, json=payload, proxy=self.proxy or None
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    logger.error(
                        f"[NVIDIA Embedding] API Error: {response.status} - {error_text}"
                    )
                    raise Exception(
                        f"NVIDIA Embedding API request failed: HTTP {response.status} - {error_text}"
                    )

                response_data = await response.json()
                embeddings = self._parse_response(response_data)

                usage = response_data.get("usage", {})
                total_tokens = usage.get("total_tokens", 0)
                if total_tokens > 0:
                    logger.debug(
                        f"[NVIDIA Embedding] Token usage: {total_tokens}"
                    )

                return embeddings

        except aiohttp.ClientError as e:
            logger.error(f"[NVIDIA Embedding] Network error: {e}")
            raise Exception(f"Network error: {e}") from e
        except Exception as e:
            logger.error(f"[NVIDIA Embedding] Error: {e}")
            raise
  1. Simplify response parsing

The loop in _parse_response can be made more concise:

    def _parse_response(self, response_data: dict) -> list[list[float]]:
        return [
            item.get("embedding", [])
            for item in response_data.get("data", [])
        ]
  1. Consider centralizing get_dim and client lifecycle

If other embedding providers implement the same get_dim and terminate patterns, it may be worth moving them into EmbeddingProvider (or a mixin), so this provider only overrides behavior that’s truly specific to NVIDIA:

# In EmbeddingProvider (conceptual sketch):
class EmbeddingProvider:
    def get_dim(self) -> int:
        if "embedding_dimensions" in self.provider_config:
            try:
                return int(self.provider_config["embedding_dimensions"])
            except (ValueError, TypeError):
                logger.warning(
                    "embedding_dimensions in embedding configs is not a valid "
                    f"integer: '{self.provider_config['embedding_dimensions']}', ignored."
                )
        return 0

    async def terminate(self):
        if getattr(self, "client", None) and not self.client.closed:
            await self.client.close()
            self.client = None

Then NvidiaEmbeddingProvider can drop its local get_dim and terminate if they’re identical.

)
return self.client

def _build_payload(self, text: str | list[str]) -> dict:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider simplifying the Ollama embedding provider by tightening _build_payload's type, removing redundant client checks, and factoring shared client/dimension logic into the base EmbeddingProvider to reduce duplication.

  • You can narrow _build_payload to the type you actually use and drop the unused union, which simplifies the method and its signature:

  • def _build_payload(self, text: str | list[str]) -> dict:

  • def _build_payload(self, text: list[str]) -> dict:
    payload = {
    "model": self.model,
    "input": text,
    }

- `_get_client()` already guarantees a valid session, so the extra defensive check in `get_embeddings` is redundant and can be removed without changing behavior:

```python
-        client = await self._get_client()
-        if not client or client.closed:
-            raise Exception("[Ollama Embedding] Client session not initialized")
+        client = await self._get_client()
  • To avoid duplicating lifecycle and dimension logic across providers (e.g., NVIDIA and Ollama), consider promoting the common pieces into EmbeddingProvider and letting concrete providers just supply config and headers. For example:

    # in EmbeddingProvider
    class EmbeddingProvider:
        def __init__(self, provider_config: dict, provider_settings: dict) -> None:
            self.provider_config = provider_config
            self.provider_settings = provider_settings
            self.client = None
    
        async def _ensure_client(self, *, headers: dict, timeout_s: int):
            if self.client is None or self.client.closed:
                timeout = aiohttp.ClientTimeout(total=timeout_s)
                self.client = aiohttp.ClientSession(headers=headers, timeout=timeout)
            return self.client
    
        def get_dim(self) -> int:
            raw = self.provider_config.get("embedding_dimensions")
            if raw is None:
                return 0
            try:
                return int(raw)
            except (ValueError, TypeError):
                logger.warning(
                    f"embedding_dimensions in embedding configs is not a valid integer: "
                    f"'{raw}', ignored."
                )
                return 0
    
        async def terminate(self):
            if self.client and not self.client.closed:
                await self.client.close()
                self.client = None

    Then Ollama-specific code becomes thinner and avoids duplication:

    class OllamaEmbeddingProvider(EmbeddingProvider):
        ...
    
        async def _get_client(self):
            headers = {
                "Content-Type": "application/json",
                "Accept": "application/json",
            }
            return await self._ensure_client(headers=headers, timeout_s=self.timeout)
  • def get_dim(self) -> int:
    
  •     ...
    
  • async def terminate(self):
    
  •     ...
    
  • # Inherit get_dim() and terminate() from EmbeddingProvider
    

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for NVIDIA and Ollama embedding providers, including their default configurations and adapter implementations. The review identifies several instances of redundant code, specifically in the handling of proxy settings and client session verification within the new provider source files.

Comment on lines +51 to +60
if self.proxy:
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
else:
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if self.proxy block is redundant here because the aiohttp.ClientSession initialization is identical in both branches. The proxy is correctly handled later in the post request call on line 98.

            self.client = aiohttp.ClientSession(
                headers=headers,
                timeout=timeout,
            )

Comment on lines +89 to +91
client = await self._get_client()
if not client or client.closed:
raise Exception("[NVIDIA Embedding] Client session not initialized")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is redundant because _get_client() ensures that self.client is initialized and open before returning it. If it fails to initialize, it would have raised an exception within _get_client or during the ClientSession creation.

Suggested change
client = await self._get_client()
if not client or client.closed:
raise Exception("[NVIDIA Embedding] Client session not initialized")
client = await self._get_client()

Comment on lines +69 to +71
client = await self._get_client()
if not client or client.closed:
raise Exception("[Ollama Embedding] Client session not initialized")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is redundant because _get_client() ensures that self.client is initialized and open before returning it.

Suggested change
client = await self._get_client()
if not client or client.closed:
raise Exception("[Ollama Embedding] Client session not initialized")
client = await self._get_client()

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends AstrBot’s knowledge-base vectorization options by adding two new Embedding providers (NVIDIA NIM Embedding and Ollama Embedding) and wiring them into the provider manager and default configuration.

Changes:

  • Added NvidiaEmbeddingProvider (aiohttp-based) for NVIDIA’s OpenAI-compatible /embeddings API.
  • Added OllamaEmbeddingProvider (aiohttp-based) targeting Ollama’s /api/embed endpoint.
  • Registered both providers for lazy import in ProviderManager and added default configs.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
astrbot/core/provider/sources/nvidia_embedding_source.py New NVIDIA embedding provider implementation (request building, response parsing, session management).
astrbot/core/provider/sources/ollama_embedding_source.py New Ollama embedding provider implementation (payload building, request/response handling, session management).
astrbot/core/provider/manager.py Adds lazy-import cases so the new provider types can be instantiated by config.
astrbot/core/config/default.py Adds default provider entries for NVIDIA Embedding and Ollama Embedding.

Comment on lines +109 to +117
response_data = await response.json()
embeddings = self._parse_response(response_data)

usage = response_data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(f"[NVIDIA Embedding] Token usage: {total_tokens}")

return embeddings
Comment on lines +51 to +60
if self.proxy:
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
else:
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
if dimensions > 0:
payload["dimensions"] = dimensions
except (ValueError, TypeError):
pass
"type": "nvidia_embedding",
"provider": "nvidia",
"provider_type": "embedding",
"hint": "provider_group.provider.nvidia_embedding.hint",
"type": "ollama_embedding",
"provider": "ollama",
"provider_type": "embedding",
"hint": "provider_group.provider.ollama_embedding.hint",
 - Remove redundant proxy branch in NvidiaEmbeddingProvider._get_client

 - Change ClientError handling to re-raise instead of wrapping in Exception

 - Add exc_info=True for better error diagnostics

 - Remove redundant isinstance check in OllamaEmbeddingProvider._build_payload
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

Comment on lines +37 to +48
async def _get_client(self):
if self.client is None or self.client.closed:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
timeout = aiohttp.ClientTimeout(total=self.timeout)
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
return self.client
Comment on lines +43 to +55
async def _get_client(self):
if self.client is None or self.client.closed:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
timeout = aiohttp.ClientTimeout(total=self.timeout)
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
return self.client
Comment on lines +46 to +49
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
Comment on lines +1806 to +1824
"provider": "nvidia",
"provider_type": "embedding",
"hint": "provider_group.provider.nvidia_embedding.hint",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "https://integrate.api.nvidia.com/v1",
"embedding_model": "nvidia/llama-nemotron-embed-1b-v2",
"input_type": "passage",
"embedding_dimensions": 1024,
"timeout": 20,
"proxy": "",
},
"Ollama Embedding": {
"id": "ollama_embedding",
"type": "ollama_embedding",
"provider": "ollama",
"provider_type": "embedding",
"hint": "provider_group.provider.ollama_embedding.hint",
"enable": True,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:provider The bug / feature is about AI Provider, Models, LLM Agent, LLM Agent Runner. feature:knowledge-base The bug / feature is about knowledge base size:L This PR changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]在 Embedding provider 中添加对 Nvidia Embedding 的支持

2 participants