Skip to content

Commit 48556c0

Browse files
authored
fix(bedrock): handle ARN and inference profile identifiers without transformation (#48)
Signed-off-by: Roei Savion <[email protected]>
1 parent 28f98d6 commit 48556c0

File tree

2 files changed

+154
-20
lines changed

2 files changed

+154
-20
lines changed

src/providers/bedrock/provider.rs

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,31 @@ impl BedrockProvider {
9292

9393
provider_implementation
9494
}
95+
96+
fn transform_model_identifier(&self, model: String, model_config: &ModelConfig) -> String {
97+
// Check if the model is already an ARN or inference profile ID
98+
if model.starts_with("arn:aws:bedrock:") || model.contains("inference-profile") {
99+
// Use the model identifier as-is for ARNs and inference profiles
100+
model
101+
} else {
102+
// Transform model name to include provider prefix for regular model IDs
103+
let model_provider = model_config.params.get("model_provider").unwrap();
104+
let inference_profile_id = self.config.params.get("inference_profile_id");
105+
let model_version = model_config
106+
.params
107+
.get("model_version")
108+
.map_or("v1:0", |s| &**s);
109+
110+
if let Some(profile_id) = inference_profile_id {
111+
format!(
112+
"{}.{}.{}-{}",
113+
profile_id, model_provider, model, model_version
114+
)
115+
} else {
116+
format!("{}.{}-{}", model_provider, model, model_version)
117+
}
118+
}
119+
}
95120
}
96121

97122
#[async_trait]
@@ -120,25 +145,10 @@ impl Provider for BedrockProvider {
120145
StatusCode::INTERNAL_SERVER_ERROR
121146
})?;
122147

123-
// Transform model name to include provider prefix
124-
let model_provider = model_config.params.get("model_provider").unwrap();
125-
let inference_profile_id = self.config.params.get("inference_profile_id");
126148
let mut transformed_payload = payload;
127-
let model_version = model_config
128-
.params
129-
.get("model_version")
130-
.map_or("v1:0", |s| &**s);
131-
transformed_payload.model = if let Some(profile_id) = inference_profile_id {
132-
format!(
133-
"{}.{}.{}-{}",
134-
profile_id, model_provider, transformed_payload.model, model_version
135-
)
136-
} else {
137-
format!(
138-
"{}.{}-{}",
139-
model_provider, transformed_payload.model, model_version
140-
)
141-
};
149+
150+
transformed_payload.model =
151+
self.transform_model_identifier(transformed_payload.model, model_config);
142152

143153
self.get_provider_implementation(model_config)
144154
.chat_completion(&client, transformed_payload)
@@ -155,8 +165,13 @@ impl Provider for BedrockProvider {
155165
StatusCode::INTERNAL_SERVER_ERROR
156166
})?;
157167

168+
let mut transformed_payload = payload;
169+
170+
transformed_payload.model =
171+
self.transform_model_identifier(transformed_payload.model, model_config);
172+
158173
self.get_provider_implementation(model_config)
159-
.completion(&client, payload)
174+
.completion(&client, transformed_payload)
160175
.await
161176
}
162177

@@ -170,8 +185,13 @@ impl Provider for BedrockProvider {
170185
StatusCode::INTERNAL_SERVER_ERROR
171186
})?;
172187

188+
let mut transformed_payload = payload;
189+
190+
transformed_payload.model =
191+
self.transform_model_identifier(transformed_payload.model, model_config);
192+
173193
self.get_provider_implementation(model_config)
174-
.embedding(&client, payload)
194+
.embedding(&client, transformed_payload)
175195
.await
176196
}
177197
}

src/providers/bedrock/test.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,120 @@ mod ai21_tests {
386386
}
387387
}
388388
}
389+
390+
#[cfg(test)]
391+
mod arn_tests {
392+
use crate::models::chat::ChatCompletionRequest;
393+
use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
394+
use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config};
395+
use crate::providers::bedrock::BedrockProvider;
396+
use crate::providers::provider::Provider;
397+
398+
#[tokio::test]
399+
async fn test_arn_model_identifier_not_transformed() {
400+
let config = get_test_provider_config("us-east-1", "anthropic_chat_completion");
401+
let provider = BedrockProvider::new(&config);
402+
403+
// Test with full ARN - should not be transformed
404+
let model_config = get_test_model_config(
405+
"arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.example.test-model-v1:0",
406+
"anthropic",
407+
);
408+
409+
let arn_model =
410+
"arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.example.test-model-v1:0";
411+
let payload = ChatCompletionRequest {
412+
model: arn_model.to_string(),
413+
messages: vec![ChatCompletionMessage {
414+
role: "user".to_string(),
415+
content: Some(ChatMessageContent::String(
416+
"Tell me a short joke".to_string(),
417+
)),
418+
name: None,
419+
tool_calls: None,
420+
refusal: None,
421+
}],
422+
temperature: None,
423+
top_p: None,
424+
n: None,
425+
stream: None,
426+
stop: None,
427+
max_tokens: None,
428+
max_completion_tokens: None,
429+
parallel_tool_calls: None,
430+
presence_penalty: None,
431+
frequency_penalty: None,
432+
logit_bias: None,
433+
tool_choice: None,
434+
tools: None,
435+
user: None,
436+
logprobs: None,
437+
top_logprobs: None,
438+
response_format: None,
439+
};
440+
441+
// The test here is that we don't get a transformation error
442+
// The mock will handle the actual response
443+
let result = provider.chat_completions(payload, &model_config).await;
444+
445+
// Should not fail due to model identifier transformation
446+
assert!(
447+
result.is_ok(),
448+
"ARN model identifier should be handled correctly: {:?}",
449+
result.err()
450+
);
451+
}
452+
453+
#[tokio::test]
454+
async fn test_inference_profile_identifier_not_transformed() {
455+
let config = get_test_provider_config("us-east-1", "anthropic_chat_completion");
456+
let provider = BedrockProvider::new(&config);
457+
458+
// Test with inference profile ID - should not be transformed
459+
let model_config = get_test_model_config("us-east-1-inference-profile-123", "anthropic");
460+
461+
let inference_profile_model = "us-east-1-inference-profile-123";
462+
let payload = ChatCompletionRequest {
463+
model: inference_profile_model.to_string(),
464+
messages: vec![ChatCompletionMessage {
465+
role: "user".to_string(),
466+
content: Some(ChatMessageContent::String(
467+
"Tell me a short joke".to_string(),
468+
)),
469+
name: None,
470+
tool_calls: None,
471+
refusal: None,
472+
}],
473+
temperature: None,
474+
top_p: None,
475+
n: None,
476+
stream: None,
477+
stop: None,
478+
max_tokens: None,
479+
max_completion_tokens: None,
480+
parallel_tool_calls: None,
481+
presence_penalty: None,
482+
frequency_penalty: None,
483+
logit_bias: None,
484+
tool_choice: None,
485+
tools: None,
486+
user: None,
487+
logprobs: None,
488+
top_logprobs: None,
489+
response_format: None,
490+
};
491+
492+
let result = provider.chat_completions(payload, &model_config).await;
493+
494+
// Should not fail due to model identifier transformation
495+
assert!(
496+
result.is_ok(),
497+
"Inference profile identifier should be handled correctly: {:?}",
498+
result.err()
499+
);
500+
}
501+
}
502+
389503
/**
390504
391505
Helper functions for creating test clients and mock responses

0 commit comments

Comments
 (0)