Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions radiacode/bytes_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import struct
from typing import Any


class BytesBuffer:
Expand All @@ -13,7 +14,7 @@ class BytesBuffer:
data (bytes): The binary data to read from.
"""

def __init__(self, data: bytes):
def __init__(self, data: bytes) -> None:
"""Initialize the BytesBuffer with binary data.

Args:
Expand All @@ -38,7 +39,7 @@ def data(self) -> bytes:
"""
return self._data[self._pos :]

def unpack(self, fmt: str) -> tuple:
def unpack(self, fmt: str) -> tuple[Any, ...]:
"""Unpack binary data according to the given format string.

Uses the struct module's format syntax to unpack binary data from the
Expand All @@ -54,7 +55,7 @@ def unpack(self, fmt: str) -> tuple:
Raises:
Exception: If there isn't enough data remaining in the buffer for the requested format.
"""
sz = struct.calcsize(fmt)
sz: int = struct.calcsize(fmt)
if self._pos + sz > len(self._data):
raise ValueError(f'BytesBuffer: {sz} bytes required for {fmt}, but have only {len(self._data) - self._pos}')
self._pos += sz
Expand All @@ -73,5 +74,5 @@ def unpack_string(self) -> str:
Exception: If there isn't enough data in the buffer.
UnicodeDecodeError: If the string data cannot be decoded as ASCII.
"""
slen = self.unpack('<B')[0]
slen: int = self.unpack('<B')[0]
return self.unpack(f'<{slen}s')[0].decode('ascii')
2 changes: 1 addition & 1 deletion radiacode/decoders/databuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def decode_VS_DATA_BUF(
next_seq = None
while br.size() >= 7:
seq, eid, gid, ts_offset = br.unpack('<BBBi')
dt = base_time + datetime.timedelta(milliseconds=ts_offset * 10)
dt: datetime.datetime = base_time + datetime.timedelta(milliseconds=ts_offset * 10)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it required by Pyrefly to annotate the types of all variables?
In my opinion, annotations like nvsfr: int = len(vsfr_ids) or target_date: str = r.unpack_string() make the code feel overloaded - these types are trivial for a reader to infer and should also be easy for a static analyzer to deduce.
I do like annotations such as ret: list[int] = [] and annotations in function arguments.

Even if these annotations are mandatory for Pyrefly, I’m open to including them in this repo if you find them useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are levels of checking and annotation you can tune; there were some variables that didn't appear to need type annoations. 🤷🏻

I agree that some things are trivial to infer - len(thing) -> int - and that function argument and return value annotations are useful. It seems like the philosophy of behind pyre and pyrefly is to annotate everything (handy when your IDE can complain when you stuff the wrong data type into a variable) since python isn't strictly typed, so requiring type annotations is a way to make you be more careful.

I think you can close this PR, or just keep it for discussion about type checking and annotations in general; I'll do a cleaner one to ensure that as many functions as possible have annotated arguments and return values, at least where it's not ugly.

Copy link

@yangdanny97 yangdanny97 May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pyrefly infers and displays types for local variables mostly as a development aid, they don't have to be added to the code unless it helps with clarity.

It's probably a good idea to add annotations for function params and returns though, and Pyrefly will help infer the latter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyrefly also doesn't guess the annotation properly depending on how you import something (IMO).

import datetime

now = datetime.datetime.now()

the inferred type of now is datetime, but then datetime.now doesn't exist. If was to say from datetime import datetime then now:datetime = datetime.now() would do what I want, but since I said import datetime I'd have expected pyrefly to infer now:datetime.datetime.

if next_seq is not None and next_seq != seq:
if not ignore_errors:
print(f'seq jump while processing {eid=} {gid=}, expect:{next_seq}, got:{seq} {br.size()=}')
Expand Down
13 changes: 7 additions & 6 deletions radiacode/decoders/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@


def decode_counts_v0(br: BytesBuffer) -> list[int]:
ret = []
ret: list[int] = []
while br.size() > 0:
ret.append(br.unpack('<I')[0])
return ret


def decode_counts_v1(br: BytesBuffer) -> list[int]:
ret = []
ret: list[int] = []
last = 0
while br.size() > 0:
u16 = br.unpack('<H')[0]
cnt = (u16 >> 4) & 0x0FFF
vlen = u16 & 0x0F
u16: int = br.unpack('<H')[0]
cnt: int = (u16 >> 4) & 0x0FFF
vlen: int = u16 & 0x0F
v: int = 0
for _ in range(cnt):
if vlen == 0:
v = 0
Expand All @@ -44,7 +45,7 @@ def decode_RC_VS_SPECTRUM(br: BytesBuffer, format_version: int) -> Spectrum:
ts, a0, a1, a2 = br.unpack('<Ifff')

assert format_version in {0, 1}, f'unspported format_version={format_version}'
counts = decode_counts_v0(br) if format_version == 0 else decode_counts_v1(br)
counts: list[int] = decode_counts_v0(br) if format_version == 0 else decode_counts_v1(br)

return Spectrum(
duration=datetime.timedelta(seconds=ts),
Expand Down
110 changes: 55 additions & 55 deletions radiacode/radiacode.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ def base_time(self) -> datetime.datetime:
return self._base_time

def execute(self, reqtype: COMMAND, args: Optional[bytes] = None) -> BytesBuffer:
req_seq_no = 0x80 + self._seq
req_seq_no: int = 0x80 + self._seq
self._seq = (self._seq + 1) % 32

req_header = struct.pack('<HBB', int(reqtype), 0, req_seq_no)
request = req_header + (args or b'')
full_request = struct.pack('<I', len(request)) + request
req_header: bytes = struct.pack('<HBB', int(reqtype), 0, req_seq_no)
request: bytes = req_header + (args or b'')
full_request: bytes = struct.pack('<I', len(request)) + request

response = self._connection.execute(full_request)
resp_header = response.unpack('<4s')[0]
assert req_header == resp_header, f'req={req_header.hex()} resp={resp_header.hex()}'
return response

def read_request(self, command_id: int | VS | VSFR) -> BytesBuffer:
r = self.execute(COMMAND.RD_VIRT_STRING, struct.pack('<I', int(command_id)))
r: BytesBuffer = self.execute(COMMAND.RD_VIRT_STRING, struct.pack('<I', int(command_id)))
retcode, flen = r.unpack('<II')
assert retcode == 1, f'{command_id}: got retcode {retcode}'
# HACK: workaround for new firmware bug(?)
Expand All @@ -134,8 +134,8 @@ def read_request(self, command_id: int | VS | VSFR) -> BytesBuffer:
return r

def write_request(self, command_id: int | VSFR, data: Optional[bytes] = None) -> None:
r = self.execute(COMMAND.WR_VIRT_SFR, struct.pack('<I', int(command_id)) + (data or b''))
retcode = r.unpack('<I')[0]
r: BytesBuffer = self.execute(COMMAND.WR_VIRT_SFR, struct.pack('<I', int(command_id)) + (data or b''))
retcode: int = r.unpack('<I')[0]
assert retcode == 1
assert r.size() == 0

Expand All @@ -153,7 +153,7 @@ def batch_read_vsfrs(self, vsfr_ids: list[VSFR], unpack_format: str) -> list[int
Repeat count is not supported (use "ffff" instead of "4f") as the length
of the unpack_format string must equal the number of VSFRs being fetched.
"""
nvsfr = len(vsfr_ids)
nvsfr: int = len(vsfr_ids)
if nvsfr == 0:
raise ValueError('No VSFRs specified')

Expand All @@ -164,23 +164,23 @@ def batch_read_vsfrs(self, vsfr_ids: list[VSFR], unpack_format: str) -> list[int
if not (isinstance(unpack_format, str) and len(unpack_format) == nvsfr):
raise ValueError(f'invalid unpack_format `{unpack_format}`')

msg = [struct.pack('<I', nvsfr)]
msg: list[bytes] = [struct.pack('<I', nvsfr)]
msg.extend([struct.pack('<I', int(c)) for c in vsfr_ids])
r = self.execute(COMMAND.RD_VIRT_SFR_BATCH, b''.join(msg))
r: BytesBuffer = self.execute(COMMAND.RD_VIRT_SFR_BATCH, b''.join(msg))

valid_flags = r.unpack('<I')[0]
expected_flags = (1 << nvsfr) - 1
valid_flags: int = r.unpack('<I')[0]
expected_flags: int = (1 << nvsfr) - 1
if valid_flags != expected_flags:
raise ValueError(f'Unexpected validity flags, bad vsfr_id? {valid_flags:08b} != {expected_flags:08b}')

ret = [r.unpack(f'<{unpack_format[i]}')[0] for i in range(nvsfr)]
ret: list[int | float] = [r.unpack(f'<{unpack_format[i]}')[0] for i in range(nvsfr)]

assert r.size() == 0
return ret

def status(self) -> str:
r = self.execute(COMMAND.GET_STATUS)
flags = r.unpack('<I')
r: BytesBuffer = self.execute(COMMAND.GET_STATUS)
flags: int = r.unpack('<I')[0]
assert r.size() == 0
return f'status flags: {flags}'

Expand All @@ -192,14 +192,14 @@ def set_local_time(self, dt: datetime.datetime) -> None:
The time components used are: year, month, day, hour, minute, second.
Microseconds are ignored.
"""
d = struct.pack('<BBBBBBBB', dt.day, dt.month, dt.year - 2000, 0, dt.second, dt.minute, dt.hour, 0)
d: bytes = struct.pack('<BBBBBBBB', dt.day, dt.month, dt.year - 2000, 0, dt.second, dt.minute, dt.hour, 0)
self.execute(COMMAND.SET_TIME, d)

def fw_signature(self) -> str:
r = self.execute(COMMAND.FW_SIGNATURE)
signature = r.unpack('<I')[0]
filename = r.unpack_string()
idstring = r.unpack_string()
r: BytesBuffer = self.execute(COMMAND.FW_SIGNATURE)
signature: int = r.unpack('<I')[0]
filename: str = r.unpack_string()
idstring: str = r.unpack_string()
return f'Signature: {signature:08X}, FileName="{filename}", IdString="{idstring}"'

def fw_version(self) -> tuple[tuple[int, int, str], tuple[int, int, str]]:
Expand All @@ -210,11 +210,11 @@ def fw_version(self) -> tuple[tuple[int, int, str], tuple[int, int, str]]:
- Boot version: (major, minor, date string)
- Target version: (major, minor, date string)
"""
r = self.execute(COMMAND.GET_VERSION)
r: BytesBuffer = self.execute(COMMAND.GET_VERSION)
boot_minor, boot_major = r.unpack('<HH')
boot_date = r.unpack_string()
boot_date: str = r.unpack_string()
target_minor, target_major = r.unpack('<HH')
target_date = r.unpack_string()
target_date: str = r.unpack_string()
assert r.size() == 0
return ((boot_major, boot_minor, boot_date), (target_major, target_minor, target_date.strip('\x00')))

Expand All @@ -225,19 +225,19 @@ def hw_serial_number(self) -> str:
str: Hardware serial number formatted as hyphen-separated hexadecimal groups
(e.g. "12345678-9ABCDEF0")
"""
r = self.execute(COMMAND.GET_SERIAL)
serial_len = r.unpack('<I')[0]
r: BytesBuffer = self.execute(COMMAND.GET_SERIAL)
serial_len: int = r.unpack('<I')[0]
assert serial_len % 4 == 0
serial_groups = [r.unpack('<I')[0] for _ in range(serial_len // 4)]
serial_groups: list[int] = [r.unpack('<I')[0] for _ in range(serial_len // 4)]
assert r.size() == 0
return '-'.join(f'{v:08X}' for v in serial_groups)

def configuration(self) -> str:
r = self.read_request(VS.CONFIGURATION)
r: BytesBuffer = self.read_request(VS.CONFIGURATION)
return r.data().decode('cp1251')

def text_message(self) -> str:
r = self.read_request(VS.TEXT_MESSAGE)
r: BytesBuffer = self.read_request(VS.TEXT_MESSAGE)
return r.data().decode('ascii')

def serial_number(self) -> str:
Expand All @@ -246,11 +246,11 @@ def serial_number(self) -> str:
Returns:
str: The device serial number as an ASCII string
"""
r = self.read_request(VS.SERIAL_NUMBER)
r: BytesBuffer = self.read_request(VS.SERIAL_NUMBER)
return r.data().decode('ascii')

def commands(self) -> str:
br = self.read_request(VS.SFR_FILE)
br: BytesBuffer = self.read_request(VS.SFR_FILE)
return br.data().decode('ascii')

# called with 0 after init!
Expand All @@ -264,7 +264,7 @@ def device_time(self, v: int) -> None:

def data_buf(self) -> list[DoseRateDB | RareData | RealTimeData | RawData | Event]:
"""Get buffered measurement data from the device."""
r = self.read_request(VS.DATA_BUF)
r: BytesBuffer = self.read_request(VS.DATA_BUF)
return decode_VS_DATA_BUF(r, self._base_time)

def spectrum(self) -> Spectrum:
Expand All @@ -273,7 +273,7 @@ def spectrum(self) -> Spectrum:
Returns:
Spectrum: Object containing the current spectrum data
"""
r = self.read_request(VS.SPECTRUM)
r: BytesBuffer = self.read_request(VS.SPECTRUM)
return decode_RC_VS_SPECTRUM(r, self._spectrum_format_version)

def spectrum_accum(self) -> Spectrum:
Expand All @@ -282,7 +282,7 @@ def spectrum_accum(self) -> Spectrum:
Returns:
Spectrum: Object containing the accumulated spectrum data
"""
r = self.read_request(VS.SPEC_ACCUM)
r: BytesBuffer = self.read_request(VS.SPEC_ACCUM)
return decode_RC_VS_SPECTRUM(r, self._spectrum_format_version)

def dose_reset(self) -> None:
Expand All @@ -298,8 +298,8 @@ def spectrum_reset(self) -> None:
This clears the current spectrum data buffer, effectively resetting the spectrum
measurement to start fresh.
"""
r = self.execute(COMMAND.WR_VIRT_STRING, struct.pack('<II', int(VS.SPECTRUM), 0))
retcode = r.unpack('<I')[0]
r: BytesBuffer = self.execute(COMMAND.WR_VIRT_STRING, struct.pack('<II', int(VS.SPECTRUM), 0))
retcode: int = r.unpack('<I')[0]
assert retcode == 1
assert r.size() == 0

Expand All @@ -312,7 +312,7 @@ def energy_calib(self) -> list[float]:
- a1: Linear term coefficient (keV/channel)
- a2: Quadratic term coefficient (keV/channel^2)
"""
r = self.read_request(VS.ENERGY_CALIB)
r: BytesBuffer = self.read_request(VS.ENERGY_CALIB)
return list(r.unpack('<fff'))

def set_energy_calib(self, coef: list[float]) -> None:
Expand All @@ -325,12 +325,12 @@ def set_energy_calib(self, coef: list[float]) -> None:
- a2: Quadratic term coefficient (keV/channel^2)
"""
assert len(coef) == 3
pc = struct.pack('<fff', *coef)
r = self.execute(COMMAND.WR_VIRT_STRING, struct.pack('<II', int(VS.ENERGY_CALIB), len(pc)) + pc)
retcode = r.unpack('<I')[0]
pc: bytes = struct.pack('<fff', *coef)
r: BytesBuffer = self.execute(COMMAND.WR_VIRT_STRING, struct.pack('<II', int(VS.ENERGY_CALIB), len(pc)) + pc)
retcode: int = r.unpack('<I')[0]
assert retcode == 1

def set_language(self, lang='ru') -> None:
def set_language(self, lang: str = 'ru') -> None:
"""Set the device interface language.

Args:
Expand Down Expand Up @@ -370,7 +370,7 @@ def set_sound_ctrl(self, ctrls: list[CTRL]) -> None:
Args:
ctrls: List of CTRL enum values specifying which events should trigger sounds
"""
flags = 0
flags: int = 0
for c in ctrls:
flags |= int(c)
self.write_request(VSFR.SOUND_CTRL, struct.pack('<I', flags))
Expand All @@ -383,7 +383,7 @@ def set_display_off_time(self, seconds: int) -> None:
Must be one of: 5, 10, 15, or 30 seconds.
"""
assert seconds in {5, 10, 15, 30}
v = 3 if seconds == 30 else (seconds // 5) - 1
v: int = 3 if seconds == 30 else (seconds // 5) - 1
self.write_request(VSFR.DISP_OFF_TIME, struct.pack('<I', v))

def set_display_brightness(self, brightness: int) -> None:
Expand Down Expand Up @@ -419,7 +419,7 @@ def set_vibro_ctrl(self, ctrls: list[CTRL]) -> None:

def get_alarm_limits(self) -> AlarmLimits:
"Retrieve the alarm limits"
regs = [
regs: list[VSFR] = [
VSFR.CR_LEV1_cp10s,
VSFR.CR_LEV2_cp10s,
VSFR.DR_LEV1_uR_h,
Expand All @@ -430,10 +430,10 @@ def get_alarm_limits(self) -> AlarmLimits:
VSFR.CR_UNITS,
]

resp = self.batch_read_vsfrs(regs, 'I' * len(regs))
resp: list[int] = self.batch_read_vsfrs(regs, 'I' * len(regs))

dose_multiplier = 100 if resp[6] else 1
count_multiplier = 60 if resp[7] else 1
dose_multiplier: int = 100 if resp[6] else 1
count_multiplier: int = 60 if resp[7] else 1
return AlarmLimits(
l1_count_rate=resp[0] / 10 * count_multiplier,
l2_count_rate=resp[1] / 10 * count_multiplier,
Expand Down Expand Up @@ -483,12 +483,12 @@ def set_alarm_limits(
unit will be set to Sv.
"""

which_limits = []
limit_values = []
which_limits: list[VSFR] = []
limit_values: list[int] = []

dose_multiplier = 100 if dose_unit_sv is True else 1
dose_multiplier: int = 100 if dose_unit_sv is True else 1
if isinstance(count_unit_cpm, bool):
count_multiplier = 1 / 6 if count_unit_cpm else 10
count_multiplier: float = 1 / 6 if count_unit_cpm else 10
else:
count_multiplier = 1

Expand Down Expand Up @@ -536,12 +536,12 @@ def set_alarm_limits(
which_limits.append(VSFR.CR_UNITS)
limit_values.append(int(count_unit_cpm))

num_to_set = len(which_limits)
num_to_set: int = len(which_limits)
if not num_to_set:
raise ValueError('No limits specified')

pack_items = [num_to_set] + [int(x) for x in which_limits] + limit_values
pack_format = f'<I{num_to_set}I{num_to_set}I'
resp = self.execute(COMMAND.WR_VIRT_SFR_BATCH, struct.pack(pack_format, *pack_items))
expected_valid = (1 << len(which_limits)) - 1
pack_items: list[int] = [num_to_set] + [int(x) for x in which_limits] + limit_values
pack_format: str = f'<I{num_to_set}I{num_to_set}I'
resp: BytesBuffer = self.execute(COMMAND.WR_VIRT_SFR_BATCH, struct.pack(pack_format, *pack_items))
expected_valid: int = (1 << len(which_limits)) - 1
return expected_valid == resp.unpack('<I')[0]
Loading