fix: add ollama and nvidia embedding#8104
Conversation
There was a problem hiding this comment.
Hey - I've found 6 issues, and left some high level feedback:
- In
NvidiaEmbeddingProvider._get_clienttheif self.proxybranch currently creates the sameClientSessionas theelsebranch and never configures the proxy at session level, so either wire the proxy into the session (e.g., via a connector) or simplify the condition to avoid dead code. - Both providers consistently raise bare
Exceptionfor network and API errors; consider using more specific exception types or a shared provider error class so callers can distinguish between network issues, API errors, and logical errors more easily. - There is repeated logic between the two embedding providers (client/session setup,
get_dim, error handling); factoring this into shared helpers or a small base mixin would reduce duplication and make future provider additions easier to maintain.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `NvidiaEmbeddingProvider._get_client` the `if self.proxy` branch currently creates the same `ClientSession` as the `else` branch and never configures the proxy at session level, so either wire the proxy into the session (e.g., via a connector) or simplify the condition to avoid dead code.
- Both providers consistently raise bare `Exception` for network and API errors; consider using more specific exception types or a shared provider error class so callers can distinguish between network issues, API errors, and logical errors more easily.
- There is repeated logic between the two embedding providers (client/session setup, `get_dim`, error handling); factoring this into shared helpers or a small base mixin would reduce duplication and make future provider additions easier to maintain.
## Individual Comments
### Comment 1
<location path="astrbot/core/provider/sources/nvidia_embedding_source.py" line_range="44-61" />
<code_context>
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ }
+ timeout = aiohttp.ClientTimeout(total=self.timeout)
+ if self.proxy:
+ self.client = aiohttp.ClientSession(
+ headers=headers,
</code_context>
<issue_to_address>
**suggestion:** The proxy-specific branch in `_get_client` is redundant and can be simplified.
In `_get_client`, both the `if self.proxy:` and `else:` branches build `aiohttp.ClientSession` identically, and the proxy is only applied when calling `client.post(...)`. You can collapse this into a single `ClientSession` creation without the conditional to keep the logic simpler and prevent the two paths from drifting apart over time.
```suggestion
if self.client is None or self.client.closed:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
timeout = aiohttp.ClientTimeout(total=self.timeout)
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
return self.client
```
</issue_to_address>
### Comment 2
<location path="astrbot/core/provider/sources/nvidia_embedding_source.py" line_range="119-124" />
<code_context>
+
+ return embeddings
+
+ except aiohttp.ClientError as e:
+ logger.error(f"[NVIDIA Embedding] Network error: {e}")
+ raise Exception(f"Network error: {e}") from e
+ except Exception as e:
+ logger.error(f"[NVIDIA Embedding] Error: {e}")
+ raise
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Consider preserving the original traceback or using more specific exception types for error propagation.
In the `aiohttp.ClientError` branch you wrap the error in a generic `Exception`, which hides the original type from callers. Instead, either re-raise the original error after logging (`raise` with no args) or raise a dedicated custom exception that callers can use to distinguish network failures from other provider/API errors.
```suggestion
except aiohttp.ClientError as e:
logger.error(f"[NVIDIA Embedding] Network error: {e}")
# Preserve the original exception type and traceback
raise
except Exception as e:
# Log unexpected errors with full traceback for easier debugging
logger.error(f"[NVIDIA Embedding] Error: {e}", exc_info=True)
raise
```
</issue_to_address>
### Comment 3
<location path="astrbot/core/provider/sources/ollama_embedding_source.py" line_range="55-62" />
<code_context>
+ raise
+
+ def get_dim(self) -> int:
+ if "embedding_dimensions" in self.provider_config:
+ try:
+ return int(self.provider_config["embedding_dimensions"])
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Invalid `embedding_dimensions` values are silently ignored; consider logging or surfacing the issue.
In `_build_payload`, invalid `embedding_dimensions` values are swallowed by `except (ValueError, TypeError): pass`, which hides config errors. Instead, surface this with at least a warning (as in the NVIDIA provider) so operators know their setting is being ignored.
</issue_to_address>
### Comment 4
<location path="astrbot/core/provider/sources/ollama_embedding_source.py" line_range="99-101" />
<code_context>
+
+ return embeddings
+
+ except aiohttp.ClientError as e:
+ logger.error(f"[NVIDIA Embedding] Network error: {e}")
+ raise Exception(f"Network error: {e}") from e
</code_context>
<issue_to_address>
**suggestion:** Align error handling with NVIDIA provider and avoid wrapping `ClientError` in a generic `Exception`.
In this branch we lose the original `aiohttp.ClientError` type by wrapping it in `Exception`. After logging, either re-raise `e` directly or wrap it in a dedicated provider-specific exception so callers can distinguish network failures from other error types.
```suggestion
except aiohttp.ClientError as e:
logger.error(f"[Ollama Embedding] Network error: {e}")
raise
```
</issue_to_address>
### Comment 5
<location path="astrbot/core/provider/sources/nvidia_embedding_source.py" line_range="15" />
<code_context>
+ "NVIDIA NIM Embedding 提供商适配器",
+ provider_type=ProviderType.EMBEDDING,
+)
+class NvidiaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
+ super().__init__(provider_config, provider_settings)
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying NvidiaEmbeddingProvider by removing redundant state/branches, centralizing shared logic, and using helper methods as the single source of truth to keep the class focused on NVIDIA-specific behavior.
A few small changes would reduce complexity without changing behavior:
1. **Remove redundant branching in `_get_client` and duplicate state**
You’re already calling `super().__init__(provider_config, provider_settings)`, so if the base class stores these, you don’t need local copies. Also the `if self.proxy` branch is identical.
```python
class NvidiaEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
# If EmbeddingProvider already exposes these, don't shadow them:
# self.provider_config = provider_config
# self.provider_settings = provider_settings
self.api_key = provider_config.get("embedding_api_key", "")
self.base_url = (
provider_config.get(
"embedding_api_base", "https://integrate.api.nvidia.com/v1"
)
.rstrip("/")
.removesuffix("/embeddings")
)
self.timeout = int(provider_config.get("timeout", 20))
self.model = provider_config.get(
"embedding_model", "nvidia/llama-nemotron-embed-1b-v2"
)
self.input_type = provider_config.get("input_type", "passage")
self.proxy = provider_config.get("proxy", "")
if self.proxy:
logger.info(f"[NVIDIA Embedding] Using proxy: {self.proxy}")
self.client = None
self.set_model(self.model)
async def _get_client(self) -> aiohttp.ClientSession:
if self.client is None or self.client.closed:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
timeout = aiohttp.ClientTimeout(total=self.timeout)
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
return self.client
```
2. **Rely on `_get_client` as the single source of truth**
Since `_get_client` guarantees a live session (or raises), you don’t need to re‑check `client`/`client.closed` in `get_embeddings`:
```python
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
client = await self._get_client()
payload = self._build_payload(text)
request_url = f"{self.base_url}/embeddings"
try:
async with client.post(
request_url, json=payload, proxy=self.proxy or None
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(
f"[NVIDIA Embedding] API Error: {response.status} - {error_text}"
)
raise Exception(
f"NVIDIA Embedding API request failed: HTTP {response.status} - {error_text}"
)
response_data = await response.json()
embeddings = self._parse_response(response_data)
usage = response_data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(
f"[NVIDIA Embedding] Token usage: {total_tokens}"
)
return embeddings
except aiohttp.ClientError as e:
logger.error(f"[NVIDIA Embedding] Network error: {e}")
raise Exception(f"Network error: {e}") from e
except Exception as e:
logger.error(f"[NVIDIA Embedding] Error: {e}")
raise
```
3. **Simplify response parsing**
The loop in `_parse_response` can be made more concise:
```python
def _parse_response(self, response_data: dict) -> list[list[float]]:
return [
item.get("embedding", [])
for item in response_data.get("data", [])
]
```
4. **Consider centralizing `get_dim` and client lifecycle**
If other embedding providers implement the same `get_dim` and `terminate` patterns, it may be worth moving them into `EmbeddingProvider` (or a mixin), so this provider only overrides behavior that’s truly specific to NVIDIA:
```python
# In EmbeddingProvider (conceptual sketch):
class EmbeddingProvider:
def get_dim(self) -> int:
if "embedding_dimensions" in self.provider_config:
try:
return int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
"embedding_dimensions in embedding configs is not a valid "
f"integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
return 0
async def terminate(self):
if getattr(self, "client", None) and not self.client.closed:
await self.client.close()
self.client = None
```
Then `NvidiaEmbeddingProvider` can drop its local `get_dim` and `terminate` if they’re identical.
</issue_to_address>
### Comment 6
<location path="astrbot/core/provider/sources/ollama_embedding_source.py" line_range="50" />
<code_context>
+ )
+ return self.client
+
+ def _build_payload(self, text: str | list[str]) -> dict:
+ if isinstance(text, str):
+ input_text = [text]
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying the Ollama embedding provider by tightening `_build_payload`'s type, removing redundant client checks, and factoring shared client/dimension logic into the base `EmbeddingProvider` to reduce duplication.
- You can narrow `_build_payload` to the type you actually use and drop the unused union, which simplifies the method and its signature:
```python
- def _build_payload(self, text: str | list[str]) -> dict:
+ def _build_payload(self, text: list[str]) -> dict:
payload = {
"model": self.model,
"input": text,
}
```
- `_get_client()` already guarantees a valid session, so the extra defensive check in `get_embeddings` is redundant and can be removed without changing behavior:
```python
- client = await self._get_client()
- if not client or client.closed:
- raise Exception("[Ollama Embedding] Client session not initialized")
+ client = await self._get_client()
```
- To avoid duplicating lifecycle and dimension logic across providers (e.g., NVIDIA and Ollama), consider promoting the common pieces into `EmbeddingProvider` and letting concrete providers just supply config and headers. For example:
```python
# in EmbeddingProvider
class EmbeddingProvider:
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
self.provider_config = provider_config
self.provider_settings = provider_settings
self.client = None
async def _ensure_client(self, *, headers: dict, timeout_s: int):
if self.client is None or self.client.closed:
timeout = aiohttp.ClientTimeout(total=timeout_s)
self.client = aiohttp.ClientSession(headers=headers, timeout=timeout)
return self.client
def get_dim(self) -> int:
raw = self.provider_config.get("embedding_dimensions")
if raw is None:
return 0
try:
return int(raw)
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: "
f"'{raw}', ignored."
)
return 0
async def terminate(self):
if self.client and not self.client.closed:
await self.client.close()
self.client = None
```
Then Ollama-specific code becomes thinner and avoids duplication:
```python
class OllamaEmbeddingProvider(EmbeddingProvider):
...
async def _get_client(self):
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
return await self._ensure_client(headers=headers, timeout_s=self.timeout)
- def get_dim(self) -> int:
- ...
-
- async def terminate(self):
- ...
+ # Inherit get_dim() and terminate() from EmbeddingProvider
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| if "embedding_dimensions" in self.provider_config: | ||
| try: | ||
| dimensions = int(self.provider_config["embedding_dimensions"]) | ||
| if dimensions > 0: | ||
| payload["dimensions"] = dimensions | ||
| except (ValueError, TypeError): | ||
| pass | ||
| return payload |
There was a problem hiding this comment.
suggestion (bug_risk): Invalid embedding_dimensions values are silently ignored; consider logging or surfacing the issue.
In _build_payload, invalid embedding_dimensions values are swallowed by except (ValueError, TypeError): pass, which hides config errors. Instead, surface this with at least a warning (as in the NVIDIA provider) so operators know their setting is being ignored.
| "NVIDIA NIM Embedding 提供商适配器", | ||
| provider_type=ProviderType.EMBEDDING, | ||
| ) | ||
| class NvidiaEmbeddingProvider(EmbeddingProvider): |
There was a problem hiding this comment.
issue (complexity): Consider simplifying NvidiaEmbeddingProvider by removing redundant state/branches, centralizing shared logic, and using helper methods as the single source of truth to keep the class focused on NVIDIA-specific behavior.
A few small changes would reduce complexity without changing behavior:
- Remove redundant branching in
_get_clientand duplicate state
You’re already calling super().__init__(provider_config, provider_settings), so if the base class stores these, you don’t need local copies. Also the if self.proxy branch is identical.
class NvidiaEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
# If EmbeddingProvider already exposes these, don't shadow them:
# self.provider_config = provider_config
# self.provider_settings = provider_settings
self.api_key = provider_config.get("embedding_api_key", "")
self.base_url = (
provider_config.get(
"embedding_api_base", "https://integrate.api.nvidia.com/v1"
)
.rstrip("/")
.removesuffix("/embeddings")
)
self.timeout = int(provider_config.get("timeout", 20))
self.model = provider_config.get(
"embedding_model", "nvidia/llama-nemotron-embed-1b-v2"
)
self.input_type = provider_config.get("input_type", "passage")
self.proxy = provider_config.get("proxy", "")
if self.proxy:
logger.info(f"[NVIDIA Embedding] Using proxy: {self.proxy}")
self.client = None
self.set_model(self.model)
async def _get_client(self) -> aiohttp.ClientSession:
if self.client is None or self.client.closed:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
timeout = aiohttp.ClientTimeout(total=self.timeout)
self.client = aiohttp.ClientSession(
headers=headers,
timeout=timeout,
)
return self.client- Rely on
_get_clientas the single source of truth
Since _get_client guarantees a live session (or raises), you don’t need to re‑check client/client.closed in get_embeddings:
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
client = await self._get_client()
payload = self._build_payload(text)
request_url = f"{self.base_url}/embeddings"
try:
async with client.post(
request_url, json=payload, proxy=self.proxy or None
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(
f"[NVIDIA Embedding] API Error: {response.status} - {error_text}"
)
raise Exception(
f"NVIDIA Embedding API request failed: HTTP {response.status} - {error_text}"
)
response_data = await response.json()
embeddings = self._parse_response(response_data)
usage = response_data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(
f"[NVIDIA Embedding] Token usage: {total_tokens}"
)
return embeddings
except aiohttp.ClientError as e:
logger.error(f"[NVIDIA Embedding] Network error: {e}")
raise Exception(f"Network error: {e}") from e
except Exception as e:
logger.error(f"[NVIDIA Embedding] Error: {e}")
raise- Simplify response parsing
The loop in _parse_response can be made more concise:
def _parse_response(self, response_data: dict) -> list[list[float]]:
return [
item.get("embedding", [])
for item in response_data.get("data", [])
]- Consider centralizing
get_dimand client lifecycle
If other embedding providers implement the same get_dim and terminate patterns, it may be worth moving them into EmbeddingProvider (or a mixin), so this provider only overrides behavior that’s truly specific to NVIDIA:
# In EmbeddingProvider (conceptual sketch):
class EmbeddingProvider:
def get_dim(self) -> int:
if "embedding_dimensions" in self.provider_config:
try:
return int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
"embedding_dimensions in embedding configs is not a valid "
f"integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
return 0
async def terminate(self):
if getattr(self, "client", None) and not self.client.closed:
await self.client.close()
self.client = NoneThen NvidiaEmbeddingProvider can drop its local get_dim and terminate if they’re identical.
| ) | ||
| return self.client | ||
|
|
||
| def _build_payload(self, text: str | list[str]) -> dict: |
There was a problem hiding this comment.
issue (complexity): Consider simplifying the Ollama embedding provider by tightening _build_payload's type, removing redundant client checks, and factoring shared client/dimension logic into the base EmbeddingProvider to reduce duplication.
-
You can narrow
_build_payloadto the type you actually use and drop the unused union, which simplifies the method and its signature: -
def _build_payload(self, text: str | list[str]) -> dict:
- def _build_payload(self, text: list[str]) -> dict:
payload = {
"model": self.model,
"input": text,
}
- `_get_client()` already guarantees a valid session, so the extra defensive check in `get_embeddings` is redundant and can be removed without changing behavior:
```python
- client = await self._get_client()
- if not client or client.closed:
- raise Exception("[Ollama Embedding] Client session not initialized")
+ client = await self._get_client()
-
To avoid duplicating lifecycle and dimension logic across providers (e.g., NVIDIA and Ollama), consider promoting the common pieces into
EmbeddingProviderand letting concrete providers just supply config and headers. For example:# in EmbeddingProvider class EmbeddingProvider: def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.provider_config = provider_config self.provider_settings = provider_settings self.client = None async def _ensure_client(self, *, headers: dict, timeout_s: int): if self.client is None or self.client.closed: timeout = aiohttp.ClientTimeout(total=timeout_s) self.client = aiohttp.ClientSession(headers=headers, timeout=timeout) return self.client def get_dim(self) -> int: raw = self.provider_config.get("embedding_dimensions") if raw is None: return 0 try: return int(raw) except (ValueError, TypeError): logger.warning( f"embedding_dimensions in embedding configs is not a valid integer: " f"'{raw}', ignored." ) return 0 async def terminate(self): if self.client and not self.client.closed: await self.client.close() self.client = None
Then Ollama-specific code becomes thinner and avoids duplication:
class OllamaEmbeddingProvider(EmbeddingProvider): ... async def _get_client(self): headers = { "Content-Type": "application/json", "Accept": "application/json", } return await self._ensure_client(headers=headers, timeout_s=self.timeout)
-
def get_dim(self) -> int: -
... -
async def terminate(self): -
...
-
# Inherit get_dim() and terminate() from EmbeddingProvider
There was a problem hiding this comment.
Code Review
This pull request adds support for NVIDIA and Ollama embedding providers, including their default configurations and adapter implementations. The review identifies several instances of redundant code, specifically in the handling of proxy settings and client session verification within the new provider source files.
| if self.proxy: | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) | ||
| else: | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) |
There was a problem hiding this comment.
| client = await self._get_client() | ||
| if not client or client.closed: | ||
| raise Exception("[NVIDIA Embedding] Client session not initialized") |
There was a problem hiding this comment.
This check is redundant because _get_client() ensures that self.client is initialized and open before returning it. If it fails to initialize, it would have raised an exception within _get_client or during the ClientSession creation.
| client = await self._get_client() | |
| if not client or client.closed: | |
| raise Exception("[NVIDIA Embedding] Client session not initialized") | |
| client = await self._get_client() |
| client = await self._get_client() | ||
| if not client or client.closed: | ||
| raise Exception("[Ollama Embedding] Client session not initialized") |
There was a problem hiding this comment.
This check is redundant because _get_client() ensures that self.client is initialized and open before returning it.
| client = await self._get_client() | |
| if not client or client.closed: | |
| raise Exception("[Ollama Embedding] Client session not initialized") | |
| client = await self._get_client() |
There was a problem hiding this comment.
Pull request overview
This PR extends AstrBot’s knowledge-base vectorization options by adding two new Embedding providers (NVIDIA NIM Embedding and Ollama Embedding) and wiring them into the provider manager and default configuration.
Changes:
- Added
NvidiaEmbeddingProvider(aiohttp-based) for NVIDIA’s OpenAI-compatible/embeddingsAPI. - Added
OllamaEmbeddingProvider(aiohttp-based) targeting Ollama’s/api/embedendpoint. - Registered both providers for lazy import in
ProviderManagerand added default configs.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
astrbot/core/provider/sources/nvidia_embedding_source.py |
New NVIDIA embedding provider implementation (request building, response parsing, session management). |
astrbot/core/provider/sources/ollama_embedding_source.py |
New Ollama embedding provider implementation (payload building, request/response handling, session management). |
astrbot/core/provider/manager.py |
Adds lazy-import cases so the new provider types can be instantiated by config. |
astrbot/core/config/default.py |
Adds default provider entries for NVIDIA Embedding and Ollama Embedding. |
| response_data = await response.json() | ||
| embeddings = self._parse_response(response_data) | ||
|
|
||
| usage = response_data.get("usage", {}) | ||
| total_tokens = usage.get("total_tokens", 0) | ||
| if total_tokens > 0: | ||
| logger.debug(f"[NVIDIA Embedding] Token usage: {total_tokens}") | ||
|
|
||
| return embeddings |
| if self.proxy: | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) | ||
| else: | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) |
| if dimensions > 0: | ||
| payload["dimensions"] = dimensions | ||
| except (ValueError, TypeError): | ||
| pass |
| "type": "nvidia_embedding", | ||
| "provider": "nvidia", | ||
| "provider_type": "embedding", | ||
| "hint": "provider_group.provider.nvidia_embedding.hint", |
| "type": "ollama_embedding", | ||
| "provider": "ollama", | ||
| "provider_type": "embedding", | ||
| "hint": "provider_group.provider.ollama_embedding.hint", |
- Remove redundant proxy branch in NvidiaEmbeddingProvider._get_client - Change ClientError handling to re-raise instead of wrapping in Exception - Add exc_info=True for better error diagnostics - Remove redundant isinstance check in OllamaEmbeddingProvider._build_payload
| async def _get_client(self): | ||
| if self.client is None or self.client.closed: | ||
| headers = { | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| } | ||
| timeout = aiohttp.ClientTimeout(total=self.timeout) | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) | ||
| return self.client |
| async def _get_client(self): | ||
| if self.client is None or self.client.closed: | ||
| headers = { | ||
| "Authorization": f"Bearer {self.api_key}", | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| } | ||
| timeout = aiohttp.ClientTimeout(total=self.timeout) | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) | ||
| return self.client |
| "Authorization": f"Bearer {self.api_key}", | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| } |
| "provider": "nvidia", | ||
| "provider_type": "embedding", | ||
| "hint": "provider_group.provider.nvidia_embedding.hint", | ||
| "enable": True, | ||
| "embedding_api_key": "", | ||
| "embedding_api_base": "https://integrate.api.nvidia.com/v1", | ||
| "embedding_model": "nvidia/llama-nemotron-embed-1b-v2", | ||
| "input_type": "passage", | ||
| "embedding_dimensions": 1024, | ||
| "timeout": 20, | ||
| "proxy": "", | ||
| }, | ||
| "Ollama Embedding": { | ||
| "id": "ollama_embedding", | ||
| "type": "ollama_embedding", | ||
| "provider": "ollama", | ||
| "provider_type": "embedding", | ||
| "hint": "provider_group.provider.ollama_embedding.hint", | ||
| "enable": True, |
This PR adds two new Embedding providers to extend AstrBot's knowledge base vectorization capabilities:
NV-Embed-QA,llama-nemotron-embed-1b-v2). Users can leverage NVIDIA's free developer tier for faster vectorization, or connect to privately deployed NIM microservices.Closes [Feature]在 Embedding provider 中添加对 Nvidia Embedding 的支持 #7829.
Modifications / 改动点
Core files modified:
astrbot/core/provider/sources/nvidia_embedding_source.py(new, 140 lines)NvidiaEmbeddingProviderclass implementingEmbeddingProviderinterfaceapi_key,base_url,model,input_type,timeout,proxyget_embedding()/get_embeddings()astrbot/core/provider/sources/ollama_embedding_source.py(new, 120 lines)OllamaEmbeddingProviderclass implementingEmbeddingProviderinterfacebase_url,model,embedding_dimensions,timeout,proxy/api/embedendpointastrbot/core/provider/manager.py(+8 lines)match/caseblock)astrbot/core/config/default.py(+28 lines)This is NOT a breaking change. / 这不是一个破坏性变更。
Screenshots or Test Results / 运行截图或测试结果
Checklist / 检查清单
😊 If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
/ 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。
👀 My changes have been well-tested, and "Verification Steps" and "Screenshots" have been provided above.
/ 我的更改经过了良好的测试,并已在上方提供了“验证步骤”和“运行截图”。
🤓 I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in
requirements.txtandpyproject.toml./ 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到
requirements.txt和pyproject.toml文件相应位置。😮 My changes do not introduce malicious code.
/ 我的更改没有引入恶意代码。
Summary by Sourcery
Add new NVIDIA NIM and Ollama embedding providers and wire them into the provider manager and default configuration.
New Features:
Enhancements: