-
Notifications
You must be signed in to change notification settings - Fork 53
feature: image compression #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,15 +11,14 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
| import base64 | ||
| from io import BytesIO | ||
|
|
||
| from PIL import Image | ||
|
|
||
| from crab.core import action | ||
|
|
||
| from pydantic import Field | ||
| from crab.core.decorators import action | ||
| from crab.utils.common import base64_to_image | ||
|
|
||
| @action | ||
| def save_base64_image(image: str, path: str = "image.png") -> None: | ||
| image = Image.open(BytesIO(base64.b64decode(image))) | ||
| image.save(path) | ||
| def save_image(image: str = Field(..., description="Base64 encoded image string"), path: str = Field(..., description="Path to save the image")): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the pydantic Field cannot be used as the annotations. What did you want to achieve by this modification? |
||
| """Save a base64 encoded image to a file.""" | ||
| img = base64_to_image(image) | ||
| img.save(path) | ||
| return f"Image saved to {path}" | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,158 @@ | ||||||
| # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||||||
| # Licensed under the Apache License, Version 2.0 (the “License”); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an “AS IS” BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
| # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||||||
|
|
||||||
| import json | ||||||
| from typing import Any | ||||||
| from time import sleep | ||||||
|
|
||||||
| from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import. The import Apply this diff to remove the unused import: -from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
+from crab import Action, ActionOutput, BackendModel, MessageTypeCommittable suggestion
Suggested change
ToolsRuff
|
||||||
|
|
||||||
| try: | ||||||
| from zhipuai import ZhipuAI | ||||||
| glm_model_enable = True | ||||||
| except ImportError: | ||||||
| glm_model_enable = False | ||||||
|
|
||||||
| class GLMModel(BackendModel): | ||||||
| def __init__( | ||||||
| self, | ||||||
| model: str, | ||||||
| parameters: dict[str, Any] = dict(), | ||||||
| history_messages_len: int = 0, | ||||||
| ) -> None: | ||||||
| if not glm_model_enable: | ||||||
| raise ImportError("Please install zhipuai to use GLMModel") | ||||||
| super().__init__( | ||||||
| model, | ||||||
| parameters, | ||||||
| history_messages_len, | ||||||
| ) | ||||||
| self.client = ZhipuAI() | ||||||
|
|
||||||
| def reset(self, system_message: str, action_space: list[Action] | None) -> None: | ||||||
| self.system_message = system_message | ||||||
| self.glm_system_message = { | ||||||
| "role": "system", | ||||||
| "content": system_message, | ||||||
| } | ||||||
| self.action_space = action_space | ||||||
| self.action_schema = self._convert_action_to_schema(self.action_space) | ||||||
| self.token_usage = 0 | ||||||
| self.chat_history = [] | ||||||
|
|
||||||
| def chat(self, message: tuple[str, MessageType]): | ||||||
| request_messages = self._convert_to_request_messages(message) | ||||||
| response = self.call_api(request_messages) | ||||||
|
|
||||||
| assistant_message = response.choices[0].message | ||||||
| action_list = self._convert_tool_calls_to_action_list(assistant_message) | ||||||
|
|
||||||
| output = ChatOutput( | ||||||
| message=assistant_message.content if not action_list else None, | ||||||
| action_list=action_list, | ||||||
| ) | ||||||
|
|
||||||
| self.record_message(request_messages[-1], assistant_message) | ||||||
| return output | ||||||
|
Comment on lines
+54
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chat interaction method. This method handles the chat interaction by converting messages, calling the API, and processing the response. The implementation seems robust, but ensure that the Define or import from crab import ChatOutput # Assuming it's part of the crab moduleToolsRuff
|
||||||
|
|
||||||
| def get_token_usage(self): | ||||||
| return self.token_usage | ||||||
|
|
||||||
| def record_message(self, new_message: dict, response_message: dict) -> None: | ||||||
| self.chat_history.append([new_message]) | ||||||
| self.chat_history[-1].append(response_message) | ||||||
|
|
||||||
| if self.action_schema: | ||||||
| tool_calls = response_message.tool_calls | ||||||
| for tool_call in tool_calls: | ||||||
| self.chat_history[-1].append( | ||||||
| { | ||||||
| "tool_call_id": tool_call.id, | ||||||
| "role": "tool", | ||||||
| "name": tool_call.function.name, | ||||||
| "content": "", | ||||||
| } | ||||||
| ) | ||||||
|
|
||||||
| def call_api(self, request_messages: list): | ||||||
| while True: | ||||||
| try: | ||||||
| response = self.client.chat.completions.create( | ||||||
| model=self.model, | ||||||
| messages=request_messages, | ||||||
| **self.parameters, | ||||||
| ) | ||||||
| except Exception as e: | ||||||
| print(f"API call failed: {str(e)}. Retrying in 10 seconds...") | ||||||
| sleep(10) | ||||||
| else: | ||||||
| break | ||||||
|
|
||||||
| self.token_usage += response.usage.total_tokens | ||||||
| return response | ||||||
|
Comment on lines
+88
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. API call method with retry logic. The method includes retry logic for handling API call failures, which is crucial for maintaining robustness in network interactions. However, consider adding a limit to the number of retries to avoid potential infinite loops. Add a retry limit to the API call method: max_retries = 5
attempts = 0
while attempts < max_retries:
try:
...
except Exception as e:
if attempts < max_retries - 1:
print(f"API call failed: {str(e)}. Retrying in 10 seconds...")
sleep(10)
attempts += 1
else:
raise Exception("API call failed after maximum retries.")
else:
break |
||||||
|
|
||||||
| @staticmethod | ||||||
| def _convert_action_to_schema(action_space: list[Action] | None): | ||||||
| if action_space is None: | ||||||
| return None | ||||||
|
|
||||||
| tools = [] | ||||||
| for action in action_space: | ||||||
| tool = { | ||||||
| "type": "function", | ||||||
| "function": { | ||||||
| "name": action.name, | ||||||
| "description": action.description, | ||||||
| "parameters": { | ||||||
| "type": "object", | ||||||
| "properties": {}, | ||||||
| "required": [], | ||||||
| }, | ||||||
| }, | ||||||
| } | ||||||
| for param in action.parameters: | ||||||
| tool["function"]["parameters"]["properties"][param.name] = { | ||||||
| "type": param.type, | ||||||
| "description": param.description, | ||||||
| } | ||||||
| if param.required: | ||||||
| tool["function"]["parameters"]["required"].append(param.name) | ||||||
| tools.append(tool) | ||||||
| return tools | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _convert_tool_calls_to_action_list(self, message): | ||||||
| if not message.content or not message.content.startswith("arguments="): | ||||||
| return None | ||||||
|
|
||||||
| action_list = [] | ||||||
| parts = message.content.split(", name=") | ||||||
| arguments = json.loads(parts[0].replace("arguments=", "").strip("'")) | ||||||
| name = parts[1].strip("'") | ||||||
| action_output = ActionOutput( | ||||||
| name=name, | ||||||
| args=arguments, | ||||||
| ) | ||||||
| action_list.append(action_output) | ||||||
| return action_list | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _convert_message(message: tuple[str, MessageType]): | ||||||
| content, message_type = message | ||||||
| if message_type == MessageType.TEXT: | ||||||
| return {"type": "text", "text": content} | ||||||
| elif message_type == MessageType.IMAGE_URL: | ||||||
| return {"type": "image_url", "image_url": {"url": content}} | ||||||
| else: | ||||||
| raise ValueError(f"Unsupported message type: {message_type}") | ||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,104 @@ | ||||||||
| import pytest | ||||||||
| from PIL import Image | ||||||||
| import io | ||||||||
| import base64 | ||||||||
| import os | ||||||||
|
Comment on lines
+3
to
+5
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused imports. The imports Apply this diff to remove the unused imports: -import io
-import base64
-import osCommittable suggestion
Suggested change
ToolsRuff
|
||||||||
| from unittest.mock import patch, MagicMock | ||||||||
| from crab.utils.common import base64_to_image, image_to_base64 | ||||||||
| from crab.actions.file_actions import save_image | ||||||||
|
|
||||||||
| import sys | ||||||||
| # Mock the entire crab.agents.backend_models module | ||||||||
| sys.modules['crab.agents.backend_models'] = MagicMock() | ||||||||
| sys.modules['crab.agents.backend_models.openai_model'] = MagicMock() | ||||||||
| sys.modules['crab.actions.desktop_actions'] = MagicMock() | ||||||||
|
|
||||||||
| # Create mock classes/functions | ||||||||
| class MockOpenAIModel: | ||||||||
| def _convert_message(self, message): | ||||||||
| return {"type": "image_url", "image_url": {"url": ""}} | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Address line length issues. Several lines exceed the recommended line length of 88 characters. Refactoring these lines will improve code readability and maintainability. Apply these diffs to address the line length issues: - return {"type": "image_url", "image_url": {"url": ""}}
+ return {
+ "type": "image_url",
+ "image_url": {"url": ""}
+ }
- converted_message = openai_model._convert_message((received_image, MockMessageType.IMAGE_JPG_BASE64))
+ converted_message = openai_model._convert_message(
+ (received_image, MockMessageType.IMAGE_JPG_BASE64)
+ )
- assert converted_message["image_url"]["url"].startswith("data:image/png;base64,")
+ assert converted_message["image_url"]["url"].startswith(
+ "data:image/png;base64,"
+ )Also applies to: 77-77, 79-79 ToolsRuff
|
||||||||
|
|
||||||||
| class MockMessageType: | ||||||||
| IMAGE_JPG_BASE64 = "image_jpg_base64" | ||||||||
|
|
||||||||
| def mock_screenshot(): | ||||||||
| return Image.new('RGB', (100, 100), color='red') | ||||||||
|
|
||||||||
| # Apply mocks | ||||||||
| patch('crab.agents.backend_models.openai_model.OpenAIModel', MockOpenAIModel).start() | ||||||||
| patch('crab.agents.backend_models.openai_model.MessageType', MockMessageType).start() | ||||||||
| patch('crab.actions.desktop_actions.screenshot', mock_screenshot).start() | ||||||||
|
|
||||||||
| class TestImageHandling: | ||||||||
| @pytest.fixture(autouse=True) | ||||||||
| def setup(self): | ||||||||
| self.test_image = Image.new('RGB', (100, 100), color='red') | ||||||||
|
|
||||||||
| def test_image_processing_path(self): | ||||||||
| print("\n--- Image Processing Path Test ---") | ||||||||
|
|
||||||||
| # Use self.test_image instead of taking a screenshot | ||||||||
| screenshot_image = self.test_image | ||||||||
|
|
||||||||
| # 1. Start with a PIL Image (using self.test_image) | ||||||||
| print("1. Starting with a PIL Image") | ||||||||
| assert isinstance(self.test_image, Image.Image) | ||||||||
|
|
||||||||
| # 2. Simulate saving the image | ||||||||
| print("2. Saving the image") | ||||||||
| save_image(self.test_image, "test_image.png") | ||||||||
| print(" Image saved successfully") | ||||||||
|
|
||||||||
| # 3. Using self.test_image instead of taking a screenshot | ||||||||
| print("3. Using self.test_image instead of taking a screenshot") | ||||||||
| screenshot_image = self.test_image | ||||||||
| assert isinstance(screenshot_image, Image.Image) | ||||||||
| print(" Using self.test_image as PIL Image") | ||||||||
|
|
||||||||
| # 4. Prepare for network transfer (serialize to base64) | ||||||||
| print("4. Serializing image for network transfer") | ||||||||
| base64_string = image_to_base64(self.test_image) | ||||||||
| assert isinstance(base64_string, str) | ||||||||
| print(" Image serialized to base64 string") | ||||||||
|
|
||||||||
| # 5. Simulate network transfer | ||||||||
| print("5. Simulating network transfer") | ||||||||
| received_base64 = base64_string # In reality, this would be sent and received | ||||||||
|
|
||||||||
| # 6. Deserialize after network transfer | ||||||||
| print("6. Deserializing image after network transfer") | ||||||||
| received_image = base64_to_image(received_base64) | ||||||||
| assert isinstance(received_image, Image.Image) | ||||||||
| print(" Image deserialized back to PIL Image") | ||||||||
|
|
||||||||
| # 7. Use the image in a backend model (e.g., OpenAI) | ||||||||
| print("7. Using image in backend model") | ||||||||
| openai_model = MockOpenAIModel() | ||||||||
| converted_message = openai_model._convert_message((received_image, MockMessageType.IMAGE_JPG_BASE64)) | ||||||||
| assert converted_message["type"] == "image_url" | ||||||||
| assert converted_message["image_url"]["url"].startswith("data:image/png;base64,") | ||||||||
| print(" Image successfully converted for use in OpenAI model") | ||||||||
|
|
||||||||
| print("--- Image Processing Path Test Completed Successfully ---") | ||||||||
|
|
||||||||
| def test_base64_to_image(self): | ||||||||
| # Convert image to base64 | ||||||||
| base64_string = image_to_base64(self.test_image) | ||||||||
|
|
||||||||
| # Test base64_to_image function | ||||||||
| converted_image = base64_to_image(base64_string) | ||||||||
| assert isinstance(converted_image, Image.Image) | ||||||||
| assert converted_image.size == (100, 100) | ||||||||
|
|
||||||||
| def test_image_to_base64(self): | ||||||||
| # Test image_to_base64 function | ||||||||
| base64_string = image_to_base64(self.test_image) | ||||||||
| assert isinstance(base64_string, str) | ||||||||
|
|
||||||||
| # Verify that the base64 string can be converted back to an image | ||||||||
| converted_image = base64_to_image(base64_string) | ||||||||
| assert converted_image.size == (100, 100) | ||||||||
|
|
||||||||
| # Make sure to stop all patches after the tests | ||||||||
|
Comment on lines
+32
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comprehensive review of the The However, consider the following improvements:
Consider replacing print statements with logging and enhancing assertions for more robust testing. Example of replacing print with logging: import logging
logging.basicConfig(level=logging.INFO)
# Replace print statements
logging.info("1. Starting with a PIL Image")Example of enhanced assertions: # After converting image to base64 and back
original_data = self.test_image.tobytes()
restored_data = converted_image.tobytes()
assert original_data == restored_data, "Image data must remain consistent after conversions."ToolsRuff
|
||||||||
| def teardown_module(module): | ||||||||
| patch.stopall() | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider breaking the line to adhere to style guidelines.
The line defining the function signature is too long (145 characters), which exceeds the recommended limit of 88 characters. Consider breaking it into multiple lines for better readability.
Apply this diff to break the line:
Committable suggestion
Tools
Ruff