Skip to content

Commit df5401f

Browse files
fix
1 parent 8c0a1be commit df5401f

File tree

3 files changed

+87
-30
lines changed

3 files changed

+87
-30
lines changed

pyrefly/lib/commands/infer.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ impl InferArgs {
290290
if let Some(ast) = transaction.get_ast(&handle) {
291291
let error_range = error.range();
292292
let unknown_name = module_info.code_at(error_range);
293-
let imports: Vec<(TextSize, String, String)> = transaction
293+
let imports: Vec<ImportEdit> = transaction
294294
.search_exports_exact(unknown_name)
295295
.into_iter()
296296
.map(|handle_to_import_from| {
@@ -335,16 +335,16 @@ impl InferArgs {
335335
fs_anyhow::write(file_path, result)
336336
}
337337

338-
fn add_imports_to_file(
339-
file_path: &Path,
340-
imports: Vec<(TextSize, String, String)>,
341-
) -> anyhow::Result<()> {
338+
fn add_imports_to_file(file_path: &Path, imports: Vec<ImportEdit>) -> anyhow::Result<()> {
342339
let file_content = fs_anyhow::read_to_string(file_path)?;
343340
let mut result = file_content;
344-
for (position, import, _) in imports {
345-
let offset = (position).into();
346-
if !result.contains(&import) {
347-
result.insert_str(offset, &import);
341+
for import_edit in imports {
342+
if import_edit.insert_text.is_empty() {
343+
continue;
344+
}
345+
let offset = (import_edit.position).into();
346+
if offset <= result.len() && !result.contains(&import_edit.insert_text) {
347+
result.insert_str(offset, &import_edit.insert_text);
348348
}
349349
}
350350
fs_anyhow::write(file_path, result)

pyrefly/lib/state/ide.rs

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pub struct ImportEdit {
3939
pub position: TextSize,
4040
pub insert_text: String,
4141
pub display_text: String,
42+
pub module_name: String,
4243
}
4344

4445
pub enum IntermediateDefinition {
@@ -198,7 +199,7 @@ pub fn insert_import_edit(
198199
handle_to_import_from: Handle,
199200
export_name: &str,
200201
import_format: ImportFormat,
201-
) -> (TextSize, String, String) {
202+
) -> ImportEdit {
202203
let use_absolute_import = match import_format {
203204
ImportFormat::Absolute => true,
204205
ImportFormat::Relative => {
@@ -235,12 +236,7 @@ pub fn insert_import_edit_with_forced_import_format(
235236
handle_to_import_from: Handle,
236237
export_name: &str,
237238
use_absolute_import: bool,
238-
) -> (TextSize, String, String) {
239-
let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) {
240-
first_stmt.range().start()
241-
} else {
242-
ast.range.end()
243-
};
239+
) -> ImportEdit {
244240
let module_name_to_import = if use_absolute_import {
245241
handle_to_import_from.module()
246242
} else if let Some(relative_module) = ModuleName::relative_module_name_between(
@@ -256,14 +252,14 @@ pub fn insert_import_edit_with_forced_import_format(
256252
module_name_to_import.as_str(),
257253
export_name
258254
);
259-
if let Some((position, insert_text)) =
260-
try_extend_existing_from_import(ast, module_name_to_import.as_str(), export_name)
261-
{
262-
return ImportEdit {
263-
position,
264-
insert_text,
265-
display_text,
266-
};
255+
if let Some(edit) = try_extend_existing_from_import(
256+
ast,
257+
module_name_to_import.as_str(),
258+
export_name,
259+
display_text.clone(),
260+
module_name_to_import.as_str(),
261+
) {
262+
return edit;
267263
}
268264
let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) {
269265
first_stmt.range().start()
@@ -275,7 +271,12 @@ pub fn insert_import_edit_with_forced_import_format(
275271
module_name_to_import.as_str(),
276272
export_name
277273
);
278-
(position, insert_text, module_name_to_import.to_string())
274+
ImportEdit {
275+
position,
276+
insert_text,
277+
display_text,
278+
module_name: module_name_to_import.to_string(),
279+
}
279280
}
280281

281282
/// Some handles must be imported in absolute style,
@@ -299,3 +300,48 @@ fn handle_require_absolute_import(config_finder: &ConfigFinder, handle: &Handle)
299300
.site_package_path()
300301
.any(|search_path| handle.path().as_path().starts_with(search_path))
301302
}
303+
304+
fn try_extend_existing_from_import(
305+
ast: &ModModule,
306+
target_module_name: &str,
307+
export_name: &str,
308+
display_text: String,
309+
module_name: &str,
310+
) -> Option<ImportEdit> {
311+
for stmt in &ast.body {
312+
if let Stmt::ImportFrom(import_from) = stmt {
313+
if import_from_module_name(import_from) == target_module_name {
314+
if import_from
315+
.names
316+
.iter()
317+
.any(|alias| alias.asname.is_none() && alias.name.as_str() == export_name)
318+
{
319+
// Already imported; don't propose a duplicate edit.
320+
return None;
321+
}
322+
if let Some(last_alias) = import_from.names.last() {
323+
let position = last_alias.range.end();
324+
let insert_text = format!(", {}", export_name);
325+
return Some(ImportEdit {
326+
position,
327+
insert_text,
328+
display_text,
329+
module_name: module_name.to_owned(),
330+
});
331+
}
332+
}
333+
}
334+
}
335+
None
336+
}
337+
338+
fn import_from_module_name(import_from: &StmtImportFrom) -> String {
339+
let mut module_name = String::new();
340+
if import_from.level > 0 {
341+
module_name.push_str(&".".repeat(import_from.level as usize));
342+
}
343+
if let Some(module) = &import_from.module {
344+
module_name.push_str(module.as_str());
345+
}
346+
module_name
347+
}

pyrefly/lib/state/lsp.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,14 +1907,18 @@ impl<'a> Transaction<'a> {
19071907
if error_range.contains_range(range) {
19081908
let unknown_name = module_info.code_at(error_range);
19091909
for handle_to_import_from in self.search_exports_exact(unknown_name) {
1910-
let (position, insert_text, _) = insert_import_edit(
1910+
let import_edit = insert_import_edit(
19111911
&ast,
19121912
self.config_finder(),
19131913
handle.dupe(),
19141914
handle_to_import_from,
19151915
unknown_name,
19161916
import_format,
19171917
);
1918+
// If the symbol was already imported we get an empty edit; skip it.
1919+
if import_edit.insert_text.is_empty() {
1920+
continue;
1921+
}
19181922
let range = TextRange::at(import_edit.position, TextSize::new(0));
19191923
let title = format!("Insert import: `{}`", import_edit.display_text);
19201924
code_actions.push((
@@ -2440,21 +2444,28 @@ impl<'a> Transaction<'a> {
24402444
continue;
24412445
}
24422446
let module_description = handle_to_import_from.module().as_str().to_owned();
2443-
let (insert_text, additional_text_edits, imported_module) = {
2444-
let (position, insert_text, module_name) = insert_import_edit(
2447+
let (detail_text, additional_text_edits, imported_module) = {
2448+
let import_edit = insert_import_edit(
24452449
&ast,
24462450
self.config_finder(),
24472451
handle.dupe(),
24482452
handle_to_import_from,
24492453
&name,
24502454
import_format,
24512455
);
2456+
if import_edit.insert_text.is_empty() {
2457+
continue;
2458+
}
24522459
let import_text_edit = TextEdit {
24532460
range: module_info
24542461
.to_lsp_range(TextRange::at(import_edit.position, TextSize::new(0))),
24552462
new_text: import_edit.insert_text.clone(),
24562463
};
2457-
(insert_text, Some(vec![import_text_edit]), module_name)
2464+
(
2465+
Some(import_edit.display_text),
2466+
Some(vec![import_text_edit]),
2467+
import_edit.module_name,
2468+
)
24582469
};
24592470
let auto_import_label_detail = format!(" (import {imported_module})");
24602471
let (label, label_details) = if supports_completion_item_details {
@@ -2470,7 +2481,7 @@ impl<'a> Transaction<'a> {
24702481
};
24712482
completions.push(CompletionItem {
24722483
label,
2473-
detail: Some(insert_text),
2484+
detail: detail_text,
24742485
kind: export
24752486
.symbol_kind
24762487
.map_or(Some(CompletionItemKind::VARIABLE), |k| {

0 commit comments

Comments
 (0)