diff --git a/pyproject.toml b/pyproject.toml index 96bf653..ec37efd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ markers = [ "network: marks tests requiring network access", "slow: marks other tests that cause bottlenecks", "hypothesis: tests that require hypothesis", + "hypothesis_dbf: hypothesis tests that test dbf functionality", ] python_files = "test_*.py *_test.py *_tests.py" diff --git a/src/shapefile.py b/src/shapefile.py index 89c52a4..2a047f6 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -632,6 +632,7 @@ def from_unchecked( # Raise Exception or trigger warning early, before user adds more fields # (fields are only written when first record added, and on close) + # Tests field_type, size and decimal. Name already tested and cached above. inst.encode_field_descriptor( encoding=encoding, encodingErrors=encodingErrors, diff --git a/tests/hypothesis_tests.py b/tests/hypothesis_tests.py index ea95e8b..cca90ca 100644 --- a/tests/hypothesis_tests.py +++ b/tests/hypothesis_tests.py @@ -5,6 +5,7 @@ import io import itertools import string +import warnings import pytest from hypothesis import HealthCheck, given, settings, reproduce_failure @@ -26,6 +27,16 @@ import shapefile as shp +@contextlib.contextmanager +def ignore_warnings(category=None): + with warnings.catch_warnings(): + if category: + warnings.simplefilter("ignore", category) + else: + warnings.simplefilter("ignore") + yield + + float_nums = floats(allow_nan=False, allow_infinity=False) xs = float_nums ys = float_nums @@ -542,7 +553,13 @@ def test_shx_reader_writer_roundtrip(codes_and_shapes)-> None: ENCODINGS = [ "ascii", + "latin1", "utf-8", + "utf-16-be", + "utf-16-le", + "utf-16", + "utf-32-be", + "utf-32-le", ] encodings = sampled_from(ENCODINGS) @@ -567,6 +584,12 @@ def _dbf_fields_strategy(draw, encoding: str) -> dict[str, str | int]: max_length = bounds_dict.get("max_length", 254) min_length = bounds_dict.get("min_length", 1) + if field_type in {"C", "M"}: + # Make sure field is big enough to store any BOM + # used by non-endianness specified codecs + # (e.g. utf-16 and utf-32) + min_length = max(min_length, len("".encode(encoding))) + max_length = max(max_length, min_length) max_decimal = bounds_dict.get("max_decimal", 0) size = draw(integers(min_value=min_length, max_value=max_length)) decimal = draw(integers(min_value=0, max_value=max(0,min(size - 3, max_decimal)))) @@ -581,7 +604,7 @@ def encodings_and_dbf_fields(draw): field = draw(fields_strategy) return encoding, field -def _get_fields_context(fields, codec, strict=False): +def _get_fields_w_context(fields, codec, strict=False): for field in fields: if (len(field["name"].encode(codec)) > 10 or "\x00" in field["name"] or @@ -592,14 +615,27 @@ def _get_fields_context(fields, codec, strict=False): return pytest.warns(shp.PossibleDataLoss), False return contextlib.nullcontext(), False +def _get_fields_r_context(codec): + # In utf-16-le and utf-32-le, many low code points encode + # to code units ending in null bytes, causing warnings in field + # names (which use trailing null bytes for padding). + normalised = codec.lower().replace("-","").replace("_","") + if (any(normalised.startswith(prefix) for prefix in ["utf16", "utf32"]) and + not codec.lower().endswith("-be")): + + return ignore_warnings(shp.PossibleDataLoss) + return contextlib.nullcontext() + + @pytest.mark.hypothesis +@pytest.mark.hypothesis_dbf @settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large]) @given(encoding_and_dbf_field=encodings_and_dbf_fields()) def test_dbf_Field_roundtrips(encoding_and_dbf_field: dict) -> None: encoding, field_kwargs = encoding_and_dbf_field - w_context, error_expected = _get_fields_context([field_kwargs], encoding, strict=True) + w_context, error_expected = _get_fields_w_context([field_kwargs], encoding, strict=True) with w_context: expected = shp.Field.from_unchecked( @@ -607,17 +643,19 @@ def test_dbf_Field_roundtrips(encoding_and_dbf_field: dict) -> None: strict=True, **field_kwargs, ) - encoded = expected.encode_field_descriptor(strict=True) + encoded = expected.encode_field_descriptor(encoding=encoding, strict=True) if error_expected: return stream = io.BytesIO() stream.write(encoded) stream.seek(0) - actual = shp.Field.from_byte_stream( - stream, - encoding=encoding, - ) + + with _get_fields_r_context(encoding): + actual = shp.Field.from_byte_stream( + stream, + encoding=encoding, + ) assert isinstance(actual, shp.Field) assert actual.name == expected.name @@ -753,13 +791,15 @@ def _write_fields_and_records_to_strict(w, fields, records): return written_fields, written_records @pytest.mark.hypothesis +@pytest.mark.hypothesis_dbf +@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large]) @given(codec_fields_and_records=dbf_encoding_fields_and_records()) def test_dbf_reader_writer_roundtrip(codec_fields_and_records)-> None: codec, fields, records = codec_fields_and_records stream = io.BytesIO() # pytest.raises and pytest.warns can obscure other - # exceptions inside them + # exceptions inside them, when iterating on the test code w = shp.DbfWriter(dbf=stream, encoding=codec, strict=True) written_fields, written_records = _write_fields_and_records_to_strict(w, fields, records) @@ -770,7 +810,7 @@ def test_dbf_reader_writer_roundtrip(codec_fields_and_records)-> None: w.close() - with shp.DbfReader(dbf=stream, encoding=codec) as r: + with _get_fields_r_context(codec), shp.DbfReader(dbf=stream, encoding=codec) as r: _assert_reader_matches_expected_fields(r, written_fields, True) _assert_reader_matches_expected_records(r, written_fields, written_records) @@ -786,6 +826,7 @@ def codes_codecs_fields_shapes_and_records(draw): @pytest.mark.hypothesis +@pytest.mark.hypothesis_dbf @settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large]) @given(codes_codecs_fields_shapes_and_records=codes_codecs_fields_shapes_and_records()) def test_shapefile_reader_writer_roundtrip(codes_codecs_fields_shapes_and_records)-> None: @@ -813,7 +854,7 @@ def test_shapefile_reader_writer_roundtrip(codes_codecs_fields_shapes_and_record w.close() - with shp.Reader(encoding=encoding, **streams) as r: + with _get_fields_r_context(encoding), shp.Reader(encoding=encoding, **streams) as r: _assert_reader_matches_expected_fields(r, written_fields, True) _assert_reader_matches_expected_records(r, written_fields, written_records) - _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes) \ No newline at end of file + _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes)