diff --git a/pyrefly/lib/commands/infer.rs b/pyrefly/lib/commands/infer.rs index 6ae8dd84e..c558d55f1 100644 --- a/pyrefly/lib/commands/infer.rs +++ b/pyrefly/lib/commands/infer.rs @@ -26,6 +26,7 @@ use crate::commands::files::get_project_config_for_current_dir; use crate::commands::util::CommandExitStatus; use crate::config::error_kind::ErrorKind; use crate::lsp::wasm::inlay_hints::ParameterAnnotation; +use crate::state::ide::ImportEdit; use crate::state::ide::insert_import_edit_with_forced_import_format; use crate::state::lsp::AnnotationKind; use crate::state::require::Require; @@ -292,7 +293,7 @@ impl InferArgs { if let Some(ast) = transaction.get_ast(&handle) { let error_range = error.range(); let unknown_name = module_info.code_at(error_range); - let imports: Vec<(TextSize, String, String)> = transaction + let imports: Vec = transaction .search_exports_exact(unknown_name) .into_iter() .map(|handle_to_import_from| { @@ -302,6 +303,7 @@ impl InferArgs { handle_to_import_from.dupe(), unknown_name, true, + /*merge_with_existing=*/ false, ) }) .collect(); @@ -337,16 +339,16 @@ impl InferArgs { fs_anyhow::write(file_path, result) } - fn add_imports_to_file( - file_path: &Path, - imports: Vec<(TextSize, String, String)>, - ) -> anyhow::Result<()> { + fn add_imports_to_file(file_path: &Path, imports: Vec) -> anyhow::Result<()> { let file_content = fs_anyhow::read_to_string(file_path)?; let mut result = file_content; - for (position, import, _) in imports { - let offset = (position).into(); - if !result.contains(&import) { - result.insert_str(offset, &import); + for import_edit in imports { + if import_edit.insert_text.is_empty() { + continue; + } + let offset = (import_edit.position).into(); + if offset <= result.len() && !result.contains(&import_edit.insert_text) { + result.insert_str(offset, &import_edit.insert_text); } } fs_anyhow::write(file_path, result) diff --git a/pyrefly/lib/state/ide.rs b/pyrefly/lib/state/ide.rs index a1a56f2a4..49446285e 100644 --- a/pyrefly/lib/state/ide.rs +++ b/pyrefly/lib/state/ide.rs @@ -15,6 +15,8 @@ use pyrefly_python::symbol_kind::SymbolKind; use pyrefly_util::gas::Gas; use ruff_python_ast::Expr; use ruff_python_ast::ModModule; +use ruff_python_ast::Stmt; +use ruff_python_ast::StmtImportFrom; use ruff_python_ast::helpers::is_docstring_stmt; use ruff_python_ast::name::Name; use ruff_text_size::Ranged; @@ -33,6 +35,14 @@ use crate::state::lsp::ImportFormat; const KEY_TO_DEFINITION_INITIAL_GAS: Gas = Gas::new(100); +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ImportEdit { + pub position: TextSize, + pub insert_text: String, + pub display_text: String, + pub module_name: String, +} + pub enum IntermediateDefinition { Local(Export), NamedImport(TextRange, ModuleName, Name, Option), @@ -203,7 +213,7 @@ pub fn insert_import_edit( handle_to_import_from: Handle, export_name: &str, import_format: ImportFormat, -) -> (TextSize, String, String) { +) -> ImportEdit { let use_absolute_import = match import_format { ImportFormat::Absolute => true, ImportFormat::Relative => { @@ -216,6 +226,7 @@ pub fn insert_import_edit( handle_to_import_from, export_name, use_absolute_import, + true, ) } @@ -240,12 +251,8 @@ pub fn insert_import_edit_with_forced_import_format( handle_to_import_from: Handle, export_name: &str, use_absolute_import: bool, -) -> (TextSize, String, String) { - let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) { - first_stmt.range().start() - } else { - ast.range.end() - }; + merge_with_existing: bool, +) -> ImportEdit { let module_name_to_import = if use_absolute_import { handle_to_import_from.module() } else if let Some(relative_module) = ModuleName::relative_module_name_between( @@ -256,12 +263,38 @@ pub fn insert_import_edit_with_forced_import_format( } else { handle_to_import_from.module() }; + let display_text = format!( + "from {} import {}", + module_name_to_import.as_str(), + export_name + ); + if merge_with_existing + && let Some(edit) = try_extend_existing_from_import( + ast, + module_name_to_import.as_str(), + export_name, + display_text.clone(), + module_name_to_import.as_str(), + ) + { + return edit; + } + let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) { + first_stmt.range().start() + } else { + ast.range.end() + }; let insert_text = format!( "from {} import {}\n", module_name_to_import.as_str(), export_name ); - (position, insert_text, module_name_to_import.to_string()) + ImportEdit { + position, + insert_text, + display_text, + module_name: module_name_to_import.to_string(), + } } /// Some handles must be imported in absolute style, @@ -285,3 +318,48 @@ fn handle_require_absolute_import(config_finder: &ConfigFinder, handle: &Handle) .site_package_path() .any(|search_path| handle.path().as_path().starts_with(search_path)) } + +fn try_extend_existing_from_import( + ast: &ModModule, + target_module_name: &str, + export_name: &str, + display_text: String, + module_name: &str, +) -> Option { + for stmt in &ast.body { + if let Stmt::ImportFrom(import_from) = stmt + && import_from_module_name(import_from) == target_module_name + { + if import_from + .names + .iter() + .any(|alias| alias.asname.is_none() && alias.name.as_str() == export_name) + { + // Already imported; don't propose a duplicate edit. + return None; + } + if let Some(last_alias) = import_from.names.last() { + let position = last_alias.range.end(); + let insert_text = format!(", {}", export_name); + return Some(ImportEdit { + position, + insert_text, + display_text, + module_name: module_name.to_owned(), + }); + } + } + } + None +} + +fn import_from_module_name(import_from: &StmtImportFrom) -> String { + let mut module_name = String::new(); + if import_from.level > 0 { + module_name.push_str(&".".repeat(import_from.level as usize)); + } + if let Some(module) = &import_from.module { + module_name.push_str(module.as_str()); + } + module_name +} diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 39dc62c4d..e76ba5729 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -1642,7 +1642,7 @@ impl<'a> Transaction<'a> { if error_range.contains_range(range) { let unknown_name = module_info.code_at(error_range); for handle_to_import_from in self.search_exports_exact(unknown_name) { - let (position, insert_text, _) = insert_import_edit( + let import_edit = insert_import_edit( &ast, self.config_finder(), handle.dupe(), @@ -1650,9 +1650,18 @@ impl<'a> Transaction<'a> { unknown_name, import_format, ); - let range = TextRange::at(position, TextSize::new(0)); - let title = format!("Insert import: `{}`", insert_text.trim()); - code_actions.push((title, module_info.dupe(), range, insert_text)); + // If the symbol was already imported we get an empty edit; skip it. + if import_edit.insert_text.is_empty() { + continue; + } + let range = TextRange::at(import_edit.position, TextSize::new(0)); + let title = format!("Insert import: `{}`", import_edit.display_text); + code_actions.push(( + title, + module_info.dupe(), + range, + import_edit.insert_text, + )); } for module_name in self.search_modules_fuzzy(unknown_name) { @@ -2176,9 +2185,11 @@ impl<'a> Transaction<'a> { && let Some(ast) = self.get_ast(handle) && let Some(module_info) = self.get_module_info(handle) { - for (handle_to_import_from, name, export) in - self.search_exports_fuzzy(identifier.as_str()) - { + let search_results = self.search_exports_fuzzy(identifier.as_str()); + for (handle_to_import_from, name, export) in search_results { + if !identifier.as_str().starts_with('_') && name.starts_with('_') { + continue; + } // Using handle itself doesn't always work because handles can be made separately and have different hashes if handle_to_import_from.module() == handle.module() || handle_to_import_from.module() == ModuleName::builtins() @@ -2186,8 +2197,8 @@ impl<'a> Transaction<'a> { continue; } let module_description = handle_to_import_from.module().as_str().to_owned(); - let (insert_text, additional_text_edits, imported_module) = { - let (position, insert_text, module_name) = insert_import_edit( + let (detail_text, additional_text_edits, imported_module) = { + let import_edit = insert_import_edit( &ast, self.config_finder(), handle.dupe(), @@ -2195,11 +2206,19 @@ impl<'a> Transaction<'a> { &name, import_format, ); + if import_edit.insert_text.is_empty() { + continue; + } let import_text_edit = TextEdit { - range: module_info.to_lsp_range(TextRange::at(position, TextSize::new(0))), - new_text: insert_text.clone(), + range: module_info + .to_lsp_range(TextRange::at(import_edit.position, TextSize::new(0))), + new_text: import_edit.insert_text.clone(), }; - (insert_text, Some(vec![import_text_edit]), module_name) + ( + Some(import_edit.insert_text.clone()), + Some(vec![import_text_edit]), + import_edit.module_name, + ) }; let auto_import_label_detail = format!(" (import {imported_module})"); let (label, label_details) = if supports_completion_item_details { @@ -2215,7 +2234,7 @@ impl<'a> Transaction<'a> { }; completions.push(CompletionItem { label, - detail: Some(insert_text), + detail: detail_text, kind: export .symbol_kind .map_or(Some(CompletionItemKind::VARIABLE), |k| { diff --git a/pyrefly/lib/test/lsp/code_actions.rs b/pyrefly/lib/test/lsp/code_actions.rs index 203406f46..1d4a36afb 100644 --- a/pyrefly/lib/test/lsp/code_actions.rs +++ b/pyrefly/lib/test/lsp/code_actions.rs @@ -210,8 +210,7 @@ fn insertion_test_duplicate_imports() { ], get_test_report, ); - // The insertion won't attempt to merge imports from the same module. - // It's not illegal, but it would be nice if we do merge. + // When another import from the same module already exists, we should append to it. assert_eq!( r#" # a.py @@ -227,8 +226,7 @@ from a import another_thing my_export # ^ ## After: -from a import my_export -from a import another_thing +from a import another_thing, my_export my_export # ^ "# diff --git a/pyrefly/lib/test/lsp/completion.rs b/pyrefly/lib/test/lsp/completion.rs index 27ac3fbab..a5a528ef5 100644 --- a/pyrefly/lib/test/lsp/completion.rs +++ b/pyrefly/lib/test/lsp/completion.rs @@ -107,7 +107,7 @@ fn get_test_report( report.push_str("[DEPRECATED] "); } report.push_str(&label); - if let Some(detail) = detail { + if let Some(detail) = &detail { report.push_str(": "); report.push_str(&detail); } @@ -120,7 +120,7 @@ fn get_test_report( report.push_str(" with text edit: "); report.push_str(&format!("{:?}", &text_edit)); } - if let Some(documentation) = documentation { + if let Some(ref documentation) = documentation { report.push('\n'); match documentation { lsp_types::Documentation::String(s) => {