import copy
import logging
from pathlib import Path
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QWidget, QComboBox, QLabel, \
QPushButton, QDoubleSpinBox, QLineEdit, QCheckBox, QMessageBox, QMenu, QHeaderView, QFileDialog, QAction, QToolButton, \
QToolBar, QGridLayout
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QGuiApplication
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time
from sunpy.coordinates import Heliocentric, Helioprojective
from pyampp.geometry import observer_fov_box_to_world_corners, world_to_local_cartesian_mm
from pyampp.gxbox.boxutils import validate_number, read_b3d_h5, write_b3d_h5, update_line_seeds_h5
from pyampp.gxbox.observer_restore import resolve_observer_with_info
from pyampp.io import load_model
import pickle
import vtk
import pyvista as pv
from pyvistaqt import BackgroundPlotter
from PyQt5.QtWidgets import QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, QTreeView, \
QGroupBox
from PyQt5.QtGui import QStandardItemModel, QStandardItem
import numpy as np
logging.getLogger("sunpy").setLevel(logging.WARNING)
def _contains_viewer_field_payload(b3d: dict) -> bool:
if not isinstance(b3d, dict):
return False
for key in ("corona", "nlfff", "pot"):
group = b3d.get(key)
if isinstance(group, dict) and any(name in group for name in ("bx", "by", "bz", "bcube")):
return True
chromo = b3d.get("chromo")
if isinstance(chromo, dict):
if any(name in chromo for name in ("bx", "by", "bz", "bcube", "chromo_bcube")):
return True
return False
## todo is it possible to add 3d crosshair to the plotter?
## todo integrate NLFFF extrapolation module. https://github.com/Alexey-Stupishin/pyAMaFiL
[docs]
def minval(min_val):
"""
Rounds the minimum value to the nearest hundredth.
:param min_val: float
The minimum value to round.
:return: float
The rounded minimum value.
"""
return np.ceil(min_val * 100) / 100
[docs]
def maxval(max_val):
"""
Rounds the maximum value to the nearest hundredth.
:param max_val: float
The maximum value to round.
:return: float
The rounded maximum value.
"""
return np.floor(max_val * 100) / 100
def _decode_seed_value(value):
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
if isinstance(value, np.ndarray):
if value.shape == ():
return _decode_seed_value(value.item())
if value.size == 1:
return _decode_seed_value(value.reshape(-1)[0].item())
return value
[docs]
def generate_streamlines_from_line_seeds(box, b3dtype, line_seeds):
if not isinstance(line_seeds, dict):
return [], 0.0
x = np.asarray(box.grid_coords["x"].value, dtype=float)
y = np.asarray(box.grid_coords["y"].value, dtype=float)
z_native = np.asarray(box.grid_coords["z"].value, dtype=float)
z_base = float(z_native.min())
z = z_native - z_base
bx = np.asarray(box.b3d[b3dtype]["bx"])
by = np.asarray(box.b3d[b3dtype]["by"])
bz = np.asarray(box.b3d[b3dtype]["bz"])
grid = pv.ImageData()
grid.dimensions = (len(x), len(y), len(z))
grid.spacing = (x[1] - x[0], y[1] - y[0], z[1] - z[0])
grid.origin = (x.min(), y.min(), z.min())
grid["bx"] = bx.ravel(order="F")
grid["by"] = by.ravel(order="F")
grid["bz"] = bz.ravel(order="F")
grid["vectors"] = np.c_[grid["bx"], grid["by"], grid["bz"]]
streamlines = []
for key, seed_data in sorted(line_seeds.items()):
if key == "attrs" or not isinstance(seed_data, dict):
continue
seed_type = str(_decode_seed_value(seed_data.get("seed_type", "sphere")))
if seed_type != "sphere":
continue
center = np.asarray(seed_data.get("center", ()), dtype=float).reshape(-1)
if center.size != 3:
continue
radius = float(_decode_seed_value(seed_data.get("radius", 0.0)))
n_points = int(_decode_seed_value(seed_data.get("n_points", 100)))
sl = grid.streamlines(
vectors="vectors",
source_center=(float(center[0]), float(center[1]), float(center[2])),
source_radius=radius,
n_points=n_points,
integration_direction="both",
max_length=5000,
progress_bar=False,
)
if sl is not None and getattr(sl, "n_lines", 0) > 0:
streamlines.append(sl)
return streamlines, z_base
[docs]
class MagFieldViewer(BackgroundPlotter):
"""
A class to visualize the magnetic field of a box using PyVista. It inherits from the BackgroundPlotter class.
:param box: object
The box containing magnetic field data.
:param parent: object, optional
The parent object (default is None).
"""
def __init__(self, box, parent=None, box_norm_direction=None, box_view_up=None, time=None, b3dtype='nlfff', model_path=None, session_mode=None, *args, **kwargs):
# Build the scene fully before first paint; callers explicitly call .show().
[docs]
self.source_model_path = kwargs.pop("source_model_path", None)
kwargs.setdefault("show", False)
super().__init__(*args, **kwargs)
[docs]
self.model_path = model_path
[docs]
self.session_mode = session_mode or ("embedded" if parent is not None else "standalone")
[docs]
self.box_norm_direction = box_norm_direction
[docs]
self.box_view_up = box_view_up
[docs]
self.updating_flag = False # Flag to avoid recursion
[docs]
self.current_sphere_id = None
[docs]
self.next_sphere_id = 1
[docs]
self.current_sphere = None
[docs]
self.sphere_actor = None
[docs]
self.plane_actor = None
[docs]
self.bottom_slice_actor = None
[docs]
self.base_map_actor = None
[docs]
self.model_box_actor = None
[docs]
self.fov_box_actor = None
[docs]
self.streamlines_actor = None
[docs]
self.streamlines = None
[docs]
self.sphere_visible = True
[docs]
self.slice_visible = True
[docs]
self.base_map_visible = False
[docs]
self.model_box_visible = True
[docs]
self.fov_box_visible = True
[docs]
self.plane_visible = True
[docs]
self.previous_params = {}
[docs]
self.previous_valid_values = {}
[docs]
self.scalar_selector = None
[docs]
self.slice_checkbox = None
[docs]
self.slice_axis_selector = None
[docs]
self.slice_coord_label = None
[docs]
self.scalar_selector_items = []
[docs]
self.base_map_selector = None
[docs]
self.model_box_checkbox = None
[docs]
self.fov_box_checkbox = None
[docs]
self.base_map_checkbox = None
[docs]
self.base_map_items = []
[docs]
self.base_scalar_min = 0.0
[docs]
self.base_scalar_max = 0.0
[docs]
self.field_lines_control_group = None
[docs]
self.sphere_control_group = None
self._streamline_controls_enabled = True
self._los_label_text = ""
self._embedded_close_mode = None
self._close_hook_installed = False
[docs]
self.timestr = time.to_datetime().strftime("_%Y%m%dT%H%M%S") if time is not None else ''
self._restoring_line_seeds = False
self._original_line_seeds = copy.deepcopy(self.box.b3d.get("line_seeds")) if isinstance(self.box.b3d.get("line_seeds"), dict) else None
if b3dtype in ("pot", "nlfff"):
self.b3dtype = "corona"
self.corona_type = b3dtype
else:
self.b3dtype = b3dtype
self.corona_type = None
# self.sphere_checkbox = None
[docs]
self.grid_x = self.box.grid_coords['x'].value
[docs]
self.grid_y = self.box.grid_coords['y'].value
[docs]
self.grid_z = self.box.grid_coords['z'].value
self.grid_xmin, self.grid_xmax = minval(self.grid_x.min()), maxval(self.grid_x.max())
self.grid_ymin, self.grid_ymax = minval(self.grid_y.min()), maxval(self.grid_y.max())
self.grid_zmin, self.grid_zmax = minval(self.grid_z.min()), maxval(self.grid_z.max())
[docs]
self.grid_zbase = self.grid_zmin
self.grid_z = self.grid_z - self.grid_zbase
self.grid_zmin, self.grid_zmax = self.grid_z.min(), self.grid_z.max()
[docs]
self.slice_axis_positions = {
'x': float(np.mean(self.grid_x)),
'y': float(np.mean(self.grid_y)),
'z': 0.0,
}
[docs]
self.slice_coord_min = float(self.grid_zmin)
[docs]
self.slice_coord_max = float(self.grid_zmax)
[docs]
self.default_sph_cen_x = np.mean(self.grid_x)
[docs]
self.default_sph_cen_y = np.mean(self.grid_y)
[docs]
self.default_sph_cen_z = self.grid_zmin + np.ptp(self.grid_z) * 0.1
# self.init_ui()
self.init_grid()
self.add_widgets_to_window()
self.init_plot()
self.show_axes_all()
# Keep startup in observer LoS (do not override with isometric).
self.plane_checkbox.setChecked(False)
self.app_window.setWindowTitle("GxBox 3D viewer")
self.add_menu_options() # Add this line to include menu options
self.add_parallel_projection_button() # Add parallel projection button
if self.box_norm_direction is not None and self.box_view_up is not None:
self.add_observer_cam_button() # Add this line to include the observer cam button
self._apply_streamline_control_state()
self._restore_line_seeds_from_box()
self._install_embedded_close_hook()
## Connect the camera modified event to the callback function
# self.interactor.AddObserver('ModifiedEvent', self.print_camera_position)
[docs]
def print_camera_position(self, caller, event):
"""
Prints the camera position whenever the camera is moved.
"""
camera = self.camera
position = camera.position
focal_point = camera.focal_point
view_up = camera.up
print(f"Camera position: {position}")
print(f"Focal point: {focal_point}")
print(f"View up: {view_up}")
[docs]
def set_camera_to_LOS_direction(self):
"""
Set the camera to the observer line-of-sight.
The authoritative orientation comes only from the observer WCS and the
box frame. The optional FOV box is used only to choose the framing
target (center/zoom), never to define the camera basis itself.
"""
def normalize(v):
arr = np.asarray(v, dtype=float).reshape(-1)
if arr.size != 3:
return None
norm = np.linalg.norm(arr)
if not np.isfinite(norm) or norm <= 0:
return None
return arr / norm
box_frame = getattr(getattr(self.box, "_center", None), "frame", None)
frame_obs = getattr(self.box, "_frame_obs", None)
observer = getattr(frame_obs, "observer", None)
obstime = getattr(frame_obs, "obstime", None)
if observer is None and box_frame is not None:
observer = getattr(box_frame, "observer", None)
if obstime is None and box_frame is not None:
obstime = getattr(box_frame, "obstime", None)
observer_meta = self.box.b3d.get("observer", {}) if isinstance(getattr(self.box, "b3d", None), dict) else {}
if isinstance(observer_meta, dict):
observer_key = observer_meta.get("name")
if observer_key:
try:
resolved, warning, used_key = resolve_observer_with_info(
getattr(self.box, "b3d", None) if isinstance(getattr(self.box, "b3d", None), dict) else {},
observer_key,
obstime,
)
if warning:
print(f"Warning: {warning}")
if resolved is not None:
observer = resolved
except Exception:
pass
if box_frame is None or observer is None:
return
center = getattr(self.box, "_center", None)
if center is None:
return
if frame_obs is None:
return
step_arcsec = 10.0 * u.arcsec
step_mm = 1.0 * u.Mm
try:
# Use the observer WCS itself (same sky plane used by the 2D view)
# as the authoritative LOS basis.
fov_corners = self._fov_box_corners_local()
if isinstance(fov_corners, np.ndarray) and fov_corners.shape == (8, 3):
focal_point_arr = np.mean(fov_corners, axis=0)
ref_local = SkyCoord(
x=focal_point_arr[0] * u.Mm,
y=focal_point_arr[1] * u.Mm,
z=(focal_point_arr[2] + float(self.grid_zbase)) * u.Mm,
frame=box_frame,
)
else:
focal_point_arr = np.asarray(
[
0.5 * (self.grid_xmin + self.grid_xmax),
0.5 * (self.grid_ymin + self.grid_ymax),
0.5 * (self.grid_zmin + self.grid_zmax),
],
dtype=float,
)
ref_local = center
ref_obs = ref_local.transform_to(frame_obs)
ref_dist = getattr(ref_obs, "distance", None)
if ref_dist is None:
return
right_ref_obs = SkyCoord(
Tx=ref_obs.Tx + step_arcsec,
Ty=ref_obs.Ty,
distance=ref_dist,
frame=frame_obs,
)
up_ref_obs = SkyCoord(
Tx=ref_obs.Tx,
Ty=ref_obs.Ty + step_arcsec,
distance=ref_dist,
frame=frame_obs,
)
toward_ref_obs = SkyCoord(
Tx=ref_obs.Tx,
Ty=ref_obs.Ty,
distance=ref_dist - step_mm,
frame=frame_obs,
)
right_ref_local = right_ref_obs.transform_to(box_frame)
up_ref_local = up_ref_obs.transform_to(box_frame)
toward_ref_local = toward_ref_obs.transform_to(box_frame)
except Exception:
# Fallback to a best-effort HCC-based approximation if the WCS path fails.
frame_hcc = Heliocentric(observer=observer, obstime=obstime)
step = 1.0 * u.Mm
try:
center_hcc = center.transform_to(frame_hcc)
right_ref_local = SkyCoord(
x=center_hcc.x + step,
y=center_hcc.y,
z=center_hcc.z,
frame=frame_hcc,
).transform_to(box_frame)
up_ref_local = SkyCoord(
x=center_hcc.x,
y=center_hcc.y + step,
z=center_hcc.z,
frame=frame_hcc,
).transform_to(box_frame)
toward_ref_local = SkyCoord(
x=center_hcc.x,
y=center_hcc.y,
z=center_hcc.z + step,
frame=frame_hcc,
).transform_to(box_frame)
focal_point_arr = np.asarray(
[
0.5 * (self.grid_xmin + self.grid_xmax),
0.5 * (self.grid_ymin + self.grid_ymax),
0.5 * (self.grid_zmin + self.grid_zmax),
],
dtype=float,
)
except Exception:
return
def delta_from_focal(ref):
try:
return np.array(
[
float(ref.x.to_value(u.Mm) - focal_point_arr[0]),
float(ref.y.to_value(u.Mm) - focal_point_arr[1]),
float(ref.z.to_value(u.Mm) - (focal_point_arr[2] + float(self.grid_zbase))),
],
dtype=float,
)
except Exception:
try:
return np.array(
[
float(ref.x.to_value(u.Mm) - focal_point_arr[0]),
float(ref.y.to_value(u.Mm) - focal_point_arr[1]),
float(ref.z.to_value(u.Mm) - focal_point_arr[2]),
],
dtype=float,
)
except Exception:
return None
right_local = normalize(delta_from_focal(right_ref_local))
up_local = normalize(delta_from_focal(up_ref_local))
toward_observer_local = normalize(delta_from_focal(toward_ref_local))
if right_local is None or up_local is None or toward_observer_local is None:
return
# Re-orthogonalize to remove transform noise while keeping the observer
# WCS as the truth source.
right_local = normalize(right_local)
up_local = normalize(up_local - np.dot(up_local, right_local) * right_local)
if right_local is None or up_local is None:
return
toward_observer_local = normalize(np.cross(right_local, up_local))
if toward_observer_local is None:
return
up_local = normalize(np.cross(toward_observer_local, right_local))
if up_local is None:
return
view_local = -toward_observer_local
fov_corners = self._fov_box_corners_local()
if isinstance(fov_corners, np.ndarray) and fov_corners.shape == (8, 3):
focal_point_arr = np.mean(fov_corners, axis=0)
scene_span = max(
float(self.grid_xmax - self.grid_xmin),
float(self.grid_ymax - self.grid_ymin),
float(self.grid_zmax - self.grid_zmin),
1.0,
)
camera_distance = 4.0 * scene_span
focal_point = [
float(focal_point_arr[0]),
float(focal_point_arr[1]),
float(focal_point_arr[2]),
]
self.camera.up = [float(up_local[0]), float(up_local[1]), float(up_local[2])]
self.camera.focal_point = focal_point
self.camera.position = [
float(focal_point_arr[0] - view_local[0] * camera_distance),
float(focal_point_arr[1] - view_local[1] * camera_distance),
float(focal_point_arr[2] - view_local[2] * camera_distance),
]
self.camera.ParallelProjectionOn()
# Use the FOV box only to control framing in the already-defined LOS
# basis. Otherwise fall back to the full model extents.
if isinstance(fov_corners, np.ndarray) and fov_corners.shape == (8, 3):
centered = np.asarray(fov_corners, dtype=float) - focal_point_arr.reshape((1, 3))
half_h = float(np.max(np.abs(centered @ up_local)))
half_w = float(np.max(np.abs(centered @ right_local)))
try:
render_size = self.render_window.GetSize()
win_w = max(1, int(render_size[0]))
win_h = max(1, int(render_size[1]))
aspect = max(1e-6, float(win_w) / float(win_h))
except Exception:
aspect = 1.0
# Use a more generous framing margin so the full projected FOV
# rectangle remains visible in LoS parallel view.
parallel_scale = max(half_h, half_w / max(aspect, 1e-6), 1e-3) * 1.18
self.camera.parallel_scale = parallel_scale
else:
self.camera.parallel_scale = max(0.5 * scene_span, 1e-3)
try:
self.camera.SetClippingRange(0.1, max(10.0, 10.0 * camera_distance))
except Exception:
pass
if self.parallel_proj_button is not None:
self.parallel_proj_button.setChecked(True)
self._update_los_scene_label()
self.render()
@staticmethod
def _normalize_observer_key(observer_key):
raw = observer_key
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8", "ignore")
if isinstance(raw, np.ndarray) and raw.shape == ():
raw = raw.item()
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8", "ignore")
key = str(raw or "earth").strip().lower()
aliases = {
"sdo": "SDO",
"earth": "Earth",
"solo": "Solar Orbiter",
"solar orbiter": "Solar Orbiter",
"solar-orbiter": "Solar Orbiter",
"solarorbiter": "Solar Orbiter",
"stereo-a": "STEREO-A",
"stereo a": "STEREO-A",
"stereoa": "STEREO-A",
"stereo-b": "STEREO-B",
"stereo b": "STEREO-B",
"stereob": "STEREO-B",
}
return aliases.get(key, str(raw))
def _current_los_label(self) -> str:
observer_meta = self.box.b3d.get("observer", {}) if isinstance(self.box.b3d, dict) else {}
if not isinstance(observer_meta, dict):
observer_meta = {}
if "name" in observer_meta:
return self._normalize_observer_key(observer_meta.get("name"))
fov_box = observer_meta.get("fov_box", {})
if isinstance(fov_box, dict) and "observer_key" in fov_box:
return self._normalize_observer_key(fov_box.get("observer_key"))
return "Earth"
def _update_los_scene_label(self) -> None:
text = f"Observer LOS: {self._current_los_label()}"
self._los_label_text = text
try:
self.add_text(
text,
position="upper_left",
font_size=10,
color="black",
name="observer_los_label",
shadow=False,
)
except Exception:
pass
def _apply_startup_los_view(self) -> None:
self.set_camera_to_LOS_direction()
self.reset_camera_clipping_range()
self.render()
# Some platforms defer first paint; trigger a repaint to avoid waiting
# for manual mouse interaction before the scene stabilizes.
window = getattr(self, "app_window", None)
if window is not None:
try:
window.repaint()
except Exception:
pass
[docs]
def schedule_startup_los_view(self) -> None:
# Run after Qt has realized the window size so LoS framing matches the
# manual "LoS" toolbar action behavior.
QTimer.singleShot(0, self._apply_startup_los_view)
# A second pass catches late size/layout updates and removes startup jitter.
QTimer.singleShot(120, self._apply_startup_los_view)
[docs]
def ensure_window_visible(self) -> None:
window = getattr(self, "app_window", None)
if window is None:
return
try:
handle = window.windowHandle()
screen = handle.screen() if handle is not None else None
except Exception:
screen = None
if screen is None:
try:
screen = QGuiApplication.screenAt(window.frameGeometry().center())
except Exception:
screen = None
if screen is None:
screen = QGuiApplication.primaryScreen()
if screen is None:
return
try:
avail = screen.availableGeometry()
if not avail.isValid():
return
frame = window.frameGeometry()
width = min(max(frame.width(), 900), max(900, int(avail.width() * 0.92)))
height = min(max(frame.height(), 650), max(650, int(avail.height() * 0.92)))
left = frame.left()
top = frame.top()
if top < avail.top() or top > avail.bottom() - 80 or left > avail.right() - 80 or left < avail.left():
left = avail.left() + max(0, (avail.width() - width) // 2)
top = avail.top() + max(0, (avail.height() - height) // 2)
else:
left = min(max(left, avail.left()), max(avail.left(), avail.right() - width + 1))
top = min(max(top, avail.top()), max(avail.top(), avail.bottom() - height + 1))
window.resize(width, height)
window.move(left, top)
except Exception:
return
def _install_embedded_close_hook(self) -> None:
if self.session_mode != "embedded":
return
if self._close_hook_installed:
return
window = getattr(self, "app_window", None)
if window is None:
return
original_close_event = window.closeEvent
def _wrapped_close_event(event):
# Treat system-window close as "Cancel" for embedded mode.
if self._embedded_close_mode is None:
if self.parent is not None and hasattr(self.parent, "cancel_live_3d_edits"):
try:
self.parent.cancel_live_3d_edits()
except Exception:
pass
try:
original_close_event(event)
finally:
self._embedded_close_mode = None
window.closeEvent = _wrapped_close_event
self._close_hook_installed = True
[docs]
def toggle_parallel_projection(self, state):
"""
Toggles the parallel projection mode of the camera.
"""
if state:
self.camera.ParallelProjectionOn()
else:
self.camera.ParallelProjectionOff()
[docs]
def save_state(self,default_filename='box_view3d_state.pkl'):
"""
Saves the current state of spheres to a file. Prompts the user to select a directory and input a filename.
:param default_filename: str
The default name of the file to save the state data.
"""
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
filename, _ = QFileDialog.getSaveFileName(self.app_window, "Save State", default_filename,
"Pickle Files (*.pkl)", options=options)
if filename:
# Create a serializable version of the spheres
serializable_spheres = {
sphere_id: {
'center': sphere['center'],
'radius': sphere['radius'],
'n_points': sphere['n_points'],
'sphere_visible': sphere['sphere_visible']
}
for sphere_id, sphere in self.spheres.items()
}
with open(filename, 'wb') as f:
pickle.dump(serializable_spheres, f)
print(f"State saved to {filename}")
[docs]
def load_state(self, filename = None):
"""
Loads the state of spheres from a file. Prompts the user to select a file.
"""
if not isinstance(filename, str):
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
filename, _ = QFileDialog.getOpenFileName(self.app_window, "Load State", f'box_view3d_state{self.timestr}.pkl', "Pickle Files (*.pkl)",
options=options)
if filename:
with open(filename, 'rb') as f:
serializable_spheres = pickle.load(f)
self._on_clear_spheres()
# Recreate the spheres from the serializable data
for sphere_id, sphere_data in serializable_spheres.items():
# Update the sphere control widgets
print(sphere_id, sphere_data)
self.center_x_input.setText(f"{sphere_data['center'][0]:.2f}")
self.center_y_input.setText(f"{sphere_data['center'][1]:.2f}")
self.center_z_input.setText(f"{sphere_data['center'][2]:.2f}")
self.radius_input.setText(f"{sphere_data['radius']:.2f}")
self.n_points_input.setText(f"{sphere_data['n_points']}")
# Add the sphere using the _on_add_sphere method
self._on_add_sphere()
# Update the sphere visibility
# self.update_sphere_visibility(sphere_data['sphere_visible'])
print(f"State loaded from {filename}")
@staticmethod
def _decode_seed_type(value):
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
if isinstance(value, np.ndarray) and value.shape == ():
return MagFieldViewer._decode_seed_type(value.item())
return str(value)
@staticmethod
def _as_scalar(value, default):
if value is None:
return default
if isinstance(value, np.ndarray):
if value.shape == ():
return value.item()
if value.size == 1:
return value.reshape(-1)[0].item()
return value
def _serialize_line_seeds(self):
seeds = {}
for sphere_id in sorted(self.spheres):
sphere = self.spheres[sphere_id]
seeds[f"seed_{int(sphere_id)}"] = {
"seed_type": np.bytes_("sphere"),
"center": np.asarray(sphere["center"], dtype=float),
"radius": float(sphere["radius"]),
"n_points": np.int64(sphere["n_points"]),
"sphere_visible": np.uint8(1 if sphere.get("sphere_visible", True) else 0),
}
seeds["attrs"] = {
"schema_version": np.int64(1),
"current_seed_id": np.int64(self.current_sphere_id if self.current_sphere_id is not None else -1),
"next_seed_id": np.int64(self.next_sphere_id),
}
return seeds
def _persist_line_seeds(self):
if self._restoring_line_seeds:
return
if not hasattr(self.box, "b3d") or self.box.b3d is None:
self.box.b3d = {}
if self.spheres:
self.box.b3d["line_seeds"] = self._serialize_line_seeds()
else:
self.box.b3d.pop("line_seeds", None)
def _restore_line_seeds(self, line_seeds):
attrs = line_seeds.get("attrs", {}) if isinstance(line_seeds.get("attrs"), dict) else {}
current_seed_id = int(self._as_scalar(attrs.get("current_seed_id"), -1))
next_seed_id = int(self._as_scalar(attrs.get("next_seed_id"), 1))
seed_entries = []
for key, seed_data in line_seeds.items():
if key == "attrs" or not isinstance(seed_data, dict):
continue
try:
sphere_id = int(str(key).split("_")[-1])
except Exception:
continue
seed_type = self._decode_seed_type(seed_data.get("seed_type", b"sphere"))
if seed_type != "sphere":
continue
center = np.asarray(seed_data.get("center", ()), dtype=float).reshape(-1)
if center.size != 3:
continue
radius = float(self._as_scalar(seed_data.get("radius"), 0.0))
n_points = int(self._as_scalar(seed_data.get("n_points"), 100))
sphere_visible = bool(int(self._as_scalar(seed_data.get("sphere_visible"), 1)))
seed_entries.append((sphere_id, center, radius, n_points, sphere_visible))
self._restoring_line_seeds = True
try:
self._on_clear_spheres()
if not seed_entries:
return
seed_entries.sort(key=lambda item: item[0])
for sphere_id, center, radius, n_points, sphere_visible in seed_entries:
self.center_x_input.setText(f"{center[0]:.2f}")
self.center_y_input.setText(f"{center[1]:.2f}")
self.center_z_input.setText(f"{center[2]:.2f}")
self.radius_input.setText(f"{radius:.2f}")
self.n_points_input.setText(f"{n_points}")
self.next_sphere_id = sphere_id
self._on_add_sphere()
if not sphere_visible and self.current_sphere_id in self.spheres:
self.update_sphere_visibility(False)
if current_seed_id in self.spheres:
self.select_sphere(current_seed_id)
self.next_sphere_id = max(next_seed_id, self.next_sphere_id)
finally:
self._restoring_line_seeds = False
self._persist_line_seeds()
def _restore_line_seeds_from_box(self):
line_seeds = getattr(self.box, "b3d", {}).get("line_seeds")
if not isinstance(line_seeds, dict):
return
self._restore_line_seeds(line_seeds)
def _model_stage_tag(self):
meta = self.box.b3d.get("metadata", {}) if isinstance(getattr(self.box, "b3d", None), dict) else {}
model_id = _decode_seed_value(meta.get("id", "")).upper() if isinstance(meta, dict) else ""
for suffix in (
".POT.GEN.CHR", ".NAS.GEN.CHR",
".POT.CHR", ".NAS.CHR",
".POT.GEN", ".NAS.GEN",
".NAS", ".BND", ".POT", ".NONE",
):
if model_id.endswith(suffix):
return suffix[1:]
return ""
def _has_usable_streamline_field(self):
b3d = getattr(self.box, "b3d", None)
if not isinstance(b3d, dict):
return False
field_group = b3d.get(self.b3dtype)
if not isinstance(field_group, dict):
return False
attrs = field_group.get("attrs", {}) if isinstance(field_group.get("attrs"), dict) else {}
model_type = _decode_seed_value(attrs.get("model_type", "")).strip().lower()
if model_type == "none":
return False
if self._model_stage_tag() == "NONE":
return False
try:
bx = np.asarray(field_group["bx"])
by = np.asarray(field_group["by"])
bz = np.asarray(field_group["bz"])
except Exception:
return False
if bx.size == 0 or by.size == 0 or bz.size == 0:
return False
return bool(np.any(bx) or np.any(by) or np.any(bz))
def _apply_streamline_control_state(self):
enabled = self._has_usable_streamline_field()
self._streamline_controls_enabled = enabled
reason = None if enabled else "Field-line seeding is unavailable for NONE/no-field boxes."
for widget in (
self.field_lines_control_group,
self.sphere_control_group,
self.add_sphere_button,
self.delete_sphere_button,
self.clear_sphere_button,
self.viz_sphere_button,
self.tree_view,
self.center_x_input,
self.center_y_input,
self.center_z_input,
self.radius_input,
self.n_points_input,
self.lock_z_checkbox,
):
if widget is None:
continue
widget.setEnabled(enabled)
if reason:
widget.setToolTip(reason)
if not enabled and self.spheres:
self._on_clear_spheres()
def _on_add_sphere(self):
"""
Adds a new sphere to the viewer and tree view, hiding the current sphere.
"""
if not self._streamline_controls_enabled:
print("Sphere controls are disabled for this box because no volumetric field is available.")
return
# Create a new sphere and its streamlines
if self.current_sphere_id in self.spheres:
self.spheres[self.current_sphere_id]['sphere_actor'].Off()
sphere_id = self.next_sphere_id
center_x = float(self.center_x_input.text())
center_y = float(self.center_y_input.text())
# if keep_current_parms:
# pass
# else:
# center_x = np.mean(self.grid_x)
# center_y = np.mean(self.grid_y)
# center_z = self.grid_zmin + np.ptp(self.grid_z) * 0.1
center_z = float(self.center_z_input.text())
radius = float(self.radius_input.text())
n_points = int(self.n_points_input.text())
self.center_x_input.setText(f"{center_x:.2f}")
self.center_y_input.setText(f"{center_y:.2f}")
self.center_z_input.setText(f"{center_z:.2f}")
self.create_streamlines(center_x, center_y, center_z, radius, n_points)
self.current_sphere_id = sphere_id
self.spheres[sphere_id] = {
'center': (center_x, center_y, center_z),
'radius': radius,
'n_points': n_points,
'sphere_actor': self.sphere_actor,
'streamlines': self.streamlines,
'streamlines_actor': self.streamlines_actor,
'sphere_visible': True
}
self.streamlines_actor = None
self.streamlines = None
self.update_sphere_visibility(True)
# Add the new sphere to the tree view
sphere_item = QStandardItem(f"{self.next_sphere_id}")
self.sphere_items.appendRow(sphere_item)
self.tree_view.setCurrentIndex(self.sphere_items.indexFromItem(sphere_item))
self.next_sphere_id += 1
self._persist_line_seeds()
[docs]
def select_sphere(self, sphere_id):
sphere = self.spheres[sphere_id]
self.center_x_input.setText(f"{sphere['center'][0]:.2f}")
self.center_y_input.setText(f"{sphere['center'][1]:.2f}")
self.center_z_input.setText(f"{sphere['center'][2]:.2f}")
self.radius_input.setText(f"{sphere['radius']:.2f}")
self.n_points_input.setText(f"{sphere['n_points']}")
# self.spheres[self.current_sphere_id]['streamlines_actor'].SetVisibility(False)
if self.current_sphere_id in self.spheres:
self.spheres[self.current_sphere_id]['sphere_actor'].Off()
# Restore the streamlines actor for the selected sphere
streamlines_actor = sphere['streamlines_actor']
sphere_actor = sphere['sphere_actor']
# if streamlines_actor is not None:
# streamlines_actor.SetVisibility(True)
if sphere_actor is not None:
sphere_actor.On()
self.current_sphere_id = sphere_id
[docs]
def deselect_sphere(self):
"""
Handles the deselection of a sphere.
Clears the inputs and hides the current sphere and its streamlines.
"""
# self.center_x_input.clear()
# self.center_y_input.clear()
# self.center_z_input.clear()
# self.radius_input.clear()
# self.n_points_input.clear()
if self.current_sphere_id in self.spheres:
sphere_actor = self.spheres[self.current_sphere_id]['sphere_actor']
streamlines_actor = self.spheres[self.current_sphere_id]['streamlines_actor']
if sphere_actor is not None:
sphere_actor.Off()
if streamlines_actor is not None:
streamlines_actor.SetVisibility(False)
self.current_sphere_id = None
def _on_tb_selection_changed(self, selected, deselected):
indexes = selected.indexes()
if indexes:
item = self.sphere_items.itemFromIndex(indexes[0])
sphere_id = int(item.text())
self.select_sphere(sphere_id)
# else:
# self.deselect_sphere()
def _on_delete_sphere(self):
"""
Deletes the currently selected sphere in the tree view.
"""
if self.sphere_items.rowCount() > 0:
indexes = self.tree_view.selectionModel().selectedIndexes()
if indexes:
item = self.sphere_items.itemFromIndex(indexes[0])
sphere_id = int(item.text())
self.delete_sphere_from_tb(sphere_id)
if len(self.spheres) > 0:
self.update_sphere_visibility(True)
def _on_clear_spheres(self):
"""
Removes all spheres from the tree view and clears the corresponding data.
"""
while self.sphere_items.rowCount() > 0:
item = self.sphere_items.item(0)
sphere_id = int(item.text())
self.delete_sphere_from_tb(sphere_id)
self.spheres.clear()
self.current_sphere_id = None
self.next_sphere_id = 1
self._persist_line_seeds()
[docs]
def delete_sphere_from_tb(self, sphere_id):
sphere = self.spheres.pop(sphere_id, None)
if sphere and sphere['streamlines_actor'] is not None:
self.remove_actor(sphere['streamlines_actor'])
if sphere and sphere['streamlines'] is not None:
sphere['streamlines'] = None
if sphere and sphere['sphere_actor'] is not None:
sphere['sphere_actor'].Off()
sphere['sphere_actor'].RemoveAllObservers()
# Remove from tree view
nrows = self.sphere_items.rowCount()
for row in range(nrows):
item = self.sphere_items.item(row)
if item.text() == f"{sphere_id}":
self.sphere_items.removeRow(row)
break
# Update next_sphere_id to be 1 plus the largest sphere index
if nrows > 1:
max_sphere_id = max(int(self.sphere_items.item(row).text()) for row in range(nrows - 1))
else:
max_sphere_id = 0
self.next_sphere_id = max_sphere_id + 1
self._persist_line_seeds()
def _on_tb_right_click(self, pos):
index = self.tree_view.indexAt(pos)
if index.isValid():
item = self.sphere_items.itemFromIndex(index)
sphere_id = int(item.text())
menu = QMenu()
delete_action = menu.addAction("Delete")
action = menu.exec_(self.tree_view.viewport().mapToGlobal(pos))
if action == delete_action:
self.delete_sphere_from_tb(sphere_id)
@validate_number
def _on_center_x_input_returnPressed(self, widget):
"""
Handles the return pressed event for the center X input.
:param widget: QLineEdit
The input widget.
"""
self.update_sphere()
@validate_number
def _on_center_y_input_returnPressed(self, widget):
"""
Handles the return pressed event for the center Y input.
:param widget: QLineEdit
The input widget.
"""
self.update_sphere()
@validate_number
def _on_center_z_input_returnPressed(self, widget):
"""
Handles the return pressed event for the center Z input.
:param widget: QLineEdit
The input widget.
"""
self.update_sphere()
@validate_number
def _on_radius_input_returnPressed(self, widget):
"""
Handles the return pressed event for the radius input.
:param widget: QLineEdit
The input widget.
"""
self.update_sphere()
@validate_number
def _on_n_points_input_returnPressed(self, widget):
"""
Handles the return pressed event for the number of seeds input.
:param widget: QLineEdit
The input widget.
"""
self.update_sphere()
@validate_number
def _on_slice_z_input_returnPressed(self, widget):
"""
Handles the return pressed event for the slice Z input.
:param widget: QLineEdit
The input widget.
"""
self.slice_axis_positions[self.slice_axis] = float(widget.value()) if isinstance(widget, QDoubleSpinBox) else float(widget.text())
self.update_plot()
def _slice_axis_bounds(self, axis=None):
axis = (axis or self.slice_axis).lower()
if axis == 'x':
return float(self.grid_xmin), float(self.grid_xmax)
if axis == 'y':
return float(self.grid_ymin), float(self.grid_ymax)
return float(self.grid_zmin), float(self.grid_zmax)
def _slice_normal_vector(self, axis=None):
axis = (axis or self.slice_axis).lower()
if axis == 'x':
return (1.0, 0.0, 0.0)
if axis == 'y':
return (0.0, 1.0, 0.0)
return (0.0, 0.0, 1.0)
def _slice_origin(self, coord_value=None, axis=None):
axis = (axis or self.slice_axis).lower()
coord_value = self.slice_axis_positions.get(axis, 0.0) if coord_value is None else float(coord_value)
origin = [
0.5 * (self.grid_xmin + self.grid_xmax),
0.5 * (self.grid_ymin + self.grid_ymax),
0.5 * (self.grid_zmin + self.grid_zmax),
]
idx = {'x': 0, 'y': 1, 'z': 2}[axis]
origin[idx] = coord_value
return tuple(origin)
def _on_slice_axis_changed(self, axis_text):
new_axis = (axis_text or "Z").lower()
if self.slice_z_input is not None:
self.slice_axis_positions[self.slice_axis] = float(self.slice_z_input.value())
self.slice_axis = new_axis
self._set_slice_slider_range()
if self.slice_coord_label is not None:
self.slice_coord_label.setText(f"{new_axis.upper()} [Mm]:")
min_val, max_val = self._slice_axis_bounds(new_axis)
self.slice_coord_label.setToolTip(
f"Enter the {new_axis.upper()} coordinate for the slice in the range of {min_val:.2f} to {max_val:.2f} Mm."
)
if self.slice_z_input is not None:
self.slice_z_input.setToolTip(
f"Use arrows or mouse wheel. Range: {self.slice_coord_min:.2f} to {self.slice_coord_max:.2f} Mm."
)
self.update_plane()
self.update_plot()
@validate_number
def _on_vmin_input_returnPressed(self, widget):
"""
Handles the return pressed event for the Vmin input.
:param widget: QLineEdit
The input widget.
"""
self.update_plot()
@validate_number
def _on_vmax_input_returnPressed(self, widget):
"""
Handles the return pressed event for the Vmax input.
:param widget: QLineEdit
The input widget.
"""
self.update_plot()
@validate_number
def _on_base_vmin_input_returnPressed(self, widget):
self._update_base_map_from_controls()
@validate_number
def _on_base_vmax_input_returnPressed(self, widget):
self._update_base_map_from_controls()
def _on_base_map_changed(self, _map_name):
self._set_base_scalar_range(self.base_map_selector.currentText(), reset_values=True)
self._update_base_map_from_controls()
def _update_base_map_from_controls(self):
if self.base_map_selector is None:
return
base_map = self.base_map_selector.currentText()
if base_map not in self.grid_bottom.array_names:
self.update_base_map(None, 0.0, 1.0, False)
return
bmin = self.validate_input(
self.base_vmin_input,
self.base_scalar_min,
self.base_scalar_max,
self.previous_valid_values.get(self.base_vmin_input, self.base_scalar_min),
paired_widget=self.base_vmax_input,
paired_type='vmin',
)
bmax = self.validate_input(
self.base_vmax_input,
self.base_scalar_min,
self.base_scalar_max,
self.previous_valid_values.get(self.base_vmax_input, self.base_scalar_max),
paired_widget=self.base_vmin_input,
paired_type='vmax',
)
self.update_base_map(base_map, bmin, bmax, self.base_map_visible)
def _set_slice_slider_range(self):
self.slice_coord_min, self.slice_coord_max = self._slice_axis_bounds()
current_value = self.slice_axis_positions.get(self.slice_axis, self.slice_coord_min)
current_value = min(max(current_value, self.slice_coord_min), self.slice_coord_max)
self.slice_axis_positions[self.slice_axis] = current_value
self.slice_z_min = self.slice_coord_min
self.slice_z_max = self.slice_coord_max
if self.slice_z_input is None:
return
self.slice_z_input.blockSignals(True)
self.slice_z_input.setRange(self.slice_coord_min, self.slice_coord_max)
step = max((self.slice_coord_max - self.slice_coord_min) / 200.0, 0.1)
self.slice_z_input.setSingleStep(step)
self.slice_z_input.setValue(current_value)
self.slice_z_input.blockSignals(False)
def _set_scalar_range(self, scalar_name):
data = None
if scalar_name in self.grid.array_names:
data = self.grid[scalar_name]
elif self.bottom_name is not None and scalar_name == self.bottom_name:
data = self.grid_bottom[scalar_name]
if data is None:
return
self.scalar_min = float(np.nanmin(data))
self.scalar_max = float(np.nanmax(data))
if self.scalar_min == self.scalar_max:
self.scalar_min -= 1.0
self.scalar_max += 1.0
if self.vmin_input is not None and self.vmax_input is not None:
self.vmin_input.blockSignals(True)
self.vmax_input.blockSignals(True)
self.vmin_input.setRange(self.scalar_min, self.scalar_max)
self.vmax_input.setRange(self.scalar_min, self.scalar_max)
step = max((self.scalar_max - self.scalar_min) / 200.0, 1.0)
self.vmin_input.setSingleStep(step)
self.vmax_input.setSingleStep(step)
self.vmin_input.blockSignals(False)
self.vmax_input.blockSignals(False)
def _set_base_scalar_range(self, base_map_name, reset_values=False):
if self.base_vmin_input is None or self.base_vmax_input is None:
return
if base_map_name is None or base_map_name not in self.grid_bottom.array_names:
self.base_vmin_input.setEnabled(False)
self.base_vmax_input.setEnabled(False)
return
data = self.grid_bottom[base_map_name]
self.base_scalar_min = float(np.nanmin(data))
self.base_scalar_max = float(np.nanmax(data))
if self.base_scalar_min == self.base_scalar_max:
self.base_scalar_min -= 1.0
self.base_scalar_max += 1.0
self.base_vmin_input.blockSignals(True)
self.base_vmax_input.blockSignals(True)
self.base_vmin_input.setEnabled(True)
self.base_vmax_input.setEnabled(True)
self.base_vmin_input.setRange(self.base_scalar_min, self.base_scalar_max)
self.base_vmax_input.setRange(self.base_scalar_min, self.base_scalar_max)
step = max((self.base_scalar_max - self.base_scalar_min) / 200.0, 1.0e-3)
self.base_vmin_input.setSingleStep(step)
self.base_vmax_input.setSingleStep(step)
if reset_values:
self.base_vmin_input.setValue(self.base_scalar_min)
self.base_vmax_input.setValue(self.base_scalar_max)
self.previous_valid_values[self.base_vmin_input] = self.base_scalar_min
self.previous_valid_values[self.base_vmax_input] = self.base_scalar_max
self.base_vmin_input.blockSignals(False)
self.base_vmax_input.blockSignals(False)
[docs]
def init_grid(self):
x = self.grid_x
y = self.grid_y
z = self.grid_z
self.bottom_name = None
self.base_map_items = []
bx = self.box.b3d[self.b3dtype]['bx']
by = self.box.b3d[self.b3dtype]['by']
bz = self.box.b3d[self.b3dtype]['bz']
self.grid = pv.ImageData()
self.grid.dimensions = (len(x), len(y), len(z))
self.grid.spacing = (x[1] - x[0], y[1] - y[0], z[1] - z[0])
self.grid.origin = (x.min(), y.min(), z.min())
self.grid_dims = (len(x), len(y), len(z))
self.grid_spacing = self.grid.spacing
self.grid['bx'] = bx.ravel(order='F')
self.grid['by'] = by.ravel(order='F')
self.grid['bz'] = bz.ravel(order='F')
self.grid['vectors'] = np.c_[self.grid['bx'] , self.grid['by'], self.grid['bz']]
self.scalar_selector_items = ['bx', 'by', 'bz']
self.grid_bottom = pv.ImageData()
self.grid_bottom.dimensions = (len(x), len(y), 1)
self.grid_bottom.spacing = (x[1] - x[0], y[1] - y[0], 0)
self.grid_bottom.origin = (x.min(), y.min(), z.min())
base_group = self.box.b3d.get("base", {}) if isinstance(self.box.b3d, dict) else {}
if isinstance(base_group, dict):
for key in ("bx", "by", "bz", "ic", "chromo_mask"):
if key not in base_group:
continue
arr = np.asarray(base_group[key])
if arr.ndim != 2:
continue
# Base maps are stored as (y, x); grid_bottom expects flattened (x, y).
if arr.shape != (len(y), len(x)):
continue
self.grid_bottom[key] = arr.T.ravel(order='F')
self.base_map_items.append(key)
# Include compatible refmaps (e.g., Vert_current) when they match bottom dimensions.
refmaps_group = self.box.b3d.get("refmaps", {}) if isinstance(self.box.b3d, dict) else {}
if isinstance(refmaps_group, dict):
for ref_name, ref_obj in refmaps_group.items():
if not isinstance(ref_obj, dict) or "data" not in ref_obj:
continue
arr = np.asarray(ref_obj["data"])
if arr.ndim != 2:
continue
if arr.shape != (len(y), len(x)):
continue
key = str(ref_name)
if key in self.grid_bottom.array_names:
continue
self.grid_bottom[key] = arr.T.ravel(order='F')
if key not in self.base_map_items:
self.base_map_items.append(key)
if self.parent is not None and hasattr(self.parent, "mapBottomSelector") and hasattr(self.parent, "map_bottom"):
self.bottom_name = self.parent.mapBottomSelector.currentText()
self.grid_bottom[self.bottom_name] = self.parent.map_bottom.data.T.ravel(order='F')
if self.bottom_name not in self.base_map_items:
self.base_map_items.append(self.bottom_name)
self._set_slice_slider_range()
self._set_scalar_range(self.scalar)
[docs]
def init_plot(self):
"""
Initializes and displays the plot with the magnetic field data.
"""
self._set_slice_slider_range()
self._set_scalar_range(self.scalar)
self._set_base_scalar_range(self.base_map_selector.currentText() if self.base_map_selector is not None else "none",
reset_values=True)
def _val(widget):
if isinstance(widget, QDoubleSpinBox):
return float(widget.value())
return float(widget.text())
self.previous_valid_values = {
self.center_x_input: _val(self.center_x_input),
self.center_y_input: _val(self.center_y_input),
self.center_z_input: _val(self.center_z_input),
self.radius_input: _val(self.radius_input),
self.slice_z_input: _val(self.slice_z_input),
self.n_points_input: int(self.n_points_input.text()),
self.vmin_input: _val(self.vmin_input),
self.vmax_input: _val(self.vmax_input),
self.base_vmin_input: _val(self.base_vmin_input) if self.base_vmin_input is not None else -1000.0,
self.base_vmax_input: _val(self.base_vmax_input) if self.base_vmax_input is not None else 1000.0,
}
self.update_plot(init=True)
[docs]
def update_plot(self, init=False):
"""
Updates the plot based on the current input parameters.
"""
if self.updating_flag: # Check if already updating
return
self.updating_flag = True # Set the flag
# Get current parameters
center_x = self.validate_input(self.center_x_input, self.grid_xmin, self.grid_xmax,
self.previous_valid_values[self.center_x_input])
center_y = self.validate_input(self.center_y_input, self.grid_ymin, self.grid_ymax,
self.previous_valid_values[self.center_y_input])
center_z = self.validate_input(self.center_z_input, 0, self.grid_zmax,
self.previous_valid_values[self.center_z_input])
radius = self.validate_input(self.radius_input, 0, min(np.ptp(self.grid_x), np.ptp(self.grid_y), np.ptp(self.grid_z)),
self.previous_valid_values[self.radius_input])
n_points = self.validate_input(self.n_points_input, 1, 1000, self.previous_valid_values[self.n_points_input],
to_int=True)
if not init:
self.update_sphere()
self.update_plane()
scalar = self.scalar_selector.currentText()
self._set_scalar_range(scalar)
base_map = self.base_map_selector.currentText() if self.base_map_selector is not None else None
self._set_base_scalar_range(base_map, reset_values=False)
slice_z = self.validate_input(self.slice_z_input, self.slice_coord_min, self.slice_coord_max,
self.previous_valid_values[self.slice_z_input])
self.slice_axis_positions[self.slice_axis] = slice_z
vmin = self.validate_input(self.vmin_input, -5e4, 5e4, self.previous_valid_values[self.vmin_input],
paired_widget=self.vmax_input, paired_type='vmin')
vmax = self.validate_input(self.vmax_input, -5e4, 5e4, self.previous_valid_values[self.vmax_input],
paired_widget=self.vmin_input, paired_type='vmax')
if base_map in self.grid_bottom.array_names:
bmin = self.validate_input(
self.base_vmin_input,
self.base_scalar_min,
self.base_scalar_max,
self.previous_valid_values[self.base_vmin_input],
paired_widget=self.base_vmax_input,
paired_type='vmin',
)
bmax = self.validate_input(
self.base_vmax_input,
self.base_scalar_min,
self.base_scalar_max,
self.previous_valid_values[self.base_vmax_input],
paired_widget=self.base_vmin_input,
paired_type='vmax',
)
else:
bmin = vmin
bmax = vmax
sphere_visible = self.viz_sphere_button.isChecked()
plane_visible = self.plane_visible
use_interp = self.interp_checkbox.isChecked() if self.interp_checkbox is not None else True
slice_visible = self.slice_visible
base_map_visible = self.base_map_visible
model_box_visible = self.model_box_visible
fov_box_visible = self.fov_box_visible
# Create a dictionary of current parameters
current_params = {
"center_x": center_x,
"center_y": center_y,
"center_z": center_z,
"radius": radius,
"slice_z": slice_z,
"slice_axis": self.slice_axis,
"n_points": n_points,
"vmin": vmin,
"vmax": vmax,
"scalar": scalar,
"base_map": base_map,
"base_vmin": bmin,
"base_vmax": bmax,
"base_map_visible": base_map_visible,
"slice_visible": slice_visible,
"use_interp": use_interp,
"sphere_visible": sphere_visible,
"plane_visible": plane_visible,
"model_box_visible": model_box_visible,
"fov_box_visible": fov_box_visible,
}
# Check if parameters have changed
if current_params == self.previous_params:
self.updating_flag = False # Reset the flag
return
# Update only relevant objects based on parameter changes
if current_params['slice_z'] != self.previous_params.get('slice_z') or \
current_params['slice_axis'] != self.previous_params.get('slice_axis') or \
current_params['scalar'] != self.previous_params.get('scalar') or \
current_params['vmin'] != self.previous_params.get('vmin') or \
current_params['vmax'] != self.previous_params.get('vmax') or \
current_params['slice_visible'] != self.previous_params.get('slice_visible') or \
current_params['use_interp'] != self.previous_params.get('use_interp'):
self.update_slice(current_params['slice_axis'], current_params['slice_z'], current_params['scalar'], current_params['vmin'],
current_params['vmax'], current_params['use_interp'], current_params['slice_visible'])
if current_params['base_map'] != self.previous_params.get('base_map') or \
current_params['base_map_visible'] != self.previous_params.get('base_map_visible') or \
current_params['base_vmin'] != self.previous_params.get('base_vmin') or \
current_params['base_vmax'] != self.previous_params.get('base_vmax'):
self.update_base_map(
current_params['base_map'],
current_params['base_vmin'],
current_params['base_vmax'],
current_params['base_map_visible'],
)
if current_params['plane_visible'] != self.previous_params.get('plane_visible'):
self.update_plane_visibility(current_params['plane_visible'])
if current_params['model_box_visible'] != self.previous_params.get('model_box_visible') or init:
self.update_model_box(current_params['model_box_visible'], do_render=False)
if current_params['fov_box_visible'] != self.previous_params.get('fov_box_visible') or init:
self.update_fov_box(current_params['fov_box_visible'], do_render=False)
if not init and self.current_sphere_id in self.spheres:
if current_params['center_x'] != self.previous_params.get('center_x') or \
current_params['center_y'] != self.previous_params.get('center_y') or \
current_params['center_z'] != self.previous_params.get('center_z') or \
current_params['radius'] != self.previous_params.get('radius') or \
current_params['n_points'] != self.previous_params.get('n_points'):
self.update_streamlines(current_params['center_x'], current_params['center_y'],
current_params['center_z'],
current_params['radius'], current_params['n_points'])
if current_params['sphere_visible'] != self.previous_params.get('sphere_visible'):
self.update_sphere_visibility(current_params['sphere_visible'])
# Update previous parameters
self.previous_params = current_params
# self.plotter.show()
self.updating_flag = False # Reset the flag
self.reset_camera_clipping_range()
self.render()
[docs]
def update_slice(self, slice_axis, slice_z, scalar, vmin, vmax, use_interp=True, slice_visible=True):
"""
Updates the slice plot based on the given parameters.
:param slice_axis: str
The axis normal to the slice plane.
:param slice_z: float
The slice coordinate along the selected axis.
:param scalar: str
The scalar field to use for the slice.
:param vmin: float
The minimum value for the color scale.
:param vmax: float
The maximum value for the color scale.
"""
if not slice_visible:
if self.bottom_slice_actor is not None:
self.remove_actor(self.bottom_slice_actor)
self.bottom_slice_actor = None
return
axis = slice_axis.lower()
if slice_z==0:
slice_z = 1.0e-6
slice_origin = self._slice_origin(slice_z, axis)
if use_interp:
new_slice = self.grid.slice(normal=axis, origin=slice_origin)
pref = 'point'
scalar_name = scalar
scalars = scalar_name
else:
axis_idx = {'x': 0, 'y': 1, 'z': 2}[axis]
spacing_axis = self.grid_spacing[axis_idx]
idx = int(round((slice_z - self.grid.origin[axis_idx]) / spacing_axis))
idx = max(0, min(idx, self.grid_dims[axis_idx] - 1))
nx, ny, nz = self.grid_dims
if scalar in ('bx', 'by', 'bz'):
cube = self.box.b3d[self.b3dtype][scalar]
else:
cube = self.box.b3d[self.b3dtype]['bz']
cube = np.asarray(cube)
if cube.ndim == 4 and cube.shape[-1] == 3 and scalar in ('bx', 'by', 'bz'):
comp_idx = {'bx': 0, 'by': 1, 'bz': 2}[scalar]
cube = cube[..., comp_idx]
if cube.ndim != 3 and cube.size == nx * ny * nz:
cube = cube.reshape((nx, ny, nz), order='F')
if cube.ndim == 3:
if axis == 'x':
slice_data = cube[idx, :, :]
elif axis == 'y':
slice_data = cube[:, idx, :]
else:
slice_data = cube[:, :, idx]
elif cube.ndim == 2 and cube.size == nx * ny:
slice_data = cube
else:
# Fallback to interpolated slice if cube shape is unexpected
new_slice = self.grid.slice(normal=axis, origin=slice_origin)
pref = 'point'
scalar_name = scalar
scalars = scalar_name
if self.bottom_slice_actor is None:
self.bottom_slice_actor = self.add_mesh(new_slice, scalars=scalars, clim=(vmin, vmax), show_edges=False,
cmap='gray', pickable=False, show_scalar_bar=False,
preference=pref)
else:
self.remove_actor(self.bottom_slice_actor)
self.bottom_slice_actor = self.add_mesh(new_slice, scalars=scalars, clim=(vmin, vmax), show_edges=False,
cmap='gray', pickable=False, reset_camera=False,
show_scalar_bar=False, preference=pref)
return
expected_size = {
'x': ny * nz,
'y': nx * nz,
'z': nx * ny,
}[axis]
if slice_data.ndim != 2 or slice_data.size != expected_size:
if slice_data.size == expected_size:
if axis == 'x':
slice_data = slice_data.reshape((ny, nz), order='F')
elif axis == 'y':
slice_data = slice_data.reshape((nx, nz), order='F')
else:
slice_data = slice_data.reshape((nx, ny), order='F')
else:
# Fallback to interpolated slice if reshaping is impossible
new_slice = self.grid.slice(normal=axis, origin=slice_origin)
pref = 'point'
scalar_name = scalar
scalars = scalar_name
if self.bottom_slice_actor is None:
self.bottom_slice_actor = self.add_mesh(new_slice, scalars=scalars, clim=(vmin, vmax), show_edges=False,
cmap='gray', pickable=False, show_scalar_bar=False,
preference=pref)
else:
self.remove_actor(self.bottom_slice_actor)
self.bottom_slice_actor = self.add_mesh(new_slice, scalars=scalars, clim=(vmin, vmax), show_edges=False,
cmap='gray', pickable=False, reset_camera=False,
show_scalar_bar=False, preference=pref)
return
flat_slice = slice_data.ravel(order='F')
spacing_x = (self.grid_xmax - self.grid_xmin) / float(nx)
spacing_y = (self.grid_ymax - self.grid_ymin) / float(ny)
spacing_z = (self.grid_zmax - self.grid_zmin) / float(max(nz, 1))
scalar_name = "slice_scalar"
if axis == 'x':
new_slice = pv.ImageData(
dimensions=(1, ny + 1, nz + 1),
spacing=(1, spacing_y, spacing_z),
origin=(slice_z, self.grid_ymin, self.grid_zmin),
)
elif axis == 'y':
new_slice = pv.ImageData(
dimensions=(nx + 1, 1, nz + 1),
spacing=(spacing_x, 1, spacing_z),
origin=(self.grid_xmin, slice_z, self.grid_zmin),
)
else:
new_slice = pv.ImageData(
dimensions=(nx + 1, ny + 1, 1),
spacing=(spacing_x, spacing_y, 1),
origin=(self.grid_xmin, self.grid_ymin, slice_z),
)
new_slice.cell_data[scalar_name] = flat_slice
new_slice.set_active_scalars(scalar_name, preference='cell')
pref = 'cell'
scalars = scalar_name
if self.bottom_slice_actor is None:
self.bottom_slice_actor = self.add_mesh(new_slice, scalars=scalars, clim=(vmin, vmax), show_edges=False,
cmap='gray', pickable=False, show_scalar_bar=False,
preference=pref)
else:
self.remove_actor(self.bottom_slice_actor)
self.bottom_slice_actor = self.add_mesh(new_slice, scalars=scalars, clim=(vmin, vmax), show_edges=False,
cmap='gray', pickable=False, reset_camera=False,
show_scalar_bar=False, preference=pref)
[docs]
def update_base_map(self, base_map, vmin, vmax, base_map_visible=True):
"""
Render a fixed bottom-plane base map independently of the moving z-slice.
"""
if (not base_map_visible) or base_map is None or base_map not in self.grid_bottom.array_names:
if self.base_map_actor is not None:
self.remove_actor(self.base_map_actor)
self.base_map_actor = None
return
if self.base_map_actor is None:
self.base_map_actor = self.add_mesh(
self.grid_bottom,
scalars=base_map,
clim=(vmin, vmax),
show_edges=False,
cmap='gray',
pickable=False,
show_scalar_bar=False,
)
else:
self.remove_actor(self.base_map_actor)
self.base_map_actor = self.add_mesh(
self.grid_bottom,
scalars=base_map,
clim=(vmin, vmax),
show_edges=False,
cmap='gray',
pickable=False,
reset_camera=False,
show_scalar_bar=False,
)
@staticmethod
def _wireframe_box_from_points(points: np.ndarray):
pts = np.asarray(points, dtype=float).reshape((-1, 3))
if pts.shape != (8, 3):
return None
edges = (
(0, 1), (1, 3), (3, 2), (2, 0),
(4, 5), (5, 7), (7, 6), (6, 4),
(0, 4), (1, 5), (2, 6), (3, 7),
)
line_cells = []
for start, end in edges:
line_cells.extend((2, int(start), int(end)))
mesh = pv.PolyData()
mesh.points = pts
mesh.lines = np.asarray(line_cells, dtype=np.int32)
return mesh
def _model_box_mesh(self):
corners = self.box.model_box_corners_local_mm()
return self._wireframe_box_from_points(corners)
def _fov_box_corners_local(self):
observer_meta = self.box.b3d.get("observer", {}) if isinstance(self.box.b3d, dict) else {}
fov_box = observer_meta.get("fov_box") if isinstance(observer_meta, dict) else None
if not isinstance(fov_box, dict):
print("FOV box overlay: missing observer['fov_box'] metadata.")
return None
box_frame = getattr(getattr(self.box, "_center", None), "frame", None)
frame_obs = getattr(self.box, "_frame_obs", None)
obstime = getattr(frame_obs, "obstime", None)
observer = getattr(frame_obs, "observer", None)
observer_key = fov_box.get("observer_key")
if observer_key:
try:
resolved, _warning, _used = resolve_observer_with_info(
getattr(self.box, "b3d", None) if isinstance(getattr(self.box, "b3d", None), dict) else {},
observer_key,
obstime,
)
if resolved is not None:
observer = resolved
except Exception:
pass
if observer is None or box_frame is None:
print("FOV box overlay: incomplete or invalid observer['fov_box'] metadata.")
return None
try:
corners_world = observer_fov_box_to_world_corners(
xc_arcsec=float(fov_box["xc_arcsec"]),
yc_arcsec=float(fov_box["yc_arcsec"]),
xsize_arcsec=float(fov_box["xsize_arcsec"]),
ysize_arcsec=float(fov_box["ysize_arcsec"]),
zmin_mm=float(fov_box["zmin_mm"]),
zmax_mm=float(fov_box["zmax_mm"]),
observer=observer,
obstime=obstime,
target_frame=box_frame,
)
corners = world_to_local_cartesian_mm(corners_world, z_base_mm=float(self.grid_zbase))
except Exception:
corners = None
if corners is None:
print("FOV box overlay: incomplete or invalid observer['fov_box'] metadata.")
return None
return np.asarray(corners, dtype=float)
def _fov_box_mesh(self):
corners = self._fov_box_corners_local()
if corners is None:
return None
return self._wireframe_box_from_points(corners)
[docs]
def update_model_box(self, visible=True, do_render=True):
mesh = self._model_box_mesh()
if mesh is None:
if self.model_box_actor is not None:
self.remove_actor(self.model_box_actor)
self.model_box_actor = None
if do_render:
self.render()
return
if self.model_box_actor is None:
self.model_box_actor = self.add_mesh(
mesh.tube(radius=0.35),
color="red",
pickable=False,
reset_camera=False,
lighting=False,
)
self.model_box_actor.SetVisibility(bool(visible))
if do_render:
self.render()
[docs]
def update_fov_box(self, visible=True, do_render=True):
mesh = self._fov_box_mesh()
if mesh is None:
if self.fov_box_actor is not None:
self.remove_actor(self.fov_box_actor)
self.fov_box_actor = None
if do_render:
self.render()
return
if self.fov_box_actor is None:
self.fov_box_actor = self.add_mesh(
mesh.tube(radius=0.35),
color="deepskyblue",
pickable=False,
reset_camera=False,
scalars=None,
lighting=False,
)
self.fov_box_actor.SetVisibility(bool(visible))
if do_render:
self.render()
[docs]
def create_streamlines(self, center_x, center_y, center_z, radius, n_points):
self.streamlines = self.grid.streamlines(vectors='vectors', source_center=(center_x, center_y, center_z),
source_radius=radius, n_points=n_points, integration_direction='both',
max_length=5000, progress_bar=False)
if self.streamlines.n_points > 0:
tube = self.streamlines.tube(radius=0.1)
if tube.n_points <= 0:
self.streamlines_actor = None
print("No streamlines generated.")
return
if self.streamlines_actor is None:
self.streamlines_actor = self.add_mesh(tube, pickable=False,
reset_camera=False, show_scalar_bar=False)
else:
self.remove_actor(self.streamlines_actor)
self.streamlines_actor = self.add_mesh(tube, pickable=False,
reset_camera=False, show_scalar_bar=False)
else:
self.streamlines_actor = None
print("No streamlines generated.")
[docs]
def update_streamlines(self, center_x, center_y, center_z, radius, n_points):
"""
Updates the streamline plot based on the given parameters.
:param center_x: float
The X coordinate of the center of the sphere.
:param center_y: float
The Y coordinate of the center of the sphere.
:param center_z: float
The Z coordinate of the center of the sphere.
:param radius: float
The radius of the sphere.
:param n_points: int
The number of seed points for the streamlines.
"""
sphere = self.spheres[self.current_sphere_id]
streamlines_actor = sphere['streamlines_actor']
streamlines = self.grid.streamlines(vectors='vectors', source_center=(center_x, center_y, center_z),
source_radius=radius, n_points=n_points, integration_direction='both',
max_length=5000, progress_bar=False)
self.spheres[self.current_sphere_id]['streamlines'] = streamlines
if streamlines.n_points > 0:
tube = streamlines.tube(radius=0.1)
if tube.n_points <= 0:
if streamlines_actor is not None:
self.remove_actor(streamlines_actor)
self.spheres[self.current_sphere_id]['streamlines_actor'] = None
print("No streamlines generated.")
return
if streamlines_actor is None:
streamlines_actor = self.add_mesh(tube, pickable=False,
reset_camera=False, show_scalar_bar=False)
else:
self.remove_actor(streamlines_actor)
streamlines_actor = self.add_mesh(tube, pickable=False,
reset_camera=False, show_scalar_bar=False)
self.spheres[self.current_sphere_id]['streamlines_actor'] = streamlines_actor
else:
if streamlines_actor is not None:
self.remove_actor(streamlines_actor)
self.spheres[self.current_sphere_id]['streamlines_actor'] = None
print("No streamlines generated.")
[docs]
def update_sphere(self):
"""
Updates the sphere widget based on the current input parameters.
"""
if self.current_sphere_id in self.spheres:
if 'sphere_actor' in self.spheres[self.current_sphere_id]:
sphere_actor = self.spheres[self.current_sphere_id]['sphere_actor']
else:
sphere_actor = None
else:
sphere_actor = None
if sphere_actor is not None:
center_x = float(self.center_x_input.text())
center_y = float(self.center_y_input.text())
center_z = float(self.center_z_input.text())
radius = float(self.radius_input.text())
self.spheres[self.current_sphere_id]['center'] = (center_x, center_y, center_z)
self.spheres[self.current_sphere_id]['radius'] = radius
sphere_actor.SetCenter(self.spheres[self.current_sphere_id]['center'])
sphere_actor.SetRadius(self.spheres[self.current_sphere_id]['radius'])
self.update_plot()
self._persist_line_seeds()
[docs]
def on_lock_z_changed(self, state):
if state == Qt.Checked:
if self.current_sphere_id in self.spheres:
center_x = float(self.center_x_input.text())
center_y = float(self.center_y_input.text())
center_z = float(self.center_z_input.text())
radius = float(self.radius_input.text())
self.spheres[self.current_sphere_id]['sphere_actor'].Off()
self.spheres[self.current_sphere_id]['sphere_actor'].RemoveAllObservers()
self.spheres[self.current_sphere_id]['sphere_actor'] = self.add_sphere_widget(
self._on_sphere_constrained_move,
center=(center_x, center_y, center_z),
radius=radius,
theta_resolution=18,
phi_resolution=18,
style='wireframe'
)
else:
if self.current_sphere_id in self.spheres:
center = self.spheres[self.current_sphere_id]['center']
radius = self.spheres[self.current_sphere_id]['radius']
self.spheres[self.current_sphere_id]['sphere_actor'].Off()
self.spheres[self.current_sphere_id]['sphere_actor'].RemoveAllObservers()
self.spheres[self.current_sphere_id]['sphere_actor'] = self.add_sphere_widget(
self._on_sphere_moved,
center=center,
radius=radius,
theta_resolution=18,
phi_resolution=18,
style='wireframe'
)
[docs]
def update_sphere_visibility(self, sphere_visible):
"""
Updates the visibility of the sphere widget.
:param sphere_visible: bool
Whether the sphere widget is visible.
"""
if self.current_sphere_id not in self.spheres:
if self.viz_sphere_button.isChecked() != sphere_visible:
self.viz_sphere_button.disconnect()
self.viz_sphere_button.setChecked(sphere_visible)
self.viz_sphere_button.toggled.connect(self.toggle_sphere_visibility)
return
if self.current_sphere_id in self.spheres:
if 'sphere_actor' in self.spheres[self.current_sphere_id]:
sphere_actor = self.spheres[self.current_sphere_id]['sphere_actor']
else:
sphere_actor = None
else:
sphere_actor = None
self.spheres[self.current_sphere_id]['sphere_visible'] = sphere_visible
if sphere_visible:
if sphere_actor is None:
center_x = float(self.center_x_input.text())
center_y = float(self.center_y_input.text())
center_z = float(self.center_z_input.text())
radius = float(self.radius_input.text())
move_callback = self._on_sphere_constrained_move if self.lock_z_checkbox.isChecked() else self._on_sphere_moved
# move_callback = self._on_sphere_moved
sphere_actor = self.add_sphere_widget(move_callback,
center=(center_x, center_y, center_z),
radius=radius, theta_resolution=18, phi_resolution=18,
style='wireframe')
self.spheres[self.current_sphere_id]['sphere_actor'] = sphere_actor
# self.spheres[self.current_sphere_id]['initial_position'] = (center_x, center_y, center_z)
# sphere_actor.AddObserver("StartInteractionEvent", self.start_sphere_interaction)
else:
sphere_actor.On()
else:
if sphere_actor is not None:
sphere_actor.Off()
if self.viz_sphere_button.isChecked() != sphere_visible:
self.viz_sphere_button.disconnect()
self.viz_sphere_button.setChecked(sphere_visible)
self.viz_sphere_button.toggled.connect(self.toggle_sphere_visibility)
self._persist_line_seeds()
def _on_sphere_moved(self, center):
"""
Handles the event when the sphere widget is moved.
:param center: list of float
The new center coordinates of the sphere.
"""
self.center_x_input.setText(f"{center[0]:.2f}")
self.center_y_input.setText(f"{center[1]:.2f}")
self.center_z_input.setText(f"{center[2]:.2f}")
self.update_sphere()
def _on_sphere_constrained_move(self, center):
"""
Moves the sphere in the plane z = center_z_input when 'Lock Z' is checked.
:param center: list of float
The new center coordinates of the sphere.
"""
fixed_z = float(self.center_z_input.text())
# Update the sphere's position but constrain the Z coordinate to fixed_z
new_sphere_pos = [center[0], center[1], fixed_z]
# Update the sphere actor position
if self.spheres[self.current_sphere_id]['sphere_actor'] is not None:
self.spheres[self.current_sphere_id]['sphere_actor'].SetCenter(new_sphere_pos)
# Update the input fields
self.center_x_input.setText(f"{center[0]:.2f}")
self.center_y_input.setText(f"{center[1]:.2f}")
self.update_sphere()
[docs]
def toggle_sphere_visibility(self, state):
"""
Toggles the visibility of the sphere widget.
:param state: int
The state of the checkbox (checked or unchecked).
"""
if self.viz_sphere_button.isChecked():
self.viz_sphere_button.setToolTip("Hide the sphere")
else:
self.viz_sphere_button.setToolTip("Show the sphere")
self.sphere_visible = state == Qt.Checked
if len(self.spheres) > 0:
self.update_plot()
[docs]
def update_plane(self):
"""
Updates the plane widget based on the current input parameters.
"""
if self.plane_actor is not None:
slice_pos = float(self.slice_z_input.value()) if isinstance(self.slice_z_input, QDoubleSpinBox) else float(self.slice_z_input.text())
self.slice_axis_positions[self.slice_axis] = slice_pos
origin = self._slice_origin(slice_pos)
if hasattr(self.plane_actor, "SetNormal"):
self.plane_actor.SetNormal(self._slice_normal_vector())
self.plane_actor.SetOrigin(origin)
self.update_plot()
[docs]
def update_plane_visibility(self, plane_visible):
"""
Updates the visibility of the plane widget.
:param plane_visible: bool
Whether the plane widget is visible.
"""
if plane_visible:
if self.plane_actor is None:
slice_pos = float(self.slice_z_input.value()) if isinstance(self.slice_z_input, QDoubleSpinBox) else float(self.slice_z_input.text())
self.slice_axis_positions[self.slice_axis] = slice_pos
self.plane_actor = self.add_plane_widget(self._on_plane_moved, normal=self._slice_normal_vector(),
origin=self._slice_origin(slice_pos), bounds=(
self.grid_xmin, self.grid_xmax, self.grid_ymin, self.grid_ymax, self.grid_zmin, self.grid_zmax),
normal_rotation=False)
else:
if hasattr(self.plane_actor, "SetNormal"):
self.plane_actor.SetNormal(self._slice_normal_vector())
self.plane_actor.SetOrigin(self._slice_origin())
self.plane_actor.On()
else:
if self.plane_actor is not None:
self.plane_actor.Off()
def _on_plane_moved(self, normal, origin):
"""
Handles the event when the plane widget is moved.
:param normal: list of float
The normal vector of the plane.
:param origin: list of float
The new origin coordinates of the plane.
"""
coord = float(origin[{'x': 0, 'y': 1, 'z': 2}[self.slice_axis]])
self.slice_axis_positions[self.slice_axis] = coord
if isinstance(self.slice_z_input, QDoubleSpinBox):
self.slice_z_input.setValue(coord)
else:
self.slice_z_input.setText(f"{coord:.2f}")
self.update_plane()
[docs]
def toggle_plane_visibility(self, state):
"""
Toggles the visibility of the plane widget.
:param state: int
The state of the checkbox (checked or unchecked).
"""
self.plane_visible = state == Qt.Checked
self.update_plot()
[docs]
def toggle_slice_visibility(self, state):
"""
Toggles the visibility of the z-slice actor while preserving the selected scalar.
:param state: int
The state of the checkbox (checked or unchecked).
"""
self.slice_visible = state == Qt.Checked
self.update_plot()
[docs]
def toggle_base_map_visibility(self, state):
self.base_map_visible = state == Qt.Checked
base_map = self.base_map_selector.currentText() if self.base_map_selector is not None else "none"
self.update_base_map(
base_map,
float(self.base_vmin_input.value()) if self.base_vmin_input is not None else -1000.0,
float(self.base_vmax_input.value()) if self.base_vmax_input is not None else 1000.0,
self.base_map_visible,
)
self.previous_params["base_map_visible"] = self.base_map_visible
self.previous_params["base_map"] = base_map
self.reset_camera_clipping_range()
self.render()
[docs]
def toggle_model_box_visibility(self, state):
self.model_box_visible = state == Qt.Checked
self.update_model_box(self.model_box_visible)
self.previous_params["model_box_visible"] = self.model_box_visible
self.reset_camera_clipping_range()
self.render()
[docs]
def toggle_fov_box_visibility(self, state):
self.fov_box_visible = state == Qt.Checked
self.update_fov_box(self.fov_box_visible)
self.previous_params["fov_box_visible"] = self.fov_box_visible
self.reset_camera_clipping_range()
self.render()
[docs]
def send_streamlines(self):
"""
Sends the streamline data of all spheres to the parent object (if any).
"""
print(f"Sending streamlines to {self.parent}")
if self.parent is not None:
streamlines = []
for sphere in self.spheres.values():
if sphere['streamlines_actor'] is not None:
if sphere['streamlines'].n_lines > 0:
streamlines.append(sphere['streamlines'])
if streamlines != []:
self.parent.plot_fieldlines(streamlines, z_base=self.grid_zbase)
def _collect_streamlines(self):
streamlines = []
for sphere in self.spheres.values():
if sphere.get('streamlines_actor') is not None and sphere.get('streamlines') is not None:
if sphere['streamlines'].n_lines > 0:
streamlines.append(sphere['streamlines'])
return streamlines
def _collect_line_seeds_snapshot(self):
line_seeds = self.box.b3d.get("line_seeds")
if isinstance(line_seeds, dict):
return copy.deepcopy(line_seeds)
return None
[docs]
def save_current_model(self):
if self.session_mode == "embedded":
return False
if not self.model_path:
QMessageBox.warning(self.app_window, "Save Failed", "No writable .h5 model path is attached to this 3D viewer.")
return False
# Warn user before overwriting existing file
model_file = Path(self.model_path)
if model_file.exists():
btn = QMessageBox.question(
self.app_window,
"Overwrite Model",
f"Overwrite existing model file?\n{self.model_path}",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if btn != QMessageBox.Yes:
return False
try:
update_line_seeds_h5(str(self.model_path), self._collect_line_seeds_snapshot())
self._original_line_seeds = self._collect_line_seeds_snapshot()
print(f"Saved line seeds to {self.model_path}")
return True
except Exception as exc:
QMessageBox.warning(self.app_window, "Save Failed", f"Could not save line seeds to the current model:\n{exc}")
return False
def _close_window(self):
if hasattr(self, "app_window"):
self.app_window.close()
else:
self.close()
[docs]
def accept_and_close(self):
if self.session_mode == "embedded" and self.parent is not None and hasattr(self.parent, "commit_live_3d_edits"):
self.parent.commit_live_3d_edits(
self._collect_line_seeds_snapshot(),
self._collect_streamlines(),
z_base=self.grid_zbase,
)
self._embedded_close_mode = "accept"
self._close_window()
return
if self.session_mode == "pipeline_child":
if self.save_current_model():
self._close_window()
return
self._close_window()
[docs]
def cancel_and_close(self):
if self.session_mode == "embedded" and self.parent is not None and hasattr(self.parent, "cancel_live_3d_edits"):
self.parent.cancel_live_3d_edits()
self._embedded_close_mode = "cancel"
self._close_window()
[docs]
def undo_and_restore(self):
self._restore_line_seeds(self._original_line_seeds if isinstance(self._original_line_seeds, dict) else {})
[docs]
def save_box(self):
box_dims_str = 'x'.join(map(str, self.box.dims_pix))
# Prefer canonical source model name over generic timestamp name
suggested_stem = "b3d_data"
candidate = self.model_path if self.model_path is not None else self.source_model_path
if candidate is not None:
try:
p = Path(candidate).expanduser()
suggested_stem = p.stem
except Exception:
pass
default_filename = f'{suggested_stem}.h5'
default_path = Path.cwd() / default_filename
parent_widget = self.app_window if hasattr(self, "app_window") else None
filename = QFileDialog.getSaveFileName(
parent_widget,
"Save Box As",
str(default_path),
"HDF5 Files (*.h5)",
)[0]
if not filename:
return
# Warn user if file already exists
filepath = Path(filename)
if filepath.exists():
from PyQt5.QtWidgets import QMessageBox
btn = QMessageBox.question(
parent_widget,
"File Exists",
f"File already exists:\n{filename}\n\nOverwrite?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if btn != QMessageBox.Yes:
return
write_b3d_h5(filename, self.box.b3d)
self.model_path = Path(filename)
if self.save_model_button is not None and not self.save_model_button.isEnabled():
self.save_model_button.setEnabled(True)
self.save_model_button.setToolTip("Save the current seed state back into the opened model file.")
self.save_model_button.clicked.connect(self.save_current_model)
[docs]
def load_box(self):
default_filename = "b3d_data.h5"
filename = QFileDialog.getOpenFileName(self, "Load Box", default_filename, "Model Files (*.h5 *.sav)")[0]
if not filename:
return
try:
from pyampp.gxbox.view_h5 import prepare_model_for_viewer
box, obs_time, b3dtype, temp_h5_path = prepare_model_for_viewer(filename)
except Exception as exc:
QMessageBox.critical(self.app_window, "Load Failed", f"Could not read model file:\n{exc}")
return
model_path = Path(filename).expanduser().resolve()
self.box = box
self.b3dtype = b3dtype
self.model_path = model_path if temp_h5_path is None else None
self.source_model_path = model_path
self.timestr = obs_time.to_datetime().strftime("_%Y%m%dT%H%M%S") if obs_time is not None else ""
self._on_clear_spheres()
self.init_grid()
self._apply_streamline_control_state()
self._original_line_seeds = copy.deepcopy(self.box.b3d.get("line_seeds")) if isinstance(self.box.b3d.get("line_seeds"), dict) else None
self._restore_line_seeds_from_box()
self.previous_params = {}
self.update_plot()